diff --git a/include/grpcpp/impl/codegen/call_op_set.h b/include/grpcpp/impl/codegen/call_op_set.h index d810625b3e4..fdd0b5e91b7 100644 --- a/include/grpcpp/impl/codegen/call_op_set.h +++ b/include/grpcpp/impl/codegen/call_op_set.h @@ -433,7 +433,9 @@ class CallOpRecvMessage { message_(nullptr), allow_not_getting_message_(false) {} - void RecvMessage(R* message) { message_ = message; } + void RecvMessage(R* message) { + message_ = message; + } // Do not change status if no message is received. void AllowNoMessage() { allow_not_getting_message_ = true; } @@ -468,7 +470,6 @@ class CallOpRecvMessage { *status = false; } } - message_ = nullptr; } void SetInterceptionHookPoint( @@ -565,7 +566,6 @@ class CallOpGenericRecvMessage { *status = false; } } - deserialize_.reset(); } void SetInterceptionHookPoint( @@ -580,6 +580,7 @@ class CallOpGenericRecvMessage { interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_MESSAGE); if (!got_message) interceptor_methods->SetRecvMessage(nullptr, nullptr); + deserialize_.reset(); } void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) { hijacked_ = true; diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index f1aed093dc4..70b0ecdf585 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -501,7 +501,14 @@ class BidiStreamingRpcHijackingInterceptorFactory class LoggingInterceptor : public experimental::Interceptor { public: - LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; } + LoggingInterceptor(experimental::ClientRpcInfo* info) : info_(info) { + pre_send_initial_metadata_ = false; + pre_send_message_count_ = 0; + pre_send_close_ = false; + post_recv_initial_metadata_ = false; + post_recv_message_count_ = 0; + post_recv_status_ = false; + } virtual void Intercept(experimental::InterceptorBatchMethods* methods) { if (methods->QueryInterceptionHookPoint( @@ -512,6 +519,8 @@ class LoggingInterceptor : public experimental::Interceptor { auto iterator = map->begin(); EXPECT_EQ("testkey", iterator->first); EXPECT_EQ("testvalue", iterator->second); + ASSERT_FALSE(pre_send_initial_metadata_); + pre_send_initial_metadata_ = true; } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { @@ -526,22 +535,28 @@ class LoggingInterceptor : public experimental::Interceptor { SerializationTraits::Deserialize(&copied_buffer, &req) .ok()); EXPECT_TRUE(req.message().find("Hello") == 0u); + pre_send_message_count_++; } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { // Got nothing to do here for now + pre_send_close_ = true; } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { auto* map = methods->GetRecvInitialMetadata(); // Got nothing better to do here for now EXPECT_EQ(map->size(), static_cast(0)); + post_recv_initial_metadata_ = true; } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { EchoResponse* resp = static_cast(methods->GetRecvMessage()); - EXPECT_TRUE(resp->message().find("Hello") == 0u); + if(resp != nullptr) { + EXPECT_TRUE(resp->message().find("Hello") == 0u); + post_recv_message_count_++; + } } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_STATUS)) { @@ -556,14 +571,59 @@ class LoggingInterceptor : public experimental::Interceptor { EXPECT_EQ(found, true); auto* status = methods->GetRecvStatus(); EXPECT_EQ(status->ok(), true); + post_recv_status_ = true; } methods->Proceed(); } + static void VerifyCallCommon() { + EXPECT_TRUE(pre_send_initial_metadata_); + EXPECT_TRUE(pre_send_close_); + EXPECT_TRUE(post_recv_initial_metadata_); + EXPECT_TRUE(post_recv_status_); + } + + static void VerifyUnaryCall() { + VerifyCallCommon(); + EXPECT_EQ(pre_send_message_count_, 1); + EXPECT_EQ(post_recv_message_count_, 1); + } + + static void VerifyClientStreamingCall() { + VerifyCallCommon(); + EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages); + EXPECT_EQ(post_recv_message_count_, 1); + } + + static void VerifyServerStreamingCall() { + VerifyCallCommon(); + EXPECT_EQ(pre_send_message_count_, 1); + EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages); + } + + static void VerifyBidiStreamingCall() { + VerifyCallCommon(); + EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages); + EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages); + } + private: experimental::ClientRpcInfo* info_; + static bool pre_send_initial_metadata_; + static int pre_send_message_count_; + static bool pre_send_close_; + static bool post_recv_initial_metadata_; + static int post_recv_message_count_; + static bool post_recv_status_; }; +bool LoggingInterceptor::pre_send_initial_metadata_; +int LoggingInterceptor::pre_send_message_count_; +bool LoggingInterceptor::pre_send_close_; +bool LoggingInterceptor::post_recv_initial_metadata_; +int LoggingInterceptor::post_recv_message_count_; +bool LoggingInterceptor::post_recv_status_; + class LoggingInterceptorFactory : public experimental::ClientInterceptorFactoryInterface { public: @@ -607,6 +667,7 @@ TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) { auto channel = experimental::CreateCustomChannelWithInterceptors( server_address_, InsecureChannelCredentials(), args, std::move(creators)); MakeCall(channel); + LoggingInterceptor::VerifyUnaryCall(); // Make sure all 20 dummy interceptors were run EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } @@ -643,7 +704,6 @@ TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) { } auto channel = experimental::CreateCustomChannelWithInterceptors( server_address_, InsecureChannelCredentials(), args, std::move(creators)); - MakeCall(channel); // Make sure only 20 dummy interceptors were run EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); @@ -659,8 +719,8 @@ TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) { new HijackingInterceptorFactory())); auto channel = experimental::CreateCustomChannelWithInterceptors( server_address_, InsecureChannelCredentials(), args, std::move(creators)); - MakeCall(channel); + LoggingInterceptor::VerifyUnaryCall(); } TEST_F(ClientInterceptorsEnd2endTest, @@ -708,6 +768,7 @@ TEST_F(ClientInterceptorsEnd2endTest, auto channel = server_->experimental().InProcessChannelWithInterceptors( args, std::move(creators)); MakeCallbackCall(channel); + LoggingInterceptor::VerifyUnaryCall(); // Make sure all 20 dummy interceptors were run EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } @@ -730,6 +791,7 @@ TEST_F(ClientInterceptorsEnd2endTest, auto channel = server_->experimental().InProcessChannelWithInterceptors( args, std::move(creators)); MakeCallbackCall(channel); + LoggingInterceptor::VerifyUnaryCall(); // Make sure all 20 dummy interceptors were run EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } @@ -768,6 +830,7 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingTest) { auto channel = experimental::CreateCustomChannelWithInterceptors( server_address_, InsecureChannelCredentials(), args, std::move(creators)); MakeClientStreamingCall(channel); + LoggingInterceptor::VerifyClientStreamingCall(); // Make sure all 20 dummy interceptors were run EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } @@ -787,6 +850,7 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) { auto channel = experimental::CreateCustomChannelWithInterceptors( server_address_, InsecureChannelCredentials(), args, std::move(creators)); MakeServerStreamingCall(channel); + LoggingInterceptor::VerifyServerStreamingCall(); // Make sure all 20 dummy interceptors were run EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } @@ -862,6 +926,7 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) { auto channel = experimental::CreateCustomChannelWithInterceptors( server_address_, InsecureChannelCredentials(), args, std::move(creators)); MakeBidiStreamingCall(channel); + LoggingInterceptor::VerifyBidiStreamingCall(); // Make sure all 20 dummy interceptors were run EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } @@ -928,6 +993,7 @@ TEST_F(ClientGlobalInterceptorEnd2endTest, LoggingGlobalInterceptor) { auto channel = experimental::CreateCustomChannelWithInterceptors( server_address_, InsecureChannelCredentials(), args, std::move(creators)); MakeCall(channel); + LoggingInterceptor::VerifyUnaryCall(); // Make sure all 20 dummy interceptors were run EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); experimental::TestOnlyResetGlobalClientInterceptorFactory(); diff --git a/test/cpp/end2end/interceptors_util.cc b/test/cpp/end2end/interceptors_util.cc index 900f02b5f36..6321c35ba4a 100644 --- a/test/cpp/end2end/interceptors_util.cc +++ b/test/cpp/end2end/interceptors_util.cc @@ -48,7 +48,7 @@ void MakeClientStreamingCall(const std::shared_ptr& channel) { EchoResponse resp; string expected_resp = ""; auto writer = stub->RequestStream(&ctx, &resp); - for (int i = 0; i < 10; i++) { + for (int i = 0; i < kNumStreamingMessages; i++) { writer->Write(req); expected_resp += "Hello"; } @@ -73,7 +73,7 @@ void MakeServerStreamingCall(const std::shared_ptr& channel) { EXPECT_EQ(resp.message(), "Hello"); count++; } - ASSERT_EQ(count, 10); + ASSERT_EQ(count, kNumStreamingMessages); Status s = reader->Finish(); EXPECT_EQ(s.ok(), true); } @@ -85,7 +85,7 @@ void MakeBidiStreamingCall(const std::shared_ptr& channel) { EchoResponse resp; ctx.AddMetadata("testkey", "testvalue"); auto stream = stub->BidiStream(&ctx); - for (auto i = 0; i < 10; i++) { + for (auto i = 0; i < kNumStreamingMessages; i++) { req.set_message("Hello" + std::to_string(i)); stream->Write(req); stream->Read(&resp); diff --git a/test/cpp/end2end/interceptors_util.h b/test/cpp/end2end/interceptors_util.h index 419845e5f61..1cd1448a6fa 100644 --- a/test/cpp/end2end/interceptors_util.h +++ b/test/cpp/end2end/interceptors_util.h @@ -152,6 +152,8 @@ class EchoTestServiceStreamingImpl : public EchoTestService::Service { } }; +constexpr int kNumStreamingMessages = 10; + void MakeCall(const std::shared_ptr& channel); void MakeClientStreamingCall(const std::shared_ptr& channel);