From 5d7d6c0fbdcac75ea482e1fde3e128cd0c1646c1 Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Wed, 14 Nov 2018 17:35:26 -0800 Subject: [PATCH] Add method to fail hijacked send messages --- include/grpcpp/impl/codegen/call_op_set.h | 10 ++- include/grpcpp/impl/codegen/interceptor.h | 4 + .../grpcpp/impl/codegen/interceptor_common.h | 18 ++++- .../client_interceptors_end2end_test.cc | 73 +++++++++++++++++++ 4 files changed, 102 insertions(+), 3 deletions(-) diff --git a/include/grpcpp/impl/codegen/call_op_set.h b/include/grpcpp/impl/codegen/call_op_set.h index b4c34a01c9a..1f2b88e9e14 100644 --- a/include/grpcpp/impl/codegen/call_op_set.h +++ b/include/grpcpp/impl/codegen/call_op_set.h @@ -314,14 +314,19 @@ class CallOpSendMessage { // Flags are per-message: clear them after use. write_options_.Clear(); } - void FinishOp(bool* status) { send_buf_.Clear(); } + void FinishOp(bool* status) { + send_buf_.Clear(); + if (hijacked_ && failed_send_) { + *status = false; + } + } void SetInterceptionHookPoint( InterceptorBatchMethodsImpl* interceptor_methods) { if (!send_buf_.Valid()) return; interceptor_methods->AddInterceptionHookPoint( experimental::InterceptionHookPoints::PRE_SEND_MESSAGE); - interceptor_methods->SetSendMessage(&send_buf_); + interceptor_methods->SetSendMessage(&send_buf_, &failed_send_); } void SetFinishInterceptionHookPoint( @@ -333,6 +338,7 @@ class CallOpSendMessage { private: bool hijacked_ = false; + bool failed_send_ = false; ByteBuffer send_buf_; WriteOptions write_options_; }; diff --git a/include/grpcpp/impl/codegen/interceptor.h b/include/grpcpp/impl/codegen/interceptor.h index e449e44a23b..47239332c86 100644 --- a/include/grpcpp/impl/codegen/interceptor.h +++ b/include/grpcpp/impl/codegen/interceptor.h @@ -118,6 +118,10 @@ class InterceptorBatchMethods { // only interceptors after the current interceptor are created from the // factory objects registered with the channel. virtual std::unique_ptr GetInterceptedChannel() = 0; + + // On a hijacked RPC/ to-be hijacked RPC, this can be called to fail a SEND + // MESSAGE op + virtual void FailHijackedSendMessage() = 0; }; class Interceptor { diff --git a/include/grpcpp/impl/codegen/interceptor_common.h b/include/grpcpp/impl/codegen/interceptor_common.h index d0aa23cb0a0..601a929afe6 100644 --- a/include/grpcpp/impl/codegen/interceptor_common.h +++ b/include/grpcpp/impl/codegen/interceptor_common.h @@ -110,12 +110,21 @@ class InterceptorBatchMethodsImpl Status* GetRecvStatus() override { return recv_status_; } + void FailHijackedSendMessage() override { + GPR_CODEGEN_ASSERT(hooks_[static_cast( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]); + *fail_send_message_ = true; + } + std::multimap* GetRecvTrailingMetadata() override { return recv_trailing_metadata_->map(); } - void SetSendMessage(ByteBuffer* buf) { send_message_ = buf; } + void SetSendMessage(ByteBuffer* buf, bool* fail_send_message) { + send_message_ = buf; + fail_send_message_ = fail_send_message; + } void SetSendInitialMetadata( std::multimap* metadata) { @@ -334,6 +343,7 @@ class InterceptorBatchMethodsImpl std::function callback_; ByteBuffer* send_message_ = nullptr; + bool* fail_send_message_ = nullptr; std::multimap* send_initial_metadata_; @@ -451,6 +461,12 @@ class CancelInterceptorBatchMethods "method which has a Cancel notification"); return std::unique_ptr(nullptr); } + + void FailHijackedSendMessage() override { + GPR_CODEGEN_ASSERT(false && + "It is illegal to call FailHijackedSendMessage on a " + "method which has a Cancel notification"); + } }; } // namespace internal } // namespace grpc diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 0b34ec93ae7..81efd154525 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -269,6 +269,49 @@ class HijackingInterceptorMakesAnotherCallFactory } }; +class ClientStreamingRpcHijackingInterceptor + : public experimental::Interceptor { + public: + ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) { + info_ = info; + } + virtual void Intercept(experimental::InterceptorBatchMethods* methods) { + bool hijack = false; + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + hijack = true; + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { + if (++count_ > 10) { + methods->FailHijackedSendMessage(); + } + } + if (methods->QueryInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { + auto* status = methods->GetRecvStatus(); + *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages"); + } + if (hijack) { + methods->Hijack(); + } else { + methods->Proceed(); + } + } + + private: + experimental::ClientRpcInfo* info_; + int count_ = 0; +}; +class ClientStreamingRpcHijackingInterceptorFactory + : public experimental::ClientInterceptorFactoryInterface { + public: + virtual experimental::Interceptor* CreateClientInterceptor( + experimental::ClientRpcInfo* info) override { + return new ClientStreamingRpcHijackingInterceptor(info); + } +}; + class LoggingInterceptor : public experimental::Interceptor { public: LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; } @@ -535,6 +578,36 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) { EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); } +TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) { + ChannelArguments args; + auto creators = std::unique_ptr>>( + new std::vector< + std::unique_ptr>()); + creators->push_back( + std::unique_ptr( + new ClientStreamingRpcHijackingInterceptorFactory())); + auto channel = experimental::CreateCustomChannelWithInterceptors( + server_address_, InsecureChannelCredentials(), args, std::move(creators)); + + auto stub = grpc::testing::EchoTestService::NewStub(channel); + ClientContext ctx; + EchoRequest req; + EchoResponse resp; + req.mutable_param()->set_echo_metadata(true); + req.set_message("Hello"); + string expected_resp = ""; + auto writer = stub->RequestStream(&ctx, &resp); + for (int i = 0; i < 10; i++) { + EXPECT_TRUE(writer->Write(req)); + expected_resp += "Hello"; + } + // Expect that the interceptor will reject the 11th message + EXPECT_FALSE(writer->Write(req)); + Status s = writer->Finish(); + EXPECT_EQ(s.ok(), false); +} + TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) { ChannelArguments args; DummyInterceptor::Reset();