Fix a bug where POST_RECV_MESSAGE is not being triggered

pull/19142/head
Yash Tibrewal 6 years ago
parent 435b1ee961
commit 67bdbbdf6f
  1. 7
      include/grpcpp/impl/codegen/call_op_set.h
  2. 74
      test/cpp/end2end/client_interceptors_end2end_test.cc
  3. 6
      test/cpp/end2end/interceptors_util.cc
  4. 2
      test/cpp/end2end/interceptors_util.h

@ -433,7 +433,9 @@ class CallOpRecvMessage {
message_(nullptr),
allow_not_getting_message_(false) {}
void RecvMessage(R* message) { message_ = message; }
void RecvMessage(R* message) {
message_ = message;
}
// Do not change status if no message is received.
void AllowNoMessage() { allow_not_getting_message_ = true; }
@ -468,7 +470,6 @@ class CallOpRecvMessage {
*status = false;
}
}
message_ = nullptr;
}
void SetInterceptionHookPoint(
@ -565,7 +566,6 @@ class CallOpGenericRecvMessage {
*status = false;
}
}
deserialize_.reset();
}
void SetInterceptionHookPoint(
@ -580,6 +580,7 @@ class CallOpGenericRecvMessage {
interceptor_methods->AddInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_MESSAGE);
if (!got_message) interceptor_methods->SetRecvMessage(nullptr, nullptr);
deserialize_.reset();
}
void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) {
hijacked_ = true;

@ -501,7 +501,14 @@ class BidiStreamingRpcHijackingInterceptorFactory
class LoggingInterceptor : public experimental::Interceptor {
public:
LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; }
LoggingInterceptor(experimental::ClientRpcInfo* info) : info_(info) {
pre_send_initial_metadata_ = false;
pre_send_message_count_ = 0;
pre_send_close_ = false;
post_recv_initial_metadata_ = false;
post_recv_message_count_ = 0;
post_recv_status_ = false;
}
virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
if (methods->QueryInterceptionHookPoint(
@ -512,6 +519,8 @@ class LoggingInterceptor : public experimental::Interceptor {
auto iterator = map->begin();
EXPECT_EQ("testkey", iterator->first);
EXPECT_EQ("testvalue", iterator->second);
ASSERT_FALSE(pre_send_initial_metadata_);
pre_send_initial_metadata_ = true;
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
@ -526,22 +535,28 @@ class LoggingInterceptor : public experimental::Interceptor {
SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
.ok());
EXPECT_TRUE(req.message().find("Hello") == 0u);
pre_send_message_count_++;
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
// Got nothing to do here for now
pre_send_close_ = true;
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
auto* map = methods->GetRecvInitialMetadata();
// Got nothing better to do here for now
EXPECT_EQ(map->size(), static_cast<unsigned>(0));
post_recv_initial_metadata_ = true;
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
EchoResponse* resp =
static_cast<EchoResponse*>(methods->GetRecvMessage());
EXPECT_TRUE(resp->message().find("Hello") == 0u);
if(resp != nullptr) {
EXPECT_TRUE(resp->message().find("Hello") == 0u);
post_recv_message_count_++;
}
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
@ -556,14 +571,59 @@ class LoggingInterceptor : public experimental::Interceptor {
EXPECT_EQ(found, true);
auto* status = methods->GetRecvStatus();
EXPECT_EQ(status->ok(), true);
post_recv_status_ = true;
}
methods->Proceed();
}
static void VerifyCallCommon() {
EXPECT_TRUE(pre_send_initial_metadata_);
EXPECT_TRUE(pre_send_close_);
EXPECT_TRUE(post_recv_initial_metadata_);
EXPECT_TRUE(post_recv_status_);
}
static void VerifyUnaryCall() {
VerifyCallCommon();
EXPECT_EQ(pre_send_message_count_, 1);
EXPECT_EQ(post_recv_message_count_, 1);
}
static void VerifyClientStreamingCall() {
VerifyCallCommon();
EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
EXPECT_EQ(post_recv_message_count_, 1);
}
static void VerifyServerStreamingCall() {
VerifyCallCommon();
EXPECT_EQ(pre_send_message_count_, 1);
EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
}
static void VerifyBidiStreamingCall() {
VerifyCallCommon();
EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
}
private:
experimental::ClientRpcInfo* info_;
static bool pre_send_initial_metadata_;
static int pre_send_message_count_;
static bool pre_send_close_;
static bool post_recv_initial_metadata_;
static int post_recv_message_count_;
static bool post_recv_status_;
};
bool LoggingInterceptor::pre_send_initial_metadata_;
int LoggingInterceptor::pre_send_message_count_;
bool LoggingInterceptor::pre_send_close_;
bool LoggingInterceptor::post_recv_initial_metadata_;
int LoggingInterceptor::post_recv_message_count_;
bool LoggingInterceptor::post_recv_status_;
class LoggingInterceptorFactory
: public experimental::ClientInterceptorFactoryInterface {
public:
@ -607,6 +667,7 @@ TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLoggingTest) {
auto channel = experimental::CreateCustomChannelWithInterceptors(
server_address_, InsecureChannelCredentials(), args, std::move(creators));
MakeCall(channel);
LoggingInterceptor::VerifyUnaryCall();
// Make sure all 20 dummy interceptors were run
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
}
@ -643,7 +704,6 @@ TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) {
}
auto channel = experimental::CreateCustomChannelWithInterceptors(
server_address_, InsecureChannelCredentials(), args, std::move(creators));
MakeCall(channel);
// Make sure only 20 dummy interceptors were run
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
@ -659,8 +719,8 @@ TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) {
new HijackingInterceptorFactory()));
auto channel = experimental::CreateCustomChannelWithInterceptors(
server_address_, InsecureChannelCredentials(), args, std::move(creators));
MakeCall(channel);
LoggingInterceptor::VerifyUnaryCall();
}
TEST_F(ClientInterceptorsEnd2endTest,
@ -708,6 +768,7 @@ TEST_F(ClientInterceptorsEnd2endTest,
auto channel = server_->experimental().InProcessChannelWithInterceptors(
args, std::move(creators));
MakeCallbackCall(channel);
LoggingInterceptor::VerifyUnaryCall();
// Make sure all 20 dummy interceptors were run
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
}
@ -730,6 +791,7 @@ TEST_F(ClientInterceptorsEnd2endTest,
auto channel = server_->experimental().InProcessChannelWithInterceptors(
args, std::move(creators));
MakeCallbackCall(channel);
LoggingInterceptor::VerifyUnaryCall();
// Make sure all 20 dummy interceptors were run
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
}
@ -768,6 +830,7 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingTest) {
auto channel = experimental::CreateCustomChannelWithInterceptors(
server_address_, InsecureChannelCredentials(), args, std::move(creators));
MakeClientStreamingCall(channel);
LoggingInterceptor::VerifyClientStreamingCall();
// Make sure all 20 dummy interceptors were run
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
}
@ -787,6 +850,7 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
auto channel = experimental::CreateCustomChannelWithInterceptors(
server_address_, InsecureChannelCredentials(), args, std::move(creators));
MakeServerStreamingCall(channel);
LoggingInterceptor::VerifyServerStreamingCall();
// Make sure all 20 dummy interceptors were run
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
}
@ -862,6 +926,7 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
auto channel = experimental::CreateCustomChannelWithInterceptors(
server_address_, InsecureChannelCredentials(), args, std::move(creators));
MakeBidiStreamingCall(channel);
LoggingInterceptor::VerifyBidiStreamingCall();
// Make sure all 20 dummy interceptors were run
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
}
@ -928,6 +993,7 @@ TEST_F(ClientGlobalInterceptorEnd2endTest, LoggingGlobalInterceptor) {
auto channel = experimental::CreateCustomChannelWithInterceptors(
server_address_, InsecureChannelCredentials(), args, std::move(creators));
MakeCall(channel);
LoggingInterceptor::VerifyUnaryCall();
// Make sure all 20 dummy interceptors were run
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
experimental::TestOnlyResetGlobalClientInterceptorFactory();

@ -48,7 +48,7 @@ void MakeClientStreamingCall(const std::shared_ptr<Channel>& channel) {
EchoResponse resp;
string expected_resp = "";
auto writer = stub->RequestStream(&ctx, &resp);
for (int i = 0; i < 10; i++) {
for (int i = 0; i < kNumStreamingMessages; i++) {
writer->Write(req);
expected_resp += "Hello";
}
@ -73,7 +73,7 @@ void MakeServerStreamingCall(const std::shared_ptr<Channel>& channel) {
EXPECT_EQ(resp.message(), "Hello");
count++;
}
ASSERT_EQ(count, 10);
ASSERT_EQ(count, kNumStreamingMessages);
Status s = reader->Finish();
EXPECT_EQ(s.ok(), true);
}
@ -85,7 +85,7 @@ void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel) {
EchoResponse resp;
ctx.AddMetadata("testkey", "testvalue");
auto stream = stub->BidiStream(&ctx);
for (auto i = 0; i < 10; i++) {
for (auto i = 0; i < kNumStreamingMessages; i++) {
req.set_message("Hello" + std::to_string(i));
stream->Write(req);
stream->Read(&resp);

@ -152,6 +152,8 @@ class EchoTestServiceStreamingImpl : public EchoTestService::Service {
}
};
constexpr int kNumStreamingMessages = 10;
void MakeCall(const std::shared_ptr<Channel>& channel);
void MakeClientStreamingCall(const std::shared_ptr<Channel>& channel);

Loading…
Cancel
Save