diff --git a/include/grpcpp/impl/codegen/client_interceptor.h b/include/grpcpp/impl/codegen/client_interceptor.h index 6feec224ce9..06f009e7d31 100644 --- a/include/grpcpp/impl/codegen/client_interceptor.h +++ b/include/grpcpp/impl/codegen/client_interceptor.h @@ -47,7 +47,7 @@ class ClientRpcInfo { public: ClientRpcInfo() {} ClientRpcInfo(grpc::ClientContext* ctx, const char* method, - const grpc::Channel* channel, + grpc::Channel* channel, const std::vector>& creators) : ctx_(ctx), method_(method), channel_(channel) { @@ -64,7 +64,7 @@ class ClientRpcInfo { // Getter methods const char* method() { return method_; } - const Channel* channel() { return channel_; } + Channel* channel() { return channel_; } grpc::ClientContext* client_context() { return ctx_; } public: @@ -79,7 +79,7 @@ class ClientRpcInfo { private: grpc::ClientContext* ctx_ = nullptr; const char* method_ = nullptr; - const grpc::Channel* channel_ = nullptr; + grpc::Channel* channel_ = nullptr; std::vector> interceptors_; bool hijacked_ = false; int hijacked_interceptor_ = false; diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index d53c3534a94..5124044a8b3 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -243,13 +243,13 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { interceptor_methods_.SetCall(&call_); interceptor_methods_.SetReverse(); - /* Set interception point for RECV INITIAL METADATA */ + // Set interception point for RECV INITIAL METADATA interceptor_methods_.AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); interceptor_methods_.SetRecvInitialMetadata(&ctx_.client_metadata_); if (has_request_payload_) { - /* Set interception point for RECV MESSAGE */ + // Set interception point for RECV MESSAGE auto* handler = resources_ ? method_->handler() : server_->resource_exhausted_handler_.get(); request_ = handler->Deserialize(request_payload_, &request_status_); @@ -264,8 +264,8 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { if (interceptor_methods_.RunInterceptors(f)) { ContinueRunAfterInterception(); } else { - /* There were interceptors to be run, so ContinueRunAfterInterception - will be run when interceptors are done. */ + // There were interceptors to be run, so ContinueRunAfterInterception + // will be run when interceptors are done. } } @@ -318,7 +318,6 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { grpc_metadata_array request_metadata_; grpc_byte_buffer* request_payload_; grpc_completion_queue* cq_; - bool done_intercepting_ = false; }; // Implementation of ThreadManager. Each instance of SyncRequestThreadManager @@ -763,7 +762,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, context_->set_call(call_); context_->cq_ = call_cq_; if (call_wrapper_.call() == nullptr) { - /* Fill it since it is empty. */ + // Fill it since it is empty. call_wrapper_ = internal::Call( call_, server_, call_cq_, server_->max_receive_message_size(), nullptr); } @@ -773,7 +772,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, if (*status && call_ && call_wrapper_.server_rpc_info()) { done_intercepting_ = true; - /* Set interception point for RECV INITIAL METADATA */ + // Set interception point for RECV INITIAL METADATA interceptor_methods_.AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); interceptor_methods_.SetRecvInitialMetadata(&context_->client_metadata_); @@ -781,11 +780,11 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, ContinueFinalizeResultAfterInterception, this); if (interceptor_methods_.RunInterceptors(f)) { - /* There are no interceptors to run. Continue */ + // There are no interceptors to run. Continue } else { - /* There were interceptors to be run, so - ContinueFinalizeResultAfterInterception will be run when interceptors are - done. */ + // There were interceptors to be run, so + // ContinueFinalizeResultAfterInterception will be run when interceptors + // are done. return false; } } @@ -802,7 +801,7 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, void ServerInterface::BaseAsyncRequest:: ContinueFinalizeResultAfterInterception() { context_->BeginCompletionOp(&call_wrapper_); - /* Queue a tag which will be returned immediately */ + // Queue a tag which will be returned immediately dummy_alarm_ = new Alarm(); static_cast(dummy_alarm_) ->Set(notification_cq_, @@ -844,7 +843,7 @@ ServerInterface::GenericAsyncRequest::GenericAsyncRequest( bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag, bool* status) { - /* If we are done intercepting, there is nothing more for us to do */ + // If we are done intercepting, there is nothing more for us to do if (done_intercepting_) { return BaseAsyncRequest::FinalizeResult(tag, status); } @@ -870,7 +869,7 @@ bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag, bool Server::UnimplementedAsyncRequest::FinalizeResult(void** tag, bool* status) { if (GenericAsyncRequest::FinalizeResult(tag, status)) { - /* We either had no interceptors run or we are done interceptinh */ + // We either had no interceptors run or we are done intercepting if (*status) { new UnimplementedAsyncRequest(server_, cq_); new UnimplementedAsyncResponse(this); @@ -878,7 +877,7 @@ bool Server::UnimplementedAsyncRequest::FinalizeResult(void** tag, delete this; } } else { - /* The tag was swallowed due to interception. We will see it again. */ + // The tag was swallowed due to interception. We will see it again. } return false; } diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 8537f356023..2e0db8a9b9c 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -60,6 +60,8 @@ class ClientInterceptorsEnd2endTest : public ::testing::Test { std::unique_ptr server_; }; +/* This interceptor does nothing. Just keeps a global count on the number of + * times it was invoked. */ class DummyInterceptor : public experimental::Interceptor { public: DummyInterceptor(experimental::ClientRpcInfo* info) {} @@ -91,6 +93,7 @@ class DummyInterceptorFactory } }; +/* Hijacks Echo RPC and fills in the expected values */ class HijackingInterceptor : public experimental::Interceptor { public: HijackingInterceptor(experimental::ClientRpcInfo* info) { @@ -195,6 +198,111 @@ class HijackingInterceptorFactory } }; +class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor { + public: + HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo* info) { + info_ = info; + // Make sure it is the right method + EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0); + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + gpr_log(GPR_ERROR, "ran this"); + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), 1); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + hijack = true; + // Make a copy of the map + metadata_map_ = *map; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSendMessage(); + auto copied_buffer = *buffer; + SerializationTraits::Deserialize(&copied_buffer, &req); + EXPECT_EQ(req.message(), "Hello"); + auto stub = grpc::testing::EchoTestService::NewStub( + std::shared_ptr(info_->channel())); + ClientContext ctx; + EchoResponse resp; + Status s = stub->Echo(&ctx, req, &resp); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + // Check that we got the hijacked message, and re-insert the expected + // message + EXPECT_EQ(resp->message(), "Hello1"); + resp->set_message("Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here at the moment + EXPECT_EQ(map->size(), 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + // Insert a different message than expected + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + resp->set_message("Hello1"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), 0); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + private: + experimental::ClientRpcInfo* info_; + std::multimap metadata_map_; +}; + class LoggingInterceptor : public experimental::Interceptor { public: LoggingInterceptor(experimental::ClientRpcInfo* info) { @@ -268,6 +376,19 @@ class LoggingInterceptorFactory } }; +void MakeCall(std::shared_ptr channel) { + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + Status s = stub->Echo(&ctx, req, &resp); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); +} + TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) { ChannelArguments args; DummyInterceptor::Reset(); @@ -284,16 +405,7 @@ TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) { } auto channel = experimental::CreateCustomChannelWithInterceptors( server_address_, InsecureChannelCredentials(), args, std::move(creators)); - auto stub = grpc::testing::EchoTestService::NewStub(channel); - ClientContext ctx; - EchoRequest req; - req.mutable_param()->set_echo_metadata(true); - ctx.AddMetadata("testkey", "testvalue"); - req.set_message("Hello"); - EchoResponse resp; - Status s = stub->Echo(&ctx, req, &resp); - EXPECT_EQ(s.ok(), true); - EXPECT_EQ(resp.message(), "Hello"); + MakeCall(channel); // Make sure all 20 dummy interceptors were run EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); }