|
|
|
@ -270,6 +270,235 @@ class HijackingInterceptorMakesAnotherCallFactory |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor { |
|
|
|
|
public: |
|
|
|
|
BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) { |
|
|
|
|
info_ = info; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual void Intercept(experimental::InterceptorBatchMethods* methods) { |
|
|
|
|
bool hijack = false; |
|
|
|
|
if (methods->QueryInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { |
|
|
|
|
CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue"); |
|
|
|
|
hijack = true; |
|
|
|
|
} |
|
|
|
|
if (methods->QueryInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { |
|
|
|
|
EchoRequest req; |
|
|
|
|
auto* buffer = methods->GetSerializedSendMessage(); |
|
|
|
|
auto copied_buffer = *buffer; |
|
|
|
|
EXPECT_TRUE( |
|
|
|
|
SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req) |
|
|
|
|
.ok()); |
|
|
|
|
EXPECT_EQ(req.message().find("Hello"), 0u); |
|
|
|
|
msg = req.message(); |
|
|
|
|
} |
|
|
|
|
if (methods->QueryInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { |
|
|
|
|
// Got nothing to do here for now
|
|
|
|
|
} |
|
|
|
|
if (methods->QueryInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints::POST_RECV_STATUS)) { |
|
|
|
|
CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey", |
|
|
|
|
"testvalue"); |
|
|
|
|
auto* status = methods->GetRecvStatus(); |
|
|
|
|
EXPECT_EQ(status->ok(), true); |
|
|
|
|
} |
|
|
|
|
if (methods->QueryInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { |
|
|
|
|
EchoResponse* resp = |
|
|
|
|
static_cast<EchoResponse*>(methods->GetRecvMessage()); |
|
|
|
|
resp->set_message(msg); |
|
|
|
|
} |
|
|
|
|
if (methods->QueryInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { |
|
|
|
|
EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage()) |
|
|
|
|
->message() |
|
|
|
|
.find("Hello"), |
|
|
|
|
0u); |
|
|
|
|
} |
|
|
|
|
if (methods->QueryInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { |
|
|
|
|
auto* map = methods->GetRecvTrailingMetadata(); |
|
|
|
|
// insert the metadata that we want
|
|
|
|
|
EXPECT_EQ(map->size(), static_cast<unsigned>(0)); |
|
|
|
|
map->insert(std::make_pair("testkey", "testvalue")); |
|
|
|
|
auto* status = methods->GetRecvStatus(); |
|
|
|
|
*status = Status(StatusCode::OK, ""); |
|
|
|
|
} |
|
|
|
|
if (hijack) { |
|
|
|
|
methods->Hijack(); |
|
|
|
|
} else { |
|
|
|
|
methods->Proceed(); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private: |
|
|
|
|
experimental::ClientRpcInfo* info_; |
|
|
|
|
grpc::string msg; |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class ClientStreamingRpcHijackingInterceptor |
|
|
|
|
: public experimental::Interceptor { |
|
|
|
|
public: |
|
|
|
|
ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) { |
|
|
|
|
info_ = info; |
|
|
|
|
} |
|
|
|
|
virtual void Intercept(experimental::InterceptorBatchMethods* methods) { |
|
|
|
|
bool hijack = false; |
|
|
|
|
if (methods->QueryInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { |
|
|
|
|
hijack = true; |
|
|
|
|
} |
|
|
|
|
if (methods->QueryInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { |
|
|
|
|
if (++count_ > 10) { |
|
|
|
|
methods->FailHijackedSendMessage(); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
if (methods->QueryInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) { |
|
|
|
|
EXPECT_FALSE(got_failed_send_); |
|
|
|
|
got_failed_send_ = !methods->GetSendMessageStatus(); |
|
|
|
|
} |
|
|
|
|
if (methods->QueryInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { |
|
|
|
|
auto* status = methods->GetRecvStatus(); |
|
|
|
|
*status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages"); |
|
|
|
|
} |
|
|
|
|
if (hijack) { |
|
|
|
|
methods->Hijack(); |
|
|
|
|
} else { |
|
|
|
|
methods->Proceed(); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static bool GotFailedSend() { return got_failed_send_; } |
|
|
|
|
|
|
|
|
|
private: |
|
|
|
|
experimental::ClientRpcInfo* info_; |
|
|
|
|
int count_ = 0; |
|
|
|
|
static bool got_failed_send_; |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false; |
|
|
|
|
|
|
|
|
|
class ClientStreamingRpcHijackingInterceptorFactory |
|
|
|
|
: public experimental::ClientInterceptorFactoryInterface { |
|
|
|
|
public: |
|
|
|
|
virtual experimental::Interceptor* CreateClientInterceptor( |
|
|
|
|
experimental::ClientRpcInfo* info) override { |
|
|
|
|
return new ClientStreamingRpcHijackingInterceptor(info); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class ServerStreamingRpcHijackingInterceptor |
|
|
|
|
: public experimental::Interceptor { |
|
|
|
|
public: |
|
|
|
|
ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) { |
|
|
|
|
info_ = info; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
virtual void Intercept(experimental::InterceptorBatchMethods* methods) { |
|
|
|
|
bool hijack = false; |
|
|
|
|
if (methods->QueryInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { |
|
|
|
|
auto* map = methods->GetSendInitialMetadata(); |
|
|
|
|
// Check that we can see the test metadata
|
|
|
|
|
ASSERT_EQ(map->size(), static_cast<unsigned>(1)); |
|
|
|
|
auto iterator = map->begin(); |
|
|
|
|
EXPECT_EQ("testkey", iterator->first); |
|
|
|
|
EXPECT_EQ("testvalue", iterator->second); |
|
|
|
|
hijack = true; |
|
|
|
|
} |
|
|
|
|
if (methods->QueryInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { |
|
|
|
|
EchoRequest req; |
|
|
|
|
auto* buffer = methods->GetSerializedSendMessage(); |
|
|
|
|
auto copied_buffer = *buffer; |
|
|
|
|
EXPECT_TRUE( |
|
|
|
|
SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req) |
|
|
|
|
.ok()); |
|
|
|
|
EXPECT_EQ(req.message(), "Hello"); |
|
|
|
|
} |
|
|
|
|
if (methods->QueryInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { |
|
|
|
|
// Got nothing to do here for now
|
|
|
|
|
} |
|
|
|
|
if (methods->QueryInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints::POST_RECV_STATUS)) { |
|
|
|
|
auto* map = methods->GetRecvTrailingMetadata(); |
|
|
|
|
bool found = false; |
|
|
|
|
// Check that we received the metadata as an echo
|
|
|
|
|
for (const auto& pair : *map) { |
|
|
|
|
found = pair.first.starts_with("testkey") && |
|
|
|
|
pair.second.starts_with("testvalue"); |
|
|
|
|
if (found) break; |
|
|
|
|
} |
|
|
|
|
EXPECT_EQ(found, true); |
|
|
|
|
auto* status = methods->GetRecvStatus(); |
|
|
|
|
EXPECT_EQ(status->ok(), true); |
|
|
|
|
} |
|
|
|
|
if (methods->QueryInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) { |
|
|
|
|
if (++count_ > 10) { |
|
|
|
|
methods->FailHijackedRecvMessage(); |
|
|
|
|
} |
|
|
|
|
EchoResponse* resp = |
|
|
|
|
static_cast<EchoResponse*>(methods->GetRecvMessage()); |
|
|
|
|
resp->set_message("Hello"); |
|
|
|
|
} |
|
|
|
|
if (methods->QueryInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) { |
|
|
|
|
// Only the last message will be a failure
|
|
|
|
|
EXPECT_FALSE(got_failed_message_); |
|
|
|
|
got_failed_message_ = methods->GetRecvMessage() == nullptr; |
|
|
|
|
} |
|
|
|
|
if (methods->QueryInterceptionHookPoint( |
|
|
|
|
experimental::InterceptionHookPoints::PRE_RECV_STATUS)) { |
|
|
|
|
auto* map = methods->GetRecvTrailingMetadata(); |
|
|
|
|
// insert the metadata that we want
|
|
|
|
|
EXPECT_EQ(map->size(), static_cast<unsigned>(0)); |
|
|
|
|
map->insert(std::make_pair("testkey", "testvalue")); |
|
|
|
|
auto* status = methods->GetRecvStatus(); |
|
|
|
|
*status = Status(StatusCode::OK, ""); |
|
|
|
|
} |
|
|
|
|
if (hijack) { |
|
|
|
|
methods->Hijack(); |
|
|
|
|
} else { |
|
|
|
|
methods->Proceed(); |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
static bool GotFailedMessage() { return got_failed_message_; } |
|
|
|
|
|
|
|
|
|
private: |
|
|
|
|
experimental::ClientRpcInfo* info_; |
|
|
|
|
static bool got_failed_message_; |
|
|
|
|
int count_ = 0; |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false; |
|
|
|
|
|
|
|
|
|
class ServerStreamingRpcHijackingInterceptorFactory |
|
|
|
|
: public experimental::ClientInterceptorFactoryInterface { |
|
|
|
|
public: |
|
|
|
|
virtual experimental::Interceptor* CreateClientInterceptor( |
|
|
|
|
experimental::ClientRpcInfo* info) override { |
|
|
|
|
return new ServerStreamingRpcHijackingInterceptor(info); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class BidiStreamingRpcHijackingInterceptorFactory |
|
|
|
|
: public experimental::ClientInterceptorFactoryInterface { |
|
|
|
|
public: |
|
|
|
|
virtual experimental::Interceptor* CreateClientInterceptor( |
|
|
|
|
experimental::ClientRpcInfo* info) override { |
|
|
|
|
return new BidiStreamingRpcHijackingInterceptor(info); |
|
|
|
|
} |
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
class LoggingInterceptor : public experimental::Interceptor { |
|
|
|
|
public: |
|
|
|
|
LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; } |
|
|
|
@ -550,6 +779,62 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) { |
|
|
|
|
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) { |
|
|
|
|
ChannelArguments args; |
|
|
|
|
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> |
|
|
|
|
creators; |
|
|
|
|
creators.push_back( |
|
|
|
|
std::unique_ptr<ClientStreamingRpcHijackingInterceptorFactory>( |
|
|
|
|
new ClientStreamingRpcHijackingInterceptorFactory())); |
|
|
|
|
auto channel = experimental::CreateCustomChannelWithInterceptors( |
|
|
|
|
server_address_, InsecureChannelCredentials(), args, std::move(creators)); |
|
|
|
|
|
|
|
|
|
auto stub = grpc::testing::EchoTestService::NewStub(channel); |
|
|
|
|
ClientContext ctx; |
|
|
|
|
EchoRequest req; |
|
|
|
|
EchoResponse resp; |
|
|
|
|
req.mutable_param()->set_echo_metadata(true); |
|
|
|
|
req.set_message("Hello"); |
|
|
|
|
string expected_resp = ""; |
|
|
|
|
auto writer = stub->RequestStream(&ctx, &resp); |
|
|
|
|
for (int i = 0; i < 10; i++) { |
|
|
|
|
EXPECT_TRUE(writer->Write(req)); |
|
|
|
|
expected_resp += "Hello"; |
|
|
|
|
} |
|
|
|
|
// The interceptor will reject the 11th message
|
|
|
|
|
writer->Write(req); |
|
|
|
|
Status s = writer->Finish(); |
|
|
|
|
EXPECT_EQ(s.ok(), false); |
|
|
|
|
EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend()); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) { |
|
|
|
|
ChannelArguments args; |
|
|
|
|
DummyInterceptor::Reset(); |
|
|
|
|
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> |
|
|
|
|
creators; |
|
|
|
|
creators.push_back( |
|
|
|
|
std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>( |
|
|
|
|
new ServerStreamingRpcHijackingInterceptorFactory())); |
|
|
|
|
auto channel = experimental::CreateCustomChannelWithInterceptors( |
|
|
|
|
server_address_, InsecureChannelCredentials(), args, std::move(creators)); |
|
|
|
|
MakeServerStreamingCall(channel); |
|
|
|
|
EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage()); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) { |
|
|
|
|
ChannelArguments args; |
|
|
|
|
DummyInterceptor::Reset(); |
|
|
|
|
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>> |
|
|
|
|
creators; |
|
|
|
|
creators.push_back( |
|
|
|
|
std::unique_ptr<BidiStreamingRpcHijackingInterceptorFactory>( |
|
|
|
|
new BidiStreamingRpcHijackingInterceptorFactory())); |
|
|
|
|
auto channel = experimental::CreateCustomChannelWithInterceptors( |
|
|
|
|
server_address_, InsecureChannelCredentials(), args, std::move(creators)); |
|
|
|
|
MakeBidiStreamingCall(channel); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) { |
|
|
|
|
ChannelArguments args; |
|
|
|
|
DummyInterceptor::Reset(); |
|
|
|
|