Fix interceptor batch method FailHijackedRecvMessage for async APIs

pull/22746/head
Yash Tibrewal 5 years ago
parent b8e34c84ed
commit edbae5d8e6
  1. 61
      include/grpcpp/impl/codegen/call_op_set.h
  2. 8
      include/grpcpp/impl/codegen/interceptor_common.h
  3. 181
      test/cpp/end2end/client_interceptors_end2end_test.cc
  4. 57
      test/cpp/end2end/interceptors_util.cc
  5. 18
      test/cpp/end2end/interceptors_util.h

@ -421,17 +421,14 @@ Status CallOpSendMessage::SendMessagePtr(const M* message) {
template <class R>
class CallOpRecvMessage {
public:
CallOpRecvMessage()
: got_message(false),
message_(nullptr),
allow_not_getting_message_(false) {}
CallOpRecvMessage() {}
void RecvMessage(R* message) { message_ = message; }
// Do not change status if no message is received.
void AllowNoMessage() { allow_not_getting_message_ = true; }
bool got_message;
bool got_message = false;
protected:
void AddOp(grpc_op* ops, size_t* nops) {
@ -444,7 +441,7 @@ class CallOpRecvMessage {
}
void FinishOp(bool* status) {
if (message_ == nullptr || hijacked_) return;
if (message_ == nullptr) return;
if (recv_buf_.Valid()) {
if (*status) {
got_message = *status =
@ -455,18 +452,20 @@ class CallOpRecvMessage {
got_message = false;
recv_buf_.Clear();
}
} else {
got_message = false;
if (!allow_not_getting_message_) {
*status = false;
} else if (hijacked_) {
if (!hijacked_recv_message_status_) {
FinishOpRecvMessageFailureHandler(status);
}
} else {
FinishOpRecvMessageFailureHandler(status);
}
}
void SetInterceptionHookPoint(
InterceptorBatchMethodsImpl* interceptor_methods) {
if (message_ == nullptr) return;
interceptor_methods->SetRecvMessage(message_, &got_message);
interceptor_methods->SetRecvMessage(message_,
&hijacked_recv_message_status_);
}
void SetFinishInterceptionHookPoint(
@ -485,10 +484,19 @@ class CallOpRecvMessage {
}
private:
R* message_;
// Sets got_message and \a status for a failed recv message op
void FinishOpRecvMessageFailureHandler(bool* status) {
got_message = false;
if (!allow_not_getting_message_) {
*status = false;
}
}
R* message_ = nullptr;
ByteBuffer recv_buf_;
bool allow_not_getting_message_;
bool allow_not_getting_message_ = false;
bool hijacked_ = false;
bool hijacked_recv_message_status_ = true;
};
class DeserializeFunc {
@ -513,8 +521,7 @@ class DeserializeFuncType final : public DeserializeFunc {
class CallOpGenericRecvMessage {
public:
CallOpGenericRecvMessage()
: got_message(false), allow_not_getting_message_(false) {}
CallOpGenericRecvMessage() {}
template <class R>
void RecvMessage(R* message) {
@ -528,7 +535,7 @@ class CallOpGenericRecvMessage {
// Do not change status if no message is received.
void AllowNoMessage() { allow_not_getting_message_ = true; }
bool got_message;
bool got_message = false;
protected:
void AddOp(grpc_op* ops, size_t* nops) {
@ -551,6 +558,10 @@ class CallOpGenericRecvMessage {
got_message = false;
recv_buf_.Clear();
}
} else if (hijacked_) {
if (!hijacked_recv_message_status_) {
FinishOpRecvMessageFailureHandler(status);
}
} else {
got_message = false;
if (!allow_not_getting_message_) {
@ -562,7 +573,8 @@ class CallOpGenericRecvMessage {
void SetInterceptionHookPoint(
InterceptorBatchMethodsImpl* interceptor_methods) {
if (!deserialize_) return;
interceptor_methods->SetRecvMessage(message_, &got_message);
interceptor_methods->SetRecvMessage(message_,
&hijacked_recv_message_status_);
}
void SetFinishInterceptionHookPoint(
@ -582,11 +594,20 @@ class CallOpGenericRecvMessage {
}
private:
void* message_;
bool hijacked_ = false;
// Sets got_message and \a status for a failed recv message op
void FinishOpRecvMessageFailureHandler(bool* status) {
got_message = false;
if (!allow_not_getting_message_) {
*status = false;
}
}
void* message_ = nullptr;
std::unique_ptr<DeserializeFunc> deserialize_;
ByteBuffer recv_buf_;
bool allow_not_getting_message_;
bool allow_not_getting_message_ = false;
bool hijacked_ = false;
bool hijacked_recv_message_status_ = true;
};
class CallOpClientSendClose {

@ -166,9 +166,9 @@ class InterceptorBatchMethodsImpl
send_trailing_metadata_ = metadata;
}
void SetRecvMessage(void* message, bool* got_message) {
void SetRecvMessage(void* message, bool* hijacked_recv_message_status) {
recv_message_ = message;
got_message_ = got_message;
hijacked_recv_message_status_ = hijacked_recv_message_status;
}
void SetRecvInitialMetadata(MetadataMap* map) {
@ -195,7 +195,7 @@ class InterceptorBatchMethodsImpl
void FailHijackedRecvMessage() override {
GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>(
experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)]);
*got_message_ = false;
*hijacked_recv_message_status_ = false;
}
// Clears all state
@ -407,7 +407,7 @@ class InterceptorBatchMethodsImpl
std::multimap<grpc::string, grpc::string>* send_trailing_metadata_ = nullptr;
void* recv_message_ = nullptr;
bool* got_message_ = nullptr;
bool* hijacked_recv_message_status_ = nullptr;
MetadataMap* recv_initial_metadata_ = nullptr;

@ -43,6 +43,17 @@ namespace grpc {
namespace testing {
namespace {
enum class RPCType {
kSyncUnary,
kSyncClientStreaming,
kSyncServerStreaming,
kSyncBidiStreaming,
kAsyncCQUnary,
kAsyncCQClientStreaming,
kAsyncCQServerStreaming,
kAsyncCQBidiStreaming,
};
/* Hijacks Echo RPC and fills in the expected values */
class HijackingInterceptor : public experimental::Interceptor {
public:
@ -400,6 +411,7 @@ class ServerStreamingRpcHijackingInterceptor
public:
ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
info_ = info;
got_failed_message_ = false;
}
virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
@ -531,10 +543,22 @@ class LoggingInterceptor : public experimental::Interceptor {
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
EchoRequest req;
EXPECT_EQ(static_cast<const EchoRequest*>(methods->GetSendMessage())
->message()
.find("Hello"),
0u);
auto* send_msg = methods->GetSendMessage();
if (send_msg == nullptr) {
// We did not get the non-serialized form of the message. Get the
// serialized form.
auto* buffer = methods->GetSerializedSendMessage();
auto copied_buffer = *buffer;
EchoRequest req;
EXPECT_TRUE(
SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
.ok());
EXPECT_EQ(req.message(), "Hello");
} else {
EXPECT_EQ(
static_cast<const EchoRequest*>(send_msg)->message().find("Hello"),
0u);
}
auto* buffer = methods->GetSerializedSendMessage();
auto copied_buffer = *buffer;
EXPECT_TRUE(
@ -582,6 +606,27 @@ class LoggingInterceptor : public experimental::Interceptor {
methods->Proceed();
}
static void VerifyCall(RPCType type) {
switch (type) {
case RPCType::kSyncUnary:
case RPCType::kAsyncCQUnary:
VerifyUnaryCall();
break;
case RPCType::kSyncClientStreaming:
case RPCType::kAsyncCQClientStreaming:
VerifyClientStreamingCall();
break;
case RPCType::kSyncServerStreaming:
case RPCType::kAsyncCQServerStreaming:
VerifyServerStreamingCall();
break;
case RPCType::kSyncBidiStreaming:
case RPCType::kAsyncCQBidiStreaming:
VerifyBidiStreamingCall();
break;
}
}
static void VerifyCallCommon() {
EXPECT_TRUE(pre_send_initial_metadata_);
EXPECT_TRUE(pre_send_close_);
@ -638,9 +683,31 @@ class LoggingInterceptorFactory
}
};
class ClientInterceptorsEnd2endTest : public ::testing::Test {
class TestScenario {
public:
explicit TestScenario(const RPCType& type) : type_(type) {}
RPCType type() const { return type_; }
private:
RPCType type_;
};
std::vector<TestScenario> CreateTestScenarios() {
std::vector<TestScenario> scenarios;
scenarios.emplace_back(RPCType::kSyncUnary);
scenarios.emplace_back(RPCType::kSyncClientStreaming);
scenarios.emplace_back(RPCType::kSyncServerStreaming);
scenarios.emplace_back(RPCType::kSyncBidiStreaming);
scenarios.emplace_back(RPCType::kAsyncCQUnary);
scenarios.emplace_back(RPCType::kAsyncCQServerStreaming);
return scenarios;
}
class ParameterizedClientInterceptorsEnd2endTest
: public ::testing::TestWithParam<TestScenario> {
protected:
ClientInterceptorsEnd2endTest() {
ParameterizedClientInterceptorsEnd2endTest() {
int port = grpc_pick_unused_port_or_die();
ServerBuilder builder;
@ -650,14 +717,44 @@ class ClientInterceptorsEnd2endTest : public ::testing::Test {
server_ = builder.BuildAndStart();
}
~ClientInterceptorsEnd2endTest() { server_->Shutdown(); }
~ParameterizedClientInterceptorsEnd2endTest() { server_->Shutdown(); }
void SendRPC(const std::shared_ptr<Channel>& channel) {
switch (GetParam().type()) {
case RPCType::kSyncUnary:
MakeCall(channel);
break;
case RPCType::kSyncClientStreaming:
MakeClientStreamingCall(channel);
break;
case RPCType::kSyncServerStreaming:
MakeServerStreamingCall(channel);
break;
case RPCType::kSyncBidiStreaming:
MakeBidiStreamingCall(channel);
break;
case RPCType::kAsyncCQUnary:
MakeAsyncCQCall(channel);
break;
case RPCType::kAsyncCQClientStreaming:
// TODO(yashykt) : Fill this out
break;
case RPCType::kAsyncCQServerStreaming:
MakeAsyncCQServerStreamingCall(channel);
break;
case RPCType::kAsyncCQBidiStreaming:
// TODO(yashykt) : Fill this out
break;
}
}
std::string server_address_;
TestServiceImpl service_;
EchoTestServiceStreamingImpl service_;
std::unique_ptr<Server> server_;
};
TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) {
TEST_P(ParameterizedClientInterceptorsEnd2endTest,
ClientInterceptorLoggingTest) {
ChannelArguments args;
DummyInterceptor::Reset();
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
@ -671,12 +768,36 @@ TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) {
}
auto channel = experimental::CreateCustomChannelWithInterceptors(
server_address_, InsecureChannelCredentials(), args, std::move(creators));
MakeCall(channel);
LoggingInterceptor::VerifyUnaryCall();
SendRPC(channel);
LoggingInterceptor::VerifyCall(GetParam().type());
// Make sure all 20 dummy interceptors were run
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
}
INSTANTIATE_TEST_SUITE_P(ParameterizedClientInterceptorsEnd2end,
ParameterizedClientInterceptorsEnd2endTest,
::testing::ValuesIn(CreateTestScenarios()));
class ClientInterceptorsEnd2endTest
: public ::testing::TestWithParam<TestScenario> {
protected:
ClientInterceptorsEnd2endTest() {
int port = grpc_pick_unused_port_or_die();
ServerBuilder builder;
server_address_ = "localhost:" + std::to_string(port);
builder.AddListeningPort(server_address_, InsecureServerCredentials());
builder.RegisterService(&service_);
server_ = builder.BuildAndStart();
}
~ClientInterceptorsEnd2endTest() { server_->Shutdown(); }
std::string server_address_;
TestServiceImpl service_;
std::unique_ptr<Server> server_;
};
TEST_F(ClientInterceptorsEnd2endTest,
LameChannelClientInterceptorHijackingTest) {
ChannelArguments args;
@ -757,7 +878,26 @@ TEST_F(ClientInterceptorsEnd2endTest,
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 12);
}
TEST_F(ClientInterceptorsEnd2endTest,
class ClientInterceptorsCallbackEnd2endTest : public ::testing::Test {
protected:
ClientInterceptorsCallbackEnd2endTest() {
int port = grpc_pick_unused_port_or_die();
ServerBuilder builder;
server_address_ = "localhost:" + std::to_string(port);
builder.AddListeningPort(server_address_, InsecureServerCredentials());
builder.RegisterService(&service_);
server_ = builder.BuildAndStart();
}
~ClientInterceptorsCallbackEnd2endTest() { server_->Shutdown(); }
std::string server_address_;
TestServiceImpl service_;
std::unique_ptr<Server> server_;
};
TEST_F(ClientInterceptorsCallbackEnd2endTest,
ClientInterceptorLoggingTestWithCallback) {
ChannelArguments args;
DummyInterceptor::Reset();
@ -778,7 +918,7 @@ TEST_F(ClientInterceptorsEnd2endTest,
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
}
TEST_F(ClientInterceptorsEnd2endTest,
TEST_F(ClientInterceptorsCallbackEnd2endTest,
ClientInterceptorFactoryAllowsNullptrReturn) {
ChannelArguments args;
DummyInterceptor::Reset();
@ -903,6 +1043,21 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
}
TEST_F(ClientInterceptorsStreamingEnd2endTest,
AsyncCQServerStreamingHijackingTest) {
ChannelArguments args;
DummyInterceptor::Reset();
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
creators;
creators.push_back(
std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>(
new ServerStreamingRpcHijackingInterceptorFactory()));
auto channel = experimental::CreateCustomChannelWithInterceptors(
server_address_, InsecureChannelCredentials(), args, std::move(creators));
MakeAsyncCQServerStreamingCall(channel);
EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
}
TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
ChannelArguments args;
DummyInterceptor::Reset();

@ -66,7 +66,6 @@ void MakeServerStreamingCall(const std::shared_ptr<Channel>& channel) {
ctx.AddMetadata("testkey", "testvalue");
req.set_message("Hello");
EchoResponse resp;
string expected_resp = "";
auto reader = stub->ResponseStream(&ctx, req);
int count = 0;
while (reader->Read(&resp)) {
@ -84,6 +83,7 @@ void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel) {
EchoRequest req;
EchoResponse resp;
ctx.AddMetadata("testkey", "testvalue");
req.mutable_param()->set_echo_metadata(true);
auto stream = stub->BidiStream(&ctx);
for (auto i = 0; i < kNumStreamingMessages; i++) {
req.set_message("Hello" + std::to_string(i));
@ -96,6 +96,60 @@ void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel) {
EXPECT_EQ(s.ok(), true);
}
void MakeAsyncCQCall(const std::shared_ptr<Channel>& channel) {
auto stub = grpc::testing::EchoTestService::NewStub(channel);
CompletionQueue cq;
EchoRequest send_request;
EchoResponse recv_response;
Status recv_status;
ClientContext cli_ctx;
send_request.set_message("Hello");
cli_ctx.AddMetadata("testkey", "testvalue");
std::unique_ptr<ClientAsyncResponseReader<EchoResponse>> response_reader(
stub->AsyncEcho(&cli_ctx, send_request, &cq));
response_reader->Finish(&recv_response, &recv_status, tag(1));
Verifier().Expect(1, true).Verify(&cq);
EXPECT_EQ(send_request.message(), recv_response.message());
EXPECT_TRUE(recv_status.ok());
}
void MakeAsyncCQClientStreamingCall(const std::shared_ptr<Channel>& channel) {
// TODO(yashykt) : Fill this out
}
void MakeAsyncCQServerStreamingCall(const std::shared_ptr<Channel>& channel) {
auto stub = grpc::testing::EchoTestService::NewStub(channel);
CompletionQueue cq;
EchoRequest send_request;
EchoResponse recv_response;
Status recv_status;
ClientContext cli_ctx;
cli_ctx.AddMetadata("testkey", "testvalue");
send_request.set_message("Hello");
std::unique_ptr<ClientAsyncReader<EchoResponse>> cli_stream(
stub->AsyncResponseStream(&cli_ctx, send_request, &cq, tag(1)));
Verifier().Expect(1, true).Verify(&cq);
// Read the expected number of messages
for (int i = 0; i < kNumStreamingMessages; i++) {
cli_stream->Read(&recv_response, tag(2));
Verifier().Expect(2, true).Verify(&cq);
ASSERT_EQ(recv_response.message(), send_request.message());
}
// The next read should fail
cli_stream->Read(&recv_response, tag(3));
Verifier().Expect(3, false).Verify(&cq);
// Get the status
cli_stream->Finish(&recv_status, tag(4));
Verifier().Expect(4, true).Verify(&cq);
EXPECT_TRUE(recv_status.ok());
}
void MakeAsyncCQBidiStreamingCall(const std::shared_ptr<Channel>& channel) {
// TODO(yashykt) : Fill this out
}
void MakeCallbackCall(const std::shared_ptr<Channel>& channel) {
auto stub = grpc::testing::EchoTestService::NewStub(channel);
ClientContext ctx;
@ -109,7 +163,6 @@ void MakeCallbackCall(const std::shared_ptr<Channel>& channel) {
EchoResponse resp;
stub->experimental_async()->Echo(&ctx, &req, &resp,
[&resp, &mu, &done, &cv](Status s) {
// gpr_log(GPR_ERROR, "got the callback");
EXPECT_EQ(s.ok(), true);
EXPECT_EQ(resp.message(), "Hello");
std::lock_guard<std::mutex> l(mu);

@ -102,6 +102,16 @@ class EchoTestServiceStreamingImpl : public EchoTestService::Service {
public:
~EchoTestServiceStreamingImpl() override {}
Status Echo(ServerContext* context, const EchoRequest* request,
EchoResponse* response) {
auto client_metadata = context->client_metadata();
for (const auto& pair : client_metadata) {
context->AddTrailingMetadata(ToString(pair.first), ToString(pair.second));
}
response->set_message(request->message());
return Status::OK;
}
Status BidiStream(
ServerContext* context,
grpc::ServerReaderWriter<EchoResponse, EchoRequest>* stream) override {
@ -162,6 +172,14 @@ void MakeServerStreamingCall(const std::shared_ptr<Channel>& channel);
void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel);
void MakeAsyncCQCall(const std::shared_ptr<Channel>& channel);
void MakeAsyncCQClientStreamingCall(const std::shared_ptr<Channel>& channel);
void MakeAsyncCQServerStreamingCall(const std::shared_ptr<Channel>& channel);
void MakeAsyncCQBidiStreamingCall(const std::shared_ptr<Channel>& channel);
void MakeCallbackCall(const std::shared_ptr<Channel>& channel);
bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,

Loading…
Cancel
Save