diff --git a/include/grpcpp/impl/codegen/call_op_set.h b/include/grpcpp/impl/codegen/call_op_set.h index 1f2b88e9e14..f330679ffc9 100644 --- a/include/grpcpp/impl/codegen/call_op_set.h +++ b/include/grpcpp/impl/codegen/call_op_set.h @@ -315,9 +315,16 @@ class CallOpSendMessage { write_options_.Clear(); } void FinishOp(bool* status) { - send_buf_.Clear(); + if (!send_buf_.Valid()) { + return; + } if (hijacked_ && failed_send_) { + // Hijacking interceptor failed this Op *status = false; + } else if (!*status) { + // This Op was passed down to core and the Op failed + gpr_log(GPR_ERROR, "failure status"); + failed_send_ = true; } } @@ -330,7 +337,14 @@ class CallOpSendMessage { } void SetFinishInterceptionHookPoint( - InterceptorBatchMethodsImpl* interceptor_methods) {} + InterceptorBatchMethodsImpl* interceptor_methods) { + if (send_buf_.Valid()) { + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_SEND_MESSAGE); + // We had already registered failed_send_ earlier. No need to do it again. + } + send_buf_.Clear(); + } void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) { hijacked_ = true; diff --git a/include/grpcpp/impl/codegen/interceptor.h b/include/grpcpp/impl/codegen/interceptor.h index 47239332c86..154172dd814 100644 --- a/include/grpcpp/impl/codegen/interceptor.h +++ b/include/grpcpp/impl/codegen/interceptor.h @@ -41,9 +41,10 @@ class InterceptedMessage { }; enum class InterceptionHookPoints { - /* The first two in this list are for clients and servers */ + /* The first three in this list are for clients and servers */ PRE_SEND_INITIAL_METADATA, PRE_SEND_MESSAGE, + POST_SEND_MESSAGE, PRE_SEND_STATUS /* server only */, PRE_SEND_CLOSE /* client only */, /* The following three are for hijacked clients only and can only be @@ -85,6 +86,9 @@ class InterceptorBatchMethods { // sent virtual ByteBuffer* GetSendMessage() = 0; + // Checks whether the SEND MESSAGE op succeeded + virtual bool GetSendMessageStatus() = 0; + // Returns a modifiable multimap of the initial metadata to be sent virtual std::multimap* GetSendInitialMetadata() = 0; diff --git a/include/grpcpp/impl/codegen/interceptor_common.h b/include/grpcpp/impl/codegen/interceptor_common.h index 601a929afe6..21326df73be 100644 --- a/include/grpcpp/impl/codegen/interceptor_common.h +++ b/include/grpcpp/impl/codegen/interceptor_common.h @@ -81,6 +81,8 @@ class InterceptorBatchMethodsImpl ByteBuffer* GetSendMessage() override { return send_message_; } + bool GetSendMessageStatus() override { return !*fail_send_message_; } + std::multimap* GetSendInitialMetadata() override { return send_initial_metadata_; } @@ -113,6 +115,7 @@ class InterceptorBatchMethodsImpl void FailHijackedSendMessage() override { GPR_CODEGEN_ASSERT(hooks_[static_cast( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]); + gpr_log(GPR_ERROR, "failing"); *fail_send_message_ = true; } @@ -396,6 +399,13 @@ class CancelInterceptorBatchMethods return nullptr; } + bool GetSendMessageStatus() override { + GPR_CODEGEN_ASSERT( + false && + "It is illegal to call GetSendMessageStatus on a method which " + "has a Cancel notification"); + } + std::multimap* GetSendInitialMetadata() override { GPR_CODEGEN_ASSERT(false && "It is illegal to call GetSendInitialMetadata on a " diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 81efd154525..97947e73931 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -287,6 +287,13 @@ class ClientStreamingRpcHijackingInterceptor methods->FailHijackedSendMessage(); } } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) { + EXPECT_FALSE(got_failed_send_); + gpr_log(GPR_ERROR, "%d", got_failed_send_); + got_failed_send_ = !methods->GetSendMessageStatus(); + gpr_log(GPR_ERROR, "%d", got_failed_send_); + } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { auto* status = methods->GetRecvStatus(); @@ -299,10 +306,16 @@ class ClientStreamingRpcHijackingInterceptor } } + static bool GotFailedSend() { return got_failed_send_; } + private: experimental::ClientRpcInfo* info_; int count_ = 0; + static bool got_failed_send_; }; + +bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false; + class ClientStreamingRpcHijackingInterceptorFactory : public experimental::ClientInterceptorFactoryInterface { public: @@ -602,10 +615,11 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) { EXPECT_TRUE(writer->Write(req)); expected_resp += "Hello"; } - // Expect that the interceptor will reject the 11th message - EXPECT_FALSE(writer->Write(req)); + // The interceptor will reject the 11th message + writer->Write(req); Status s = writer->Finish(); EXPECT_EQ(s.ok(), false); + EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend()); } TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {