diff --git a/include/grpcpp/impl/codegen/call_op_set.h b/include/grpcpp/impl/codegen/call_op_set.h index e6d7b900ca1..5bfd541c8f4 100644 --- a/include/grpcpp/impl/codegen/call_op_set.h +++ b/include/grpcpp/impl/codegen/call_op_set.h @@ -421,17 +421,14 @@ Status CallOpSendMessage::SendMessagePtr(const M* message) { template class CallOpRecvMessage { public: - CallOpRecvMessage() - : got_message(false), - message_(nullptr), - allow_not_getting_message_(false) {} + CallOpRecvMessage() {} void RecvMessage(R* message) { message_ = message; } // Do not change status if no message is received. void AllowNoMessage() { allow_not_getting_message_ = true; } - bool got_message; + bool got_message = false; protected: void AddOp(grpc_op* ops, size_t* nops) { @@ -444,7 +441,7 @@ class CallOpRecvMessage { } void FinishOp(bool* status) { - if (message_ == nullptr || hijacked_) return; + if (message_ == nullptr) return; if (recv_buf_.Valid()) { if (*status) { got_message = *status = @@ -455,18 +452,20 @@ class CallOpRecvMessage { got_message = false; recv_buf_.Clear(); } - } else { - got_message = false; - if (!allow_not_getting_message_) { - *status = false; + } else if (hijacked_) { + if (!hijacked_recv_message_status_) { + FinishOpRecvMessageFailureHandler(status); } + } else { + FinishOpRecvMessageFailureHandler(status); } } void SetInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { if (message_ == nullptr) return; - interceptor_methods->SetRecvMessage(message_, &got_message); + interceptor_methods->SetRecvMessage(message_, + &hijacked_recv_message_status_); } void SetFinishInterceptionHookPoint( @@ -485,10 +484,19 @@ class CallOpRecvMessage { } private: - R* message_; + // Sets got_message and \a status for a failed recv message op + void FinishOpRecvMessageFailureHandler(bool* status) { + got_message = false; + if (!allow_not_getting_message_) { + *status = false; + } + } + + R* message_ = nullptr; ByteBuffer recv_buf_; - bool allow_not_getting_message_; + bool allow_not_getting_message_ = false; bool hijacked_ = false; + bool hijacked_recv_message_status_ = true; }; class DeserializeFunc { @@ -513,8 +521,7 @@ class DeserializeFuncType final : public DeserializeFunc { class CallOpGenericRecvMessage { public: - CallOpGenericRecvMessage() - : got_message(false), allow_not_getting_message_(false) {} + CallOpGenericRecvMessage() {} template void RecvMessage(R* message) { @@ -528,7 +535,7 @@ class CallOpGenericRecvMessage { // Do not change status if no message is received. void AllowNoMessage() { allow_not_getting_message_ = true; } - bool got_message; + bool got_message = false; protected: void AddOp(grpc_op* ops, size_t* nops) { @@ -551,6 +558,10 @@ class CallOpGenericRecvMessage { got_message = false; recv_buf_.Clear(); } + } else if (hijacked_) { + if (!hijacked_recv_message_status_) { + FinishOpRecvMessageFailureHandler(status); + } } else { got_message = false; if (!allow_not_getting_message_) { @@ -562,7 +573,8 @@ class CallOpGenericRecvMessage { void SetInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { if (!deserialize_) return; - interceptor_methods->SetRecvMessage(message_, &got_message); + interceptor_methods->SetRecvMessage(message_, + &hijacked_recv_message_status_); } void SetFinishInterceptionHookPoint( @@ -582,11 +594,20 @@ class CallOpGenericRecvMessage { } private: - void* message_; - bool hijacked_ = false; + // Sets got_message and \a status for a failed recv message op + void FinishOpRecvMessageFailureHandler(bool* status) { + got_message = false; + if (!allow_not_getting_message_) { + *status = false; + } + } + + void* message_ = nullptr; std::unique_ptr deserialize_; ByteBuffer recv_buf_; - bool allow_not_getting_message_; + bool allow_not_getting_message_ = false; + bool hijacked_ = false; + bool hijacked_recv_message_status_ = true; }; class CallOpClientSendClose { diff --git a/include/grpcpp/impl/codegen/interceptor_common.h b/include/grpcpp/impl/codegen/interceptor_common.h index 01ffe19bc4a..9b124197ff0 100644 --- a/include/grpcpp/impl/codegen/interceptor_common.h +++ b/include/grpcpp/impl/codegen/interceptor_common.h @@ -166,9 +166,9 @@ class InterceptorBatchMethodsImpl send_trailing_metadata_ = metadata; } - void SetRecvMessage(void* message, bool* got_message) { + void SetRecvMessage(void* message, bool* hijacked_recv_message_status) { recv_message_ = message; - got_message_ = got_message; + hijacked_recv_message_status_ = hijacked_recv_message_status; } void SetRecvInitialMetadata(MetadataMap* map) { @@ -195,7 +195,7 @@ class InterceptorBatchMethodsImpl void FailHijackedRecvMessage() override { GPR_CODEGEN_ASSERT(hooks_[static_cast( experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)]); - *got_message_ = false; + *hijacked_recv_message_status_ = false; } // Clears all state @@ -407,7 +407,7 @@ class InterceptorBatchMethodsImpl std::multimap* send_trailing_metadata_ = nullptr; void* recv_message_ = nullptr; - bool* got_message_ = nullptr; + bool* hijacked_recv_message_status_ = nullptr; MetadataMap* recv_initial_metadata_ = nullptr; diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 7ee95a60c1f..86012a34586 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -43,6 +43,17 @@ namespace grpc { namespace testing { namespace { +enum class RPCType { + kSyncUnary, + kSyncClientStreaming, + kSyncServerStreaming, + kSyncBidiStreaming, + kAsyncCQUnary, + kAsyncCQClientStreaming, + kAsyncCQServerStreaming, + kAsyncCQBidiStreaming, +}; + /* Hijacks Echo RPC and fills in the expected values */ class HijackingInterceptor : public experimental::Interceptor { public: @@ -400,6 +411,7 @@ class ServerStreamingRpcHijackingInterceptor public: ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; + got_failed_message_ = false; } virtual void Intercept(experimental::InterceptorBatchMethods* methods) { @@ -531,10 +543,22 @@ class LoggingInterceptor : public experimental::Interceptor { if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { EchoRequest req; - EXPECT_EQ(static_cast(methods->GetSendMessage()) - ->message() - .find("Hello"), - 0u); + auto* send_msg = methods->GetSendMessage(); + if (send_msg == nullptr) { + // We did not get the non-serialized form of the message. Get the + // serialized form. + auto* buffer = methods->GetSerializedSendMessage(); + auto copied_buffer = *buffer; + EchoRequest req; + EXPECT_TRUE( + SerializationTraits::Deserialize(&copied_buffer, &req) + .ok()); + EXPECT_EQ(req.message(), "Hello"); + } else { + EXPECT_EQ( + static_cast(send_msg)->message().find("Hello"), + 0u); + } auto* buffer = methods->GetSerializedSendMessage(); auto copied_buffer = *buffer; EXPECT_TRUE( @@ -582,6 +606,27 @@ class LoggingInterceptor : public experimental::Interceptor { methods->Proceed(); } + static void VerifyCall(RPCType type) { + switch (type) { + case RPCType::kSyncUnary: + case RPCType::kAsyncCQUnary: + VerifyUnaryCall(); + break; + case RPCType::kSyncClientStreaming: + case RPCType::kAsyncCQClientStreaming: + VerifyClientStreamingCall(); + break; + case RPCType::kSyncServerStreaming: + case RPCType::kAsyncCQServerStreaming: + VerifyServerStreamingCall(); + break; + case RPCType::kSyncBidiStreaming: + case RPCType::kAsyncCQBidiStreaming: + VerifyBidiStreamingCall(); + break; + } + } + static void VerifyCallCommon() { EXPECT_TRUE(pre_send_initial_metadata_); EXPECT_TRUE(pre_send_close_); @@ -638,9 +683,31 @@ class LoggingInterceptorFactory } }; -class ClientInterceptorsEnd2endTest : public ::testing::Test { +class TestScenario { + public: + explicit TestScenario(const RPCType& type) : type_(type) {} + + RPCType type() const { return type_; } + + private: + RPCType type_; +}; + +std::vector CreateTestScenarios() { + std::vector scenarios; + scenarios.emplace_back(RPCType::kSyncUnary); + scenarios.emplace_back(RPCType::kSyncClientStreaming); + scenarios.emplace_back(RPCType::kSyncServerStreaming); + scenarios.emplace_back(RPCType::kSyncBidiStreaming); + scenarios.emplace_back(RPCType::kAsyncCQUnary); + scenarios.emplace_back(RPCType::kAsyncCQServerStreaming); + return scenarios; +} + +class ParameterizedClientInterceptorsEnd2endTest + : public ::testing::TestWithParam { protected: - ClientInterceptorsEnd2endTest() { + ParameterizedClientInterceptorsEnd2endTest() { int port = grpc_pick_unused_port_or_die(); ServerBuilder builder; @@ -650,14 +717,44 @@ class ClientInterceptorsEnd2endTest : public ::testing::Test { server_ = builder.BuildAndStart(); } - ~ClientInterceptorsEnd2endTest() { server_->Shutdown(); } + ~ParameterizedClientInterceptorsEnd2endTest() { server_->Shutdown(); } + + void SendRPC(const std::shared_ptr& channel) { + switch (GetParam().type()) { + case RPCType::kSyncUnary: + MakeCall(channel); + break; + case RPCType::kSyncClientStreaming: + MakeClientStreamingCall(channel); + break; + case RPCType::kSyncServerStreaming: + MakeServerStreamingCall(channel); + break; + case RPCType::kSyncBidiStreaming: + MakeBidiStreamingCall(channel); + break; + case RPCType::kAsyncCQUnary: + MakeAsyncCQCall(channel); + break; + case RPCType::kAsyncCQClientStreaming: + // TODO(yashykt) : Fill this out + break; + case RPCType::kAsyncCQServerStreaming: + MakeAsyncCQServerStreamingCall(channel); + break; + case RPCType::kAsyncCQBidiStreaming: + // TODO(yashykt) : Fill this out + break; + } + } std::string server_address_; - TestServiceImpl service_; + EchoTestServiceStreamingImpl service_; std::unique_ptr server_; }; -TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) { +TEST_P(ParameterizedClientInterceptorsEnd2endTest, + ClientInterceptorLoggingTest) { ChannelArguments args; DummyInterceptor::Reset(); std::vector> @@ -671,12 +768,36 @@ TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) { } auto channel = experimental::CreateCustomChannelWithInterceptors( server_address_, InsecureChannelCredentials(), args, std::move(creators)); - MakeCall(channel); - LoggingInterceptor::VerifyUnaryCall(); + SendRPC(channel); + LoggingInterceptor::VerifyCall(GetParam().type()); // Make sure all 20 dummy interceptors were run EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } +INSTANTIATE_TEST_SUITE_P(ParameterizedClientInterceptorsEnd2end, + ParameterizedClientInterceptorsEnd2endTest, + ::testing::ValuesIn(CreateTestScenarios())); + +class ClientInterceptorsEnd2endTest + : public ::testing::TestWithParam { + protected: + ClientInterceptorsEnd2endTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + ~ClientInterceptorsEnd2endTest() { server_->Shutdown(); } + + std::string server_address_; + TestServiceImpl service_; + std::unique_ptr server_; +}; + TEST_F(ClientInterceptorsEnd2endTest, LameChannelClientInterceptorHijackingTest) { ChannelArguments args; @@ -757,7 +878,26 @@ TEST_F(ClientInterceptorsEnd2endTest, EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 12); } -TEST_F(ClientInterceptorsEnd2endTest, +class ClientInterceptorsCallbackEnd2endTest : public ::testing::Test { + protected: + ClientInterceptorsCallbackEnd2endTest() { + int port = grpc_pick_unused_port_or_die(); + + ServerBuilder builder; + server_address_ = "localhost:" + std::to_string(port); + builder.AddListeningPort(server_address_, InsecureServerCredentials()); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + } + + ~ClientInterceptorsCallbackEnd2endTest() { server_->Shutdown(); } + + std::string server_address_; + TestServiceImpl service_; + std::unique_ptr server_; +}; + +TEST_F(ClientInterceptorsCallbackEnd2endTest, ClientInterceptorLoggingTestWithCallback) { ChannelArguments args; DummyInterceptor::Reset(); @@ -778,7 +918,7 @@ TEST_F(ClientInterceptorsEnd2endTest, EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } -TEST_F(ClientInterceptorsEnd2endTest, +TEST_F(ClientInterceptorsCallbackEnd2endTest, ClientInterceptorFactoryAllowsNullptrReturn) { ChannelArguments args; DummyInterceptor::Reset(); @@ -903,6 +1043,21 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) { EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage()); } +TEST_F(ClientInterceptorsStreamingEnd2endTest, + AsyncCQServerStreamingHijackingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + std::vector> + creators; + creators.push_back( + std::unique_ptr( + new ServerStreamingRpcHijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + MakeAsyncCQServerStreamingCall(channel); + EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage()); +} + TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) { ChannelArguments args; DummyInterceptor::Reset(); diff --git a/test/cpp/end2end/interceptors_util.cc b/test/cpp/end2end/interceptors_util.cc index 6321c35ba4a..7d74a48db62 100644 --- a/test/cpp/end2end/interceptors_util.cc +++ b/test/cpp/end2end/interceptors_util.cc @@ -66,7 +66,6 @@ void MakeServerStreamingCall(const std::shared_ptr& channel) { ctx.AddMetadata("testkey", "testvalue"); req.set_message("Hello"); EchoResponse resp; - string expected_resp = ""; auto reader = stub->ResponseStream(&ctx, req); int count = 0; while (reader->Read(&resp)) { @@ -84,6 +83,7 @@ void MakeBidiStreamingCall(const std::shared_ptr& channel) { EchoRequest req; EchoResponse resp; ctx.AddMetadata("testkey", "testvalue"); + req.mutable_param()->set_echo_metadata(true); auto stream = stub->BidiStream(&ctx); for (auto i = 0; i < kNumStreamingMessages; i++) { req.set_message("Hello" + std::to_string(i)); @@ -96,6 +96,60 @@ void MakeBidiStreamingCall(const std::shared_ptr& channel) { EXPECT_EQ(s.ok(), true); } +void MakeAsyncCQCall(const std::shared_ptr& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + CompletionQueue cq; + EchoRequest send_request; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + + send_request.set_message("Hello"); + cli_ctx.AddMetadata("testkey", "testvalue"); + std::unique_ptr> response_reader( + stub->AsyncEcho(&cli_ctx, send_request, &cq)); + response_reader->Finish(&recv_response, &recv_status, tag(1)); + Verifier().Expect(1, true).Verify(&cq); + EXPECT_EQ(send_request.message(), recv_response.message()); + EXPECT_TRUE(recv_status.ok()); +} + +void MakeAsyncCQClientStreamingCall(const std::shared_ptr& channel) { + // TODO(yashykt) : Fill this out +} + +void MakeAsyncCQServerStreamingCall(const std::shared_ptr& channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + CompletionQueue cq; + EchoRequest send_request; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + + cli_ctx.AddMetadata("testkey", "testvalue"); + send_request.set_message("Hello"); + std::unique_ptr> cli_stream( + stub->AsyncResponseStream(&cli_ctx, send_request, &cq, tag(1))); + Verifier().Expect(1, true).Verify(&cq); + // Read the expected number of messages + for (int i = 0; i < kNumStreamingMessages; i++) { + cli_stream->Read(&recv_response, tag(2)); + Verifier().Expect(2, true).Verify(&cq); + ASSERT_EQ(recv_response.message(), send_request.message()); + } + // The next read should fail + cli_stream->Read(&recv_response, tag(3)); + Verifier().Expect(3, false).Verify(&cq); + // Get the status + cli_stream->Finish(&recv_status, tag(4)); + Verifier().Expect(4, true).Verify(&cq); + EXPECT_TRUE(recv_status.ok()); +} + +void MakeAsyncCQBidiStreamingCall(const std::shared_ptr& channel) { + // TODO(yashykt) : Fill this out +} + void MakeCallbackCall(const std::shared_ptr& channel) { auto stub = grpc::testing::EchoTestService::NewStub(channel); ClientContext ctx; @@ -109,7 +163,6 @@ void MakeCallbackCall(const std::shared_ptr& channel) { EchoResponse resp; stub->experimental_async()->Echo(&ctx, &req, &resp, [&resp, &mu, &done, &cv](Status s) { - // gpr_log(GPR_ERROR, "got the callback"); EXPECT_EQ(s.ok(), true); EXPECT_EQ(resp.message(), "Hello"); std::lock_guard l(mu); diff --git a/test/cpp/end2end/interceptors_util.h b/test/cpp/end2end/interceptors_util.h index 6027c9b3dcf..cac9f3a6bb8 100644 --- a/test/cpp/end2end/interceptors_util.h +++ b/test/cpp/end2end/interceptors_util.h @@ -102,6 +102,16 @@ class EchoTestServiceStreamingImpl : public EchoTestService::Service { public: ~EchoTestServiceStreamingImpl() override {} + Status Echo(ServerContext* context, const EchoRequest* request, + EchoResponse* response) { + auto client_metadata = context->client_metadata(); + for (const auto& pair : client_metadata) { + context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second)); + } + response->set_message(request->message()); + return Status::OK; + } + Status BidiStream( ServerContext* context, grpc::ServerReaderWriter* stream) override { @@ -162,6 +172,14 @@ void MakeServerStreamingCall(const std::shared_ptr& channel); void MakeBidiStreamingCall(const std::shared_ptr& channel); +void MakeAsyncCQCall(const std::shared_ptr& channel); + +void MakeAsyncCQClientStreamingCall(const std::shared_ptr& channel); + +void MakeAsyncCQServerStreamingCall(const std::shared_ptr& channel); + +void MakeAsyncCQBidiStreamingCall(const std::shared_ptr& channel); + void MakeCallbackCall(const std::shared_ptr& channel); bool CheckMetadata(const std::multimap& map,