Merge pull request #17179 from yashykt/failhijackedrecv

Add interceptor methods to fail recv msg for hijacked rpcs and set recv message to nullptr on failure
pull/17647/head
Yash Tibrewal 6 years ago committed by GitHub
commit 46bd2f7adb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 13
      include/grpcpp/impl/codegen/call_op_set.h
  2. 9
      include/grpcpp/impl/codegen/interceptor.h
  3. 18
      include/grpcpp/impl/codegen/interceptor_common.h
  4. 2
      include/grpcpp/impl/codegen/server_interface.h
  5. 4
      src/cpp/server/server_cc.cc
  6. 111
      test/cpp/end2end/client_interceptors_end2end_test.cc

@ -453,14 +453,16 @@ class CallOpRecvMessage {
void SetInterceptionHookPoint(
InterceptorBatchMethodsImpl* interceptor_methods) {
interceptor_methods->SetRecvMessage(message_);
if (message_ == nullptr) return;
interceptor_methods->SetRecvMessage(message_, &got_message);
}
void SetFinishInterceptionHookPoint(
InterceptorBatchMethodsImpl* interceptor_methods) {
if (!got_message) return;
if (message_ == nullptr) return;
interceptor_methods->AddInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_MESSAGE);
if (!got_message) interceptor_methods->SetRecvMessage(nullptr, nullptr);
}
void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) {
hijacked_ = true;
@ -548,20 +550,23 @@ class CallOpGenericRecvMessage {
void SetInterceptionHookPoint(
InterceptorBatchMethodsImpl* interceptor_methods) {
interceptor_methods->SetRecvMessage(message_);
if (!deserialize_) return;
interceptor_methods->SetRecvMessage(message_, &got_message);
}
void SetFinishInterceptionHookPoint(
InterceptorBatchMethodsImpl* interceptor_methods) {
if (!got_message) return;
if (!deserialize_) return;
interceptor_methods->AddInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_MESSAGE);
if (!got_message) interceptor_methods->SetRecvMessage(nullptr, nullptr);
}
void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) {
hijacked_ = true;
if (!deserialize_) return;
interceptor_methods->AddInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_RECV_MESSAGE);
got_message = true;
}
private:

@ -168,8 +168,13 @@ class InterceptorBatchMethods {
/// list.
virtual std::unique_ptr<ChannelInterface> GetInterceptedChannel() = 0;
// On a hijacked RPC/ to-be hijacked RPC, this can be called to fail a SEND
// MESSAGE op
/// On a hijacked RPC, an interceptor can decide to fail a PRE_RECV_MESSAGE
/// op. This would be a signal to the reader that there will be no more
/// messages, or the stream has failed or been cancelled.
virtual void FailHijackedRecvMessage() = 0;
/// On a hijacked RPC/ to-be hijacked RPC, this can be called to fail a SEND
/// MESSAGE op
virtual void FailHijackedSendMessage() = 0;
};

@ -149,7 +149,10 @@ class InterceptorBatchMethodsImpl
send_trailing_metadata_ = metadata;
}
void SetRecvMessage(void* message) { recv_message_ = message; }
void SetRecvMessage(void* message, bool* got_message) {
recv_message_ = message;
got_message_ = got_message;
}
void SetRecvInitialMetadata(MetadataMap* map) {
recv_initial_metadata_ = map;
@ -172,6 +175,12 @@ class InterceptorBatchMethodsImpl
info->channel(), current_interceptor_index_ + 1));
}
void FailHijackedRecvMessage() override {
GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>(
experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)]);
*got_message_ = false;
}
// Clears all state
void ClearState() {
reverse_ = false;
@ -362,6 +371,7 @@ class InterceptorBatchMethodsImpl
std::multimap<grpc::string, grpc::string>* send_trailing_metadata_ = nullptr;
void* recv_message_ = nullptr;
bool* got_message_ = nullptr;
MetadataMap* recv_initial_metadata_ = nullptr;
@ -485,6 +495,12 @@ class CancelInterceptorBatchMethods
return std::unique_ptr<ChannelInterface>(nullptr);
}
void FailHijackedRecvMessage() override {
GPR_CODEGEN_ASSERT(false &&
"It is illegal to call FailHijackedRecvMessage on a "
"method which has a Cancel notification");
}
void FailHijackedSendMessage() override {
GPR_CODEGEN_ASSERT(false &&
"It is illegal to call FailHijackedSendMessage on a "

@ -272,7 +272,7 @@ class ServerInterface : public internal::CallHook {
/* Set interception point for recv message */
interceptor_methods_.AddInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_MESSAGE);
interceptor_methods_.SetRecvMessage(request_);
interceptor_methods_.SetRecvMessage(request_, nullptr);
return RegisteredAsyncRequest::FinalizeResult(tag, status);
}

@ -278,7 +278,7 @@ class Server::SyncRequest final : public internal::CompletionQueueTag {
request_payload_ = nullptr;
interceptor_methods_.AddInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_MESSAGE);
interceptor_methods_.SetRecvMessage(request_);
interceptor_methods_.SetRecvMessage(request_, nullptr);
}
if (interceptor_methods_.RunInterceptors(
@ -446,7 +446,7 @@ class Server::CallbackRequest final : public internal::CompletionQueueTag {
req_->request_payload_ = nullptr;
req_->interceptor_methods_.AddInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_MESSAGE);
req_->interceptor_methods_.SetRecvMessage(req_->request_);
req_->interceptor_methods_.SetRecvMessage(req_->request_, nullptr);
}
if (req_->interceptor_methods_.RunInterceptors(

@ -393,6 +393,103 @@ class ClientStreamingRpcHijackingInterceptorFactory
}
};
class ServerStreamingRpcHijackingInterceptor
: public experimental::Interceptor {
public:
ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
info_ = info;
}
virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
bool hijack = false;
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
auto* map = methods->GetSendInitialMetadata();
// Check that we can see the test metadata
ASSERT_EQ(map->size(), static_cast<unsigned>(1));
auto iterator = map->begin();
EXPECT_EQ("testkey", iterator->first);
EXPECT_EQ("testvalue", iterator->second);
hijack = true;
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
EchoRequest req;
auto* buffer = methods->GetSerializedSendMessage();
auto copied_buffer = *buffer;
EXPECT_TRUE(
SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
.ok());
EXPECT_EQ(req.message(), "Hello");
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
// Got nothing to do here for now
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
auto* map = methods->GetRecvTrailingMetadata();
bool found = false;
// Check that we received the metadata as an echo
for (const auto& pair : *map) {
found = pair.first.starts_with("testkey") &&
pair.second.starts_with("testvalue");
if (found) break;
}
EXPECT_EQ(found, true);
auto* status = methods->GetRecvStatus();
EXPECT_EQ(status->ok(), true);
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
if (++count_ > 10) {
methods->FailHijackedRecvMessage();
}
EchoResponse* resp =
static_cast<EchoResponse*>(methods->GetRecvMessage());
resp->set_message("Hello");
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
// Only the last message will be a failure
EXPECT_FALSE(got_failed_message_);
got_failed_message_ = methods->GetRecvMessage() == nullptr;
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
auto* map = methods->GetRecvTrailingMetadata();
// insert the metadata that we want
EXPECT_EQ(map->size(), static_cast<unsigned>(0));
map->insert(std::make_pair("testkey", "testvalue"));
auto* status = methods->GetRecvStatus();
*status = Status(StatusCode::OK, "");
}
if (hijack) {
methods->Hijack();
} else {
methods->Proceed();
}
}
static bool GotFailedMessage() { return got_failed_message_; }
private:
experimental::ClientRpcInfo* info_;
static bool got_failed_message_;
int count_ = 0;
};
bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false;
class ServerStreamingRpcHijackingInterceptorFactory
: public experimental::ClientInterceptorFactoryInterface {
public:
virtual experimental::Interceptor* CreateClientInterceptor(
experimental::ClientRpcInfo* info) override {
return new ServerStreamingRpcHijackingInterceptor(info);
}
};
class BidiStreamingRpcHijackingInterceptorFactory
: public experimental::ClientInterceptorFactoryInterface {
public:
@ -711,6 +808,20 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
}
TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
ChannelArguments args;
DummyInterceptor::Reset();
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
creators;
creators.push_back(
std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>(
new ServerStreamingRpcHijackingInterceptorFactory()));
auto channel = experimental::CreateCustomChannelWithInterceptors(
server_address_, InsecureChannelCredentials(), args, std::move(creators));
MakeServerStreamingCall(channel);
EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
}
TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
ChannelArguments args;
DummyInterceptor::Reset();

Loading…
Cancel
Save