Merge branch 'master' into failhijackedsend

pull/17220/head
Yash Tibrewal 6 years ago
commit 8ba5922e87
  1. 91
      test/cpp/end2end/client_interceptors_end2end_test.cc
  2. 10
      test/cpp/end2end/interceptors_util.cc
  3. 3
      test/cpp/end2end/interceptors_util.h

@ -270,6 +270,75 @@ class HijackingInterceptorMakesAnotherCallFactory
}
};
class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor {
public:
BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
info_ = info;
}
virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
bool hijack = false;
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue");
hijack = true;
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
EchoRequest req;
auto* buffer = methods->GetSerializedSendMessage();
auto copied_buffer = *buffer;
EXPECT_TRUE(
SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
.ok());
EXPECT_EQ(req.message().find("Hello"), 0u);
msg = req.message();
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
// Got nothing to do here for now
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey",
"testvalue");
auto* status = methods->GetRecvStatus();
EXPECT_EQ(status->ok(), true);
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
EchoResponse* resp =
static_cast<EchoResponse*>(methods->GetRecvMessage());
resp->set_message(msg);
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage())
->message()
.find("Hello"),
0u);
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
auto* map = methods->GetRecvTrailingMetadata();
// insert the metadata that we want
EXPECT_EQ(map->size(), static_cast<unsigned>(0));
map->insert(std::make_pair("testkey", "testvalue"));
auto* status = methods->GetRecvStatus();
*status = Status(StatusCode::OK, "");
}
if (hijack) {
methods->Hijack();
} else {
methods->Proceed();
}
}
private:
experimental::ClientRpcInfo* info_;
grpc::string msg;
};
class ClientStreamingRpcHijackingInterceptor
: public experimental::Interceptor {
public:
@ -324,6 +393,15 @@ class ClientStreamingRpcHijackingInterceptorFactory
}
};
class BidiStreamingRpcHijackingInterceptorFactory
: public experimental::ClientInterceptorFactoryInterface {
public:
virtual experimental::Interceptor* CreateClientInterceptor(
experimental::ClientRpcInfo* info) override {
return new BidiStreamingRpcHijackingInterceptor(info);
}
};
class LoggingInterceptor : public experimental::Interceptor {
public:
LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; }
@ -633,6 +711,19 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
}
TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
ChannelArguments args;
DummyInterceptor::Reset();
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
creators;
creators.push_back(
std::unique_ptr<BidiStreamingRpcHijackingInterceptorFactory>(
new BidiStreamingRpcHijackingInterceptorFactory()));
auto channel = experimental::CreateCustomChannelWithInterceptors(
server_address_, InsecureChannelCredentials(), args, std::move(creators));
MakeBidiStreamingCall(channel);
}
TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
ChannelArguments args;
DummyInterceptor::Reset();

@ -132,6 +132,16 @@ bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
return false;
}
bool CheckMetadata(const std::multimap<grpc::string, grpc::string>& map,
const string& key, const string& value) {
for (const auto& pair : map) {
if (pair.first == key && pair.second == value) {
return true;
}
}
return false;
}
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
CreateDummyClientInterceptors() {
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>

@ -165,6 +165,9 @@ void MakeCallbackCall(const std::shared_ptr<Channel>& channel);
bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
const string& key, const string& value);
bool CheckMetadata(const std::multimap<grpc::string, grpc::string>& map,
const string& key, const string& value);
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
CreateDummyClientInterceptors();

Loading…
Cancel
Save