Add method to fail recv msg for hijacked rpcs

pull/17179/head
Yash Tibrewal 6 years ago
parent 6a368df7ab
commit 699c10386d
  1. 4
      include/grpcpp/impl/codegen/call_op_set.h
  2. 3
      include/grpcpp/impl/codegen/interceptor.h
  3. 14
      include/grpcpp/impl/codegen/interceptor_common.h
  4. 2
      include/grpcpp/impl/codegen/server_interface.h
  5. 4
      src/cpp/server/server_cc.cc
  6. 101
      test/cpp/end2end/client_interceptors_end2end_test.cc

@ -406,7 +406,7 @@ class CallOpRecvMessage {
void SetInterceptionHookPoint( void SetInterceptionHookPoint(
InterceptorBatchMethodsImpl* interceptor_methods) { InterceptorBatchMethodsImpl* interceptor_methods) {
interceptor_methods->SetRecvMessage(message_); interceptor_methods->SetRecvMessage(message_, &got_message);
} }
void SetFinishInterceptionHookPoint( void SetFinishInterceptionHookPoint(
@ -501,7 +501,7 @@ class CallOpGenericRecvMessage {
void SetInterceptionHookPoint( void SetInterceptionHookPoint(
InterceptorBatchMethodsImpl* interceptor_methods) { InterceptorBatchMethodsImpl* interceptor_methods) {
interceptor_methods->SetRecvMessage(message_); interceptor_methods->SetRecvMessage(message_, &got_message);
} }
void SetFinishInterceptionHookPoint( void SetFinishInterceptionHookPoint(

@ -118,6 +118,9 @@ class InterceptorBatchMethods {
// only interceptors after the current interceptor are created from the // only interceptors after the current interceptor are created from the
// factory objects registered with the channel. // factory objects registered with the channel.
virtual std::unique_ptr<ChannelInterface> GetInterceptedChannel() = 0; virtual std::unique_ptr<ChannelInterface> GetInterceptedChannel() = 0;
// On a hijacked RPC, an interceptor can decide to fail a RECV MESSAGE op.
virtual void FailHijackedRecvMessage() = 0;
}; };
class Interceptor { class Interceptor {

@ -134,7 +134,10 @@ class InterceptorBatchMethodsImpl
send_trailing_metadata_ = metadata; send_trailing_metadata_ = metadata;
} }
void SetRecvMessage(void* message) { recv_message_ = message; } void SetRecvMessage(void* message, bool* got_message) {
recv_message_ = message;
got_message_ = got_message;
}
void SetRecvInitialMetadata(MetadataMap* map) { void SetRecvInitialMetadata(MetadataMap* map) {
recv_initial_metadata_ = map; recv_initial_metadata_ = map;
@ -157,6 +160,8 @@ class InterceptorBatchMethodsImpl
info->channel(), current_interceptor_index_ + 1)); info->channel(), current_interceptor_index_ + 1));
} }
void FailHijackedRecvMessage() override { *got_message_ = false; }
// Clears all state // Clears all state
void ClearState() { void ClearState() {
reverse_ = false; reverse_ = false;
@ -345,6 +350,7 @@ class InterceptorBatchMethodsImpl
std::multimap<grpc::string, grpc::string>* send_trailing_metadata_ = nullptr; std::multimap<grpc::string, grpc::string>* send_trailing_metadata_ = nullptr;
void* recv_message_ = nullptr; void* recv_message_ = nullptr;
bool* got_message_ = nullptr;
MetadataMap* recv_initial_metadata_ = nullptr; MetadataMap* recv_initial_metadata_ = nullptr;
@ -451,6 +457,12 @@ class CancelInterceptorBatchMethods
"method which has a Cancel notification"); "method which has a Cancel notification");
return std::unique_ptr<ChannelInterface>(nullptr); return std::unique_ptr<ChannelInterface>(nullptr);
} }
void FailHijackedRecvMessage() override {
GPR_CODEGEN_ASSERT(false &&
"It is illegal to call FailHijackedRecvMessage on a "
"method which has a Cancel notification");
}
}; };
} // namespace internal } // namespace internal
} // namespace grpc } // namespace grpc

@ -270,7 +270,7 @@ class ServerInterface : public internal::CallHook {
/* Set interception point for recv message */ /* Set interception point for recv message */
interceptor_methods_.AddInterceptionHookPoint( interceptor_methods_.AddInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_MESSAGE); experimental::InterceptionHookPoints::POST_RECV_MESSAGE);
interceptor_methods_.SetRecvMessage(request_); interceptor_methods_.SetRecvMessage(request_, nullptr);
return RegisteredAsyncRequest::FinalizeResult(tag, status); return RegisteredAsyncRequest::FinalizeResult(tag, status);
} }

@ -280,7 +280,7 @@ class Server::SyncRequest final : public internal::CompletionQueueTag {
request_payload_ = nullptr; request_payload_ = nullptr;
interceptor_methods_.AddInterceptionHookPoint( interceptor_methods_.AddInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_MESSAGE); experimental::InterceptionHookPoints::POST_RECV_MESSAGE);
interceptor_methods_.SetRecvMessage(request_); interceptor_methods_.SetRecvMessage(request_, nullptr);
} }
if (interceptor_methods_.RunInterceptors( if (interceptor_methods_.RunInterceptors(
@ -447,7 +447,7 @@ class Server::CallbackRequest final : public internal::CompletionQueueTag {
req_->request_payload_ = nullptr; req_->request_payload_ = nullptr;
req_->interceptor_methods_.AddInterceptionHookPoint( req_->interceptor_methods_.AddInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_MESSAGE); experimental::InterceptionHookPoints::POST_RECV_MESSAGE);
req_->interceptor_methods_.SetRecvMessage(req_->request_); req_->interceptor_methods_.SetRecvMessage(req_->request_, nullptr);
} }
if (req_->interceptor_methods_.RunInterceptors( if (req_->interceptor_methods_.RunInterceptors(

@ -269,6 +269,92 @@ class HijackingInterceptorMakesAnotherCallFactory
} }
}; };
class ServerStreamingRpcHijackingInterceptor
: public experimental::Interceptor {
public:
ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
info_ = info;
}
virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
bool hijack = false;
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
auto* map = methods->GetSendInitialMetadata();
// Check that we can see the test metadata
ASSERT_EQ(map->size(), static_cast<unsigned>(1));
auto iterator = map->begin();
EXPECT_EQ("testkey", iterator->first);
EXPECT_EQ("testvalue", iterator->second);
hijack = true;
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
EchoRequest req;
auto* buffer = methods->GetSendMessage();
auto copied_buffer = *buffer;
EXPECT_TRUE(
SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
.ok());
EXPECT_EQ(req.message(), "Hello");
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
// Got nothing to do here for now
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
auto* map = methods->GetRecvTrailingMetadata();
bool found = false;
// Check that we received the metadata as an echo
for (const auto& pair : *map) {
found = pair.first.starts_with("testkey") &&
pair.second.starts_with("testvalue");
if (found) break;
}
EXPECT_EQ(found, true);
auto* status = methods->GetRecvStatus();
EXPECT_EQ(status->ok(), true);
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
if (++count > 10) {
methods->FailHijackedRecvMessage();
}
EchoResponse* resp =
static_cast<EchoResponse*>(methods->GetRecvMessage());
resp->set_message("Hello");
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
auto* map = methods->GetRecvTrailingMetadata();
// insert the metadata that we want
EXPECT_EQ(map->size(), static_cast<unsigned>(0));
map->insert(std::make_pair("testkey", "testvalue"));
auto* status = methods->GetRecvStatus();
*status = Status(StatusCode::OK, "");
}
if (hijack) {
methods->Hijack();
} else {
methods->Proceed();
}
}
private:
experimental::ClientRpcInfo* info_;
int count = 0;
};
class ServerStreamingRpcHijackingInterceptorFactory
: public experimental::ClientInterceptorFactoryInterface {
public:
virtual experimental::Interceptor* CreateClientInterceptor(
experimental::ClientRpcInfo* info) override {
return new ServerStreamingRpcHijackingInterceptor(info);
}
};
class LoggingInterceptor : public experimental::Interceptor { class LoggingInterceptor : public experimental::Interceptor {
public: public:
LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; } LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; }
@ -535,6 +621,21 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
} }
TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
ChannelArguments args;
DummyInterceptor::Reset();
auto creators = std::unique_ptr<std::vector<
std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>(
new std::vector<
std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>());
creators->push_back(
std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>(
new ServerStreamingRpcHijackingInterceptorFactory()));
auto channel = experimental::CreateCustomChannelWithInterceptors(
server_address_, InsecureChannelCredentials(), args, std::move(creators));
MakeServerStreamingCall(channel);
}
TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) { TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
ChannelArguments args; ChannelArguments args;
DummyInterceptor::Reset(); DummyInterceptor::Reset();

Loading…
Cancel
Save