diff --git a/include/grpcpp/impl/codegen/call_op_set.h b/include/grpcpp/impl/codegen/call_op_set.h index b8a1c79ddf5..1c75560c04b 100644 --- a/include/grpcpp/impl/codegen/call_op_set.h +++ b/include/grpcpp/impl/codegen/call_op_set.h @@ -326,21 +326,37 @@ class CallOpSendMessage { // Flags are per-message: clear them after use. write_options_.Clear(); } - void FinishOp(bool* status) { send_buf_.Clear(); } + void FinishOp(bool* status) { + if (!send_buf_.Valid()) { + return; + } + if (hijacked_ && failed_send_) { + // Hijacking interceptor failed this Op + *status = false; + } else if (!*status) { + // This Op was passed down to core and the Op failed + failed_send_ = true; + } + } void SetInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { if (!send_buf_.Valid()) return; interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE); - interceptor_methods->SetSendMessage(&send_buf_, msg_); + interceptor_methods->SetSendMessage(&send_buf_, msg_, &failed_send_); } void SetFinishInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { + if (send_buf_.Valid()) { + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_SEND_MESSAGE); + } + send_buf_.Clear(); // The contents of the SendMessage value that was previously set // has had its references stolen by core's operations - interceptor_methods->SetSendMessage(nullptr, nullptr); + interceptor_methods->SetSendMessage(nullptr, nullptr, &failed_send_); } void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) { @@ -350,6 +366,7 @@ class CallOpSendMessage { private: const void* msg_ = nullptr; // The original non-serialized message bool hijacked_ = false; + bool failed_send_ = false; ByteBuffer send_buf_; WriteOptions write_options_; }; diff --git a/include/grpcpp/impl/codegen/interceptor.h b/include/grpcpp/impl/codegen/interceptor.h index 9b49983748a..a57a3fccbbe 100644 --- a/include/grpcpp/impl/codegen/interceptor.h +++ b/include/grpcpp/impl/codegen/interceptor.h @@ -46,9 +46,10 @@ namespace experimental { /// operation has been requested and it is available. POST_RECV means that a /// result is available but has not yet been passed back to the application. enum class InterceptionHookPoints { - /// The first two in this list are for clients and servers + /// The first three in this list are for clients and servers PRE_SEND_INITIAL_METADATA, PRE_SEND_MESSAGE, + POST_SEND_MESSAGE, PRE_SEND_STATUS, // server only PRE_SEND_CLOSE, // client only: WritesDone for stream; after write in unary /// The following three are for hijacked clients only and can only be @@ -117,6 +118,10 @@ class InterceptorBatchMethods { /// only supported for sync and callback APIs at the present moment. virtual const void* GetSendMessage() = 0; + /// Checks whether the SEND MESSAGE op succeeded. Valid for POST_SEND_MESSAGE + /// interceptions. + virtual bool GetSendMessageStatus() = 0; + /// Returns a modifiable multimap of the initial metadata to be sent. Valid /// for PRE_SEND_INITIAL_METADATA interceptions. A value of nullptr indicates /// that this field is not valid. @@ -167,6 +172,10 @@ class InterceptorBatchMethods { /// op. This would be a signal to the reader that there will be no more /// messages, or the stream has failed or been cancelled. virtual void FailHijackedRecvMessage() = 0; + + /// On a hijacked RPC/ to-be hijacked RPC, this can be called to fail a SEND + /// MESSAGE op + virtual void FailHijackedSendMessage() = 0; }; /// Interface for an interceptor. Interceptor authors must create a class diff --git a/include/grpcpp/impl/codegen/interceptor_common.h b/include/grpcpp/impl/codegen/interceptor_common.h index d60f9585fc1..345127c830e 100644 --- a/include/grpcpp/impl/codegen/interceptor_common.h +++ b/include/grpcpp/impl/codegen/interceptor_common.h @@ -83,6 +83,8 @@ class InterceptorBatchMethodsImpl const void* GetSendMessage() override { return orig_send_message_; } + bool GetSendMessageStatus() override { return !*fail_send_message_; } + std::multimap* GetSendInitialMetadata() override { return send_initial_metadata_; } @@ -112,14 +114,22 @@ class InterceptorBatchMethodsImpl Status* GetRecvStatus() override { return recv_status_; } + void FailHijackedSendMessage() override { + GPR_CODEGEN_ASSERT(hooks_[static_cast( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]); + *fail_send_message_ = true; + } + std::multimap* GetRecvTrailingMetadata() override { return recv_trailing_metadata_->map(); } - void SetSendMessage(ByteBuffer* buf, const void* msg) { + void SetSendMessage(ByteBuffer* buf, const void* msg, + bool* fail_send_message) { send_message_ = buf; orig_send_message_ = msg; + fail_send_message_ = fail_send_message; } void SetSendInitialMetadata( @@ -348,6 +358,7 @@ class InterceptorBatchMethodsImpl std::function callback_; ByteBuffer* send_message_ = nullptr; + bool* fail_send_message_ = nullptr; const void* orig_send_message_ = nullptr; std::multimap* send_initial_metadata_; @@ -402,6 +413,14 @@ class CancelInterceptorBatchMethods return nullptr; } + bool GetSendMessageStatus() override { + GPR_CODEGEN_ASSERT( + false && + "It is illegal to call GetSendMessageStatus on a method which " + "has a Cancel notification"); + return false; + } + const void* GetSendMessage() override { GPR_CODEGEN_ASSERT( false && @@ -481,6 +500,12 @@ class CancelInterceptorBatchMethods "It is illegal to call FailHijackedRecvMessage on a " "method which has a Cancel notification"); } + + void FailHijackedSendMessage() override { + GPR_CODEGEN_ASSERT(false && + "It is illegal to call FailHijackedSendMessage on a " + "method which has a Cancel notification"); + } }; } // namespace internal } // namespace grpc diff --git a/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py b/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py index aeb02458a7e..5a5dedd5f26 100644 --- a/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py +++ b/src/python/grpcio_tests/tests/unit/_cython/_fork_test.py @@ -27,6 +27,7 @@ def _get_number_active_threads(): class ForkPosixTester(unittest.TestCase): def setUp(self): + self._saved_fork_support_flag = cygrpc._GRPC_ENABLE_FORK_SUPPORT cygrpc._GRPC_ENABLE_FORK_SUPPORT = True def testForkManagedThread(self): @@ -50,6 +51,9 @@ class ForkPosixTester(unittest.TestCase): thread.join() self.assertEqual(0, _get_number_active_threads()) + def tearDown(self): + cygrpc._GRPC_ENABLE_FORK_SUPPORT = self._saved_fork_support_flag + @unittest.skipUnless(os.name == 'nt', 'Windows-specific tests') class ForkWindowsTester(unittest.TestCase): diff --git a/src/python/grpcio_tests/tests/unit/_logging_test.py b/src/python/grpcio_tests/tests/unit/_logging_test.py index 631b9de9db5..8ff127f5062 100644 --- a/src/python/grpcio_tests/tests/unit/_logging_test.py +++ b/src/python/grpcio_tests/tests/unit/_logging_test.py @@ -14,66 +14,86 @@ """Test of gRPC Python's interaction with the python logging module""" import unittest -import six -from six.moves import reload_module import logging import grpc -import functools +import subprocess import sys +INTERPRETER = sys.executable -def patch_stderr(f): - @functools.wraps(f) - def _impl(*args, **kwargs): - old_stderr = sys.stderr - sys.stderr = six.StringIO() - try: - f(*args, **kwargs) - finally: - sys.stderr = old_stderr +class LoggingTest(unittest.TestCase): - return _impl + def test_logger_not_occupied(self): + script = """if True: + import logging + import grpc -def isolated_logging(f): + if len(logging.getLogger().handlers) != 0: + raise Exception('expected 0 logging handlers') - @functools.wraps(f) - def _impl(*args, **kwargs): - reload_module(logging) - reload_module(grpc) - try: - f(*args, **kwargs) - finally: - reload_module(logging) + """ + self._verifyScriptSucceeds(script) - return _impl + def test_handler_found(self): + script = """if True: + import logging + import grpc + """ + out, err = self._verifyScriptSucceeds(script) + self.assertEqual(0, len(err), 'unexpected output to stderr') -class LoggingTest(unittest.TestCase): + def test_can_configure_logger(self): + script = """if True: + import logging + import six - @isolated_logging - def test_logger_not_occupied(self): - self.assertEqual(0, len(logging.getLogger().handlers)) + import grpc - @patch_stderr - @isolated_logging - def test_handler_found(self): - self.assertEqual(0, len(sys.stderr.getvalue())) - @isolated_logging - def test_can_configure_logger(self): - intended_stream = six.StringIO() - logging.basicConfig(stream=intended_stream) - self.assertEqual(1, len(logging.getLogger().handlers)) - self.assertIs(logging.getLogger().handlers[0].stream, intended_stream) + intended_stream = six.StringIO() + logging.basicConfig(stream=intended_stream) + + if len(logging.getLogger().handlers) != 1: + raise Exception('expected 1 logging handler') + + if logging.getLogger().handlers[0].stream is not intended_stream: + raise Exception('wrong handler stream') + + """ + self._verifyScriptSucceeds(script) - @isolated_logging def test_grpc_logger(self): - self.assertIn("grpc", logging.Logger.manager.loggerDict) - root_logger = logging.getLogger("grpc") - self.assertEqual(1, len(root_logger.handlers)) - self.assertIsInstance(root_logger.handlers[0], logging.NullHandler) + script = """if True: + import logging + + import grpc + + if "grpc" not in logging.Logger.manager.loggerDict: + raise Exception('grpc logger not found') + + root_logger = logging.getLogger("grpc") + if len(root_logger.handlers) != 1: + raise Exception('expected 1 root logger handler') + if not isinstance(root_logger.handlers[0], logging.NullHandler): + raise Exception('expected logging.NullHandler') + + """ + self._verifyScriptSucceeds(script) + + def _verifyScriptSucceeds(self, script): + process = subprocess.Popen( + [INTERPRETER, '-c', script], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + out, err = process.communicate() + self.assertEqual( + 0, process.returncode, + 'process failed with exit code %d (stdout: %s, stderr: %s)' % + (process.returncode, out, err)) + return out, err if __name__ == '__main__': diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index cc0667b460e..9fbfd8c84a1 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -270,6 +270,129 @@ class HijackingInterceptorMakesAnotherCallFactory } }; +class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor { + public: + BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue"); + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSerializedSendMessage(); + auto copied_buffer = *buffer; + EXPECT_TRUE( + SerializationTraits::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_EQ(req.message().find("Hello"), 0u); + msg = req.message(); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey", + "testvalue"); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + resp->set_message(msg); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EXPECT_EQ(static_cast(methods->GetRecvMessage()) + ->message() + .find("Hello"), + 0u); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), static_cast(0)); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + private: + experimental::ClientRpcInfo* info_; + grpc::string msg; +}; + +class ClientStreamingRpcHijackingInterceptor + : public experimental::Interceptor { + public: + ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + } + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + if (++count_ > 10) { + methods->FailHijackedSendMessage(); + } + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) { + EXPECT_FALSE(got_failed_send_); + got_failed_send_ = !methods->GetSendMessageStatus(); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages"); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + static bool GotFailedSend() { return got_failed_send_; } + + private: + experimental::ClientRpcInfo* info_; + int count_ = 0; + static bool got_failed_send_; +}; + +bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false; + +class ClientStreamingRpcHijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new ClientStreamingRpcHijackingInterceptor(info); + } +}; + class ServerStreamingRpcHijackingInterceptor : public experimental::Interceptor { public: @@ -292,7 +415,7 @@ class ServerStreamingRpcHijackingInterceptor if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { EchoRequest req; - auto* buffer = methods->GetSendMessage(); + auto* buffer = methods->GetSerializedSendMessage(); auto copied_buffer = *buffer; EXPECT_TRUE( SerializationTraits::Deserialize(&copied_buffer, &req) @@ -367,6 +490,15 @@ class ServerStreamingRpcHijackingInterceptorFactory } }; +class BidiStreamingRpcHijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new BidiStreamingRpcHijackingInterceptor(info); + } +}; + class LoggingInterceptor : public experimental::Interceptor { public: LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; } @@ -647,6 +779,35 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) { EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } +TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) { + ChannelArguments args; + std::vector> + creators; + creators.push_back( + std::unique_ptr( + new ClientStreamingRpcHijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + EchoResponse resp; + req.mutable_param()->set_echo_metadata(true); + req.set_message("Hello"); + string expected_resp = ""; + auto writer = stub->RequestStream(&ctx, &resp); + for (int i = 0; i < 10; i++) { + EXPECT_TRUE(writer->Write(req)); + expected_resp += "Hello"; + } + // The interceptor will reject the 11th message + writer->Write(req); + Status s = writer->Finish(); + EXPECT_EQ(s.ok(), false); + EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend()); +} + TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) { ChannelArguments args; DummyInterceptor::Reset(); @@ -661,6 +822,19 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) { EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage()); } +TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector> + creators; + creators.push_back( + std::unique_ptr( + new BidiStreamingRpcHijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeBidiStreamingCall(channel); +} + TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) { ChannelArguments args; DummyInterceptor::Reset(); diff --git a/test/cpp/end2end/interceptors_util.cc b/test/cpp/end2end/interceptors_util.cc index e0ad7d1526c..900f02b5f36 100644 --- a/test/cpp/end2end/interceptors_util.cc +++ b/test/cpp/end2end/interceptors_util.cc @@ -132,6 +132,16 @@ bool CheckMetadata(const std::multimap& map, return false; } +bool CheckMetadata(const std::multimap& map, + const string& key, const string& value) { + for (const auto& pair : map) { + if (pair.first == key && pair.second == value) { + return true; + } + } + return false; +} + std::vector> CreateDummyClientInterceptors() { std::vector> diff --git a/test/cpp/end2end/interceptors_util.h b/test/cpp/end2end/interceptors_util.h index 659e613d2eb..419845e5f61 100644 --- a/test/cpp/end2end/interceptors_util.h +++ b/test/cpp/end2end/interceptors_util.h @@ -165,6 +165,9 @@ void MakeCallbackCall(const std::shared_ptr& channel); bool CheckMetadata(const std::multimap& map, const string& key, const string& value); +bool CheckMetadata(const std::multimap& map, + const string& key, const string& value); + std::vector> CreateDummyClientInterceptors();