diff --git a/include/grpcpp/impl/codegen/call.h b/include/grpcpp/impl/codegen/call.h index 2fcf306c8dd..ebb5d410016 100644 --- a/include/grpcpp/impl/codegen/call.h +++ b/include/grpcpp/impl/codegen/call.h @@ -419,9 +419,10 @@ class CallOpRecvMessage { void SetHijackingState( experimental::InterceptorBatchMethods* interceptor_methods) { hijacked_ = true; - if (message_ == nullptr || !got_message) return; + if (message_ == nullptr) return; interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_RECV_MESSAGE); + got_message = true; } private: @@ -514,7 +515,7 @@ class CallOpGenericRecvMessage { void SetHijackingState( experimental::InterceptorBatchMethods* interceptor_methods) { hijacked_ = true; - if (!deserialize_ || !got_message) return; + if (!deserialize_) return; interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_RECV_MESSAGE); } @@ -886,14 +887,8 @@ class InterceptorBatchMethodsImpl } else { /* We are going up the stack of interceptors */ if (curr_iteration_ >= 0) { - if (rpc_info->hijacked_ && - curr_iteration_ < rpc_info->hijacked_interceptor_) { - /* This is a hijacked RPC and we are done running the hijacking - * interceptor. */ - ops_->ContinueFinalizeResultAfterInterception(); - } else { - rpc_info->RunInterceptor(this, curr_iteration_); - } + /* Continue running interceptors */ + rpc_info->RunInterceptor(this, curr_iteration_); } else { /* we are done running all the interceptors without any hijacking */ ops_->ContinueFinalizeResultAfterInterception(); @@ -918,18 +913,16 @@ class InterceptorBatchMethodsImpl hooks_[static_cast(type)] = true; } - virtual void GetSendMessage(ByteBuffer** buf) override { - *buf = send_message_; - } + virtual ByteBuffer* GetSendMessage() override { return send_message_; } - virtual void GetSendInitialMetadata( - std::multimap** metadata) override { - *metadata = send_initial_metadata_; + virtual std::multimap* GetSendInitialMetadata() + override { + return send_initial_metadata_; } - virtual void GetSendStatus(Status* status) override { - *status = Status(static_cast(*code_), *error_message_, - *error_details_); + virtual Status GetSendStatus() override { + return Status(static_cast(*code_), *error_message_, + *error_details_); } virtual void ModifySendStatus(const Status& status) override { @@ -938,27 +931,23 @@ class InterceptorBatchMethodsImpl *error_message_ = status.error_message(); } - virtual void GetSendTrailingMetadata( - std::multimap** metadata) override { - *metadata = send_trailing_metadata_; + virtual std::multimap* GetSendTrailingMetadata() + override { + return send_trailing_metadata_; } - virtual void GetRecvMessage(void** message) override { - *message = recv_message_; - } + virtual void* GetRecvMessage() override { return recv_message_; } - virtual void GetRecvInitialMetadata( - std::multimap** map) override { - *map = recv_initial_metadata_->map(); + virtual std::multimap* + GetRecvInitialMetadata() override { + return recv_initial_metadata_->map(); } - virtual void GetRecvStatus(Status** status) override { - *status = recv_status_; - } + virtual Status* GetRecvStatus() override { return recv_status_; } - virtual void GetRecvTrailingMetadata( - std::multimap** map) override { - *map = recv_trailing_metadata_->map(); + virtual std::multimap* + GetRecvTrailingMetadata() override { + return recv_trailing_metadata_->map(); } virtual void SetSendMessage(ByteBuffer* buf) override { send_message_ = buf; } @@ -999,7 +988,6 @@ class InterceptorBatchMethodsImpl void SetReverse() { reverse_ = true; ClearHookPoints(); - curr_iteration_ = 0; } /* This needs to be set before interceptors are run */ @@ -1014,14 +1002,17 @@ class InterceptorBatchMethodsImpl return true; } if (!reverse_) { - rpc_info->RunInterceptor(this, 0); + curr_iteration_ = 0; } else { if (rpc_info->hijacked_) { - rpc_info->RunInterceptor(this, rpc_info->hijacked_interceptor_); + curr_iteration_ = rpc_info->hijacked_interceptor_; + gpr_log(GPR_ERROR, "running from the hijacked %d", + rpc_info->hijacked_interceptor_); } else { - rpc_info->RunInterceptor(this, rpc_info->interceptors_.size() - 1); + curr_iteration_ = rpc_info->interceptors_.size() - 1; } } + rpc_info->RunInterceptor(this, curr_iteration_); return false; } diff --git a/include/grpcpp/impl/codegen/client_interceptor.h b/include/grpcpp/impl/codegen/client_interceptor.h index 50272721f0b..c1feecd0aed 100644 --- a/include/grpcpp/impl/codegen/client_interceptor.h +++ b/include/grpcpp/impl/codegen/client_interceptor.h @@ -73,6 +73,7 @@ class ClientRpcInfo { // Getter methods const char* method() { return method_; } const Channel* channel() { return channel_; } + grpc::ClientContext* client_context() { return ctx_; } // const grpc::InterceptedMessage& outgoing_message(); // grpc::InterceptedMessage *mutable_outgoing_message(); // const grpc::InterceptedMessage& received_message(); diff --git a/include/grpcpp/impl/codegen/interceptor.h b/include/grpcpp/impl/codegen/interceptor.h index e9fd423b838..390e5ab7fb4 100644 --- a/include/grpcpp/impl/codegen/interceptor.h +++ b/include/grpcpp/impl/codegen/interceptor.h @@ -77,27 +77,27 @@ class InterceptorBatchMethods { virtual void AddInterceptionHookPoint(InterceptionHookPoints type) = 0; - virtual void GetSendMessage(ByteBuffer** buf) = 0; + virtual ByteBuffer* GetSendMessage() = 0; - virtual void GetSendInitialMetadata( - std::multimap** metadata) = 0; + virtual std::multimap* + GetSendInitialMetadata() = 0; - virtual void GetSendStatus(Status* status) = 0; + virtual Status GetSendStatus() = 0; virtual void ModifySendStatus(const Status& status) = 0; - virtual void GetSendTrailingMetadata( - std::multimap** metadata) = 0; + virtual std::multimap* + GetSendTrailingMetadata() = 0; - virtual void GetRecvMessage(void** message) = 0; + virtual void* GetRecvMessage() = 0; - virtual void GetRecvInitialMetadata( - std::multimap** map) = 0; + virtual std::multimap* + GetRecvInitialMetadata() = 0; - virtual void GetRecvStatus(Status** status) = 0; + virtual Status* GetRecvStatus() = 0; - virtual void GetRecvTrailingMetadata( - std::multimap** map) = 0; + virtual std::multimap* + GetRecvTrailingMetadata() = 0; virtual void SetSendMessage(ByteBuffer* buf) = 0; diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 6d10f5da5b1..17b2bdb9a4d 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -60,37 +60,198 @@ class ClientInterceptorsEnd2endTest : public ::testing::Test { std::unique_ptr server_; }; +class DummyInterceptor : public experimental::ClientInterceptor { + public: + DummyInterceptor(experimental::ClientRpcInfo* info) {} + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + num_times_run_++; + } + methods->Proceed(); + } + + static void Reset() { num_times_run_.store(0); } + + static int GetNumTimesRun() { return num_times_run_.load(); } + + private: + static std::atomic num_times_run_; +}; + +std::atomic DummyInterceptor::num_times_run_; + +class DummyInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::ClientInterceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new DummyInterceptor(info); + } +}; + +class HijackingInterceptor : public experimental::ClientInterceptor { + public: + HijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + // Make sure it is the right method + EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0); + } + + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + gpr_log(GPR_ERROR, "ran this"); + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), 1); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + EchoRequest req; + auto* buffer = methods->GetSendMessage(); + auto copied_buffer = *buffer; + SerializationTraits::Deserialize(&copied_buffer, &req); + EXPECT_EQ(req.message(), "Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { + // Got nothing to do here for now + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + // Check that we got the hijacked message, and re-insert the expected + // message + EXPECT_EQ(resp->message(), "Hello1"); + resp->set_message("Hello"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) { + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here at the moment + EXPECT_EQ(map->size(), 0); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { + // Insert a different message than expected + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + resp->set_message("Hello1"); + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* map = methods->GetRecvTrailingMetadata(); + // insert the metadata that we want + EXPECT_EQ(map->size(), 0); + map->insert(std::make_pair("testkey", "testvalue")); + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::OK, ""); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + private: + experimental::ClientRpcInfo* info_; +}; + +class HijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::ClientInterceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new HijackingInterceptor(info); + } +}; + class LoggingInterceptor : public experimental::ClientInterceptor { public: - LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; } + LoggingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + // Make sure it is the right method + EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0); + } virtual void Intercept(experimental::InterceptorBatchMethods* methods) { - gpr_log(GPR_ERROR, "here\n"); + gpr_log(GPR_ERROR, "ran this"); if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { - gpr_log(GPR_ERROR, "here\n"); + auto* map = methods->GetSendInitialMetadata(); + // Check that we can see the test metadata + ASSERT_EQ(map->size(), 1); + auto iterator = map->begin(); + EXPECT_EQ("testkey", iterator->first); + EXPECT_EQ("testvalue", iterator->second); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { - gpr_log(GPR_ERROR, "here\n"); + EchoRequest req; + auto* buffer = methods->GetSendMessage(); + auto copied_buffer = *buffer; + SerializationTraits::Deserialize(&copied_buffer, &req); + EXPECT_EQ(req.message(), "Hello"); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { - gpr_log(GPR_ERROR, "here\n"); + // Got nothing to do here for now } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { - gpr_log(GPR_ERROR, "here\n"); + auto* map = methods->GetRecvInitialMetadata(); + // Got nothing better to do here for now + EXPECT_EQ(map->size(), 0); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { - gpr_log(GPR_ERROR, "here\n"); + EchoResponse* resp = + static_cast(methods->GetRecvMessage()); + EXPECT_EQ(resp->message(), "Hello"); } if (methods->QueryInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_STATUS)) { - gpr_log(GPR_ERROR, "here\n"); + auto* map = methods->GetRecvTrailingMetadata(); + bool found = false; + // Check that we received the metadata as an echo + for (const auto& pair : *map) { + found = pair.first.starts_with("testkey") && + pair.second.starts_with("testvalue"); + if (found) break; + } + EXPECT_EQ(found, true); + auto* status = methods->GetRecvStatus(); + EXPECT_EQ(status->ok(), true); } - gpr_log(GPR_ERROR, "here\n"); methods->Proceed(); } @@ -108,6 +269,72 @@ class LoggingInterceptorFactory }; TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr>>( + new std::vector< + std::unique_ptr>()); + creators->push_back(std::unique_ptr( + new LoggingInterceptorFactory())); + // Add 20 dummy interceptors + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + Status s = stub->Echo(&ctx, req, &resp); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); + // Make sure all 20 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) { + ChannelArguments args; + DummyInterceptor::Reset(); + auto creators = std::unique_ptr>>( + new std::vector< + std::unique_ptr>()); + // Add 10 dummy interceptors before hijacking interceptor + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr( + new DummyInterceptorFactory())); + } + creators->push_back(std::unique_ptr( + new HijackingInterceptorFactory())); + // Add 10 dummy interceptors after hijacking interceptor + for (auto i = 0; i < 20; i++) { + creators->push_back(std::unique_ptr( + new DummyInterceptorFactory())); + } + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); + req.set_message("Hello"); + EchoResponse resp; + Status s = stub->Echo(&ctx, req, &resp); + EXPECT_EQ(s.ok(), true); + EXPECT_EQ(resp.message(), "Hello"); + // Make sure only 10 dummy interceptors were run + EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); +} + +TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) { ChannelArguments args; auto creators = std::unique_ptr>>( @@ -115,17 +342,21 @@ TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) { std::unique_ptr>()); creators->push_back(std::unique_ptr( new LoggingInterceptorFactory())); + creators->push_back(std::unique_ptr( + new HijackingInterceptorFactory())); auto channel = experimental::CreateCustomChannelWithInterceptors( server_address_, InsecureChannelCredentials(), args, std::move(creators)); auto stub = grpc::testing::EchoTestService::NewStub(channel); ClientContext ctx; EchoRequest req; + req.mutable_param()->set_echo_metadata(true); + ctx.AddMetadata("testkey", "testvalue"); req.set_message("Hello"); EchoResponse resp; Status s = stub->Echo(&ctx, req, &resp); EXPECT_EQ(s.ok(), true); - std::cout << resp.message() << "\n"; + EXPECT_EQ(resp.message(), "Hello"); } } // namespace