Add method to fail hijacked send messages

pull/17220/head
Yash Tibrewal 6 years ago
parent 9cfacc48ee
commit 5d7d6c0fbd
  1. 10
      include/grpcpp/impl/codegen/call_op_set.h
  2. 4
      include/grpcpp/impl/codegen/interceptor.h
  3. 18
      include/grpcpp/impl/codegen/interceptor_common.h
  4. 73
      test/cpp/end2end/client_interceptors_end2end_test.cc

@ -314,14 +314,19 @@ class CallOpSendMessage {
// Flags are per-message: clear them after use.
write_options_.Clear();
}
void FinishOp(bool* status) { send_buf_.Clear(); }
void FinishOp(bool* status) {
send_buf_.Clear();
if (hijacked_ && failed_send_) {
*status = false;
}
}
void SetInterceptionHookPoint(
InterceptorBatchMethodsImpl* interceptor_methods) {
if (!send_buf_.Valid()) return;
interceptor_methods->AddInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE);
interceptor_methods->SetSendMessage(&send_buf_);
interceptor_methods->SetSendMessage(&send_buf_, &failed_send_);
}
void SetFinishInterceptionHookPoint(
@ -333,6 +338,7 @@ class CallOpSendMessage {
private:
bool hijacked_ = false;
bool failed_send_ = false;
ByteBuffer send_buf_;
WriteOptions write_options_;
};

@ -118,6 +118,10 @@ class InterceptorBatchMethods {
// only interceptors after the current interceptor are created from the
// factory objects registered with the channel.
virtual std::unique_ptr<ChannelInterface> GetInterceptedChannel() = 0;
// On a hijacked RPC/ to-be hijacked RPC, this can be called to fail a SEND
// MESSAGE op
virtual void FailHijackedSendMessage() = 0;
};
class Interceptor {

@ -110,12 +110,21 @@ class InterceptorBatchMethodsImpl
Status* GetRecvStatus() override { return recv_status_; }
void FailHijackedSendMessage() override {
GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]);
*fail_send_message_ = true;
}
std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata()
override {
return recv_trailing_metadata_->map();
}
void SetSendMessage(ByteBuffer* buf) { send_message_ = buf; }
void SetSendMessage(ByteBuffer* buf, bool* fail_send_message) {
send_message_ = buf;
fail_send_message_ = fail_send_message;
}
void SetSendInitialMetadata(
std::multimap<grpc::string, grpc::string>* metadata) {
@ -334,6 +343,7 @@ class InterceptorBatchMethodsImpl
std::function<void(void)> callback_;
ByteBuffer* send_message_ = nullptr;
bool* fail_send_message_ = nullptr;
std::multimap<grpc::string, grpc::string>* send_initial_metadata_;
@ -451,6 +461,12 @@ class CancelInterceptorBatchMethods
"method which has a Cancel notification");
return std::unique_ptr<ChannelInterface>(nullptr);
}
void FailHijackedSendMessage() override {
GPR_CODEGEN_ASSERT(false &&
"It is illegal to call FailHijackedSendMessage on a "
"method which has a Cancel notification");
}
};
} // namespace internal
} // namespace grpc

@ -269,6 +269,49 @@ class HijackingInterceptorMakesAnotherCallFactory
}
};
class ClientStreamingRpcHijackingInterceptor
: public experimental::Interceptor {
public:
ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
info_ = info;
}
virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
bool hijack = false;
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
hijack = true;
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
if (++count_ > 10) {
methods->FailHijackedSendMessage();
}
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
auto* status = methods->GetRecvStatus();
*status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages");
}
if (hijack) {
methods->Hijack();
} else {
methods->Proceed();
}
}
private:
experimental::ClientRpcInfo* info_;
int count_ = 0;
};
class ClientStreamingRpcHijackingInterceptorFactory
: public experimental::ClientInterceptorFactoryInterface {
public:
virtual experimental::Interceptor* CreateClientInterceptor(
experimental::ClientRpcInfo* info) override {
return new ClientStreamingRpcHijackingInterceptor(info);
}
};
class LoggingInterceptor : public experimental::Interceptor {
public:
LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; }
@ -535,6 +578,36 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
}
TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
ChannelArguments args;
auto creators = std::unique_ptr<std::vector<
std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>>(
new std::vector<
std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>());
creators->push_back(
std::unique_ptr<ClientStreamingRpcHijackingInterceptorFactory>(
new ClientStreamingRpcHijackingInterceptorFactory()));
auto channel = experimental::CreateCustomChannelWithInterceptors(
server_address_, InsecureChannelCredentials(), args, std::move(creators));
auto stub = grpc::testing::EchoTestService::NewStub(channel);
ClientContext ctx;
EchoRequest req;
EchoResponse resp;
req.mutable_param()->set_echo_metadata(true);
req.set_message("Hello");
string expected_resp = "";
auto writer = stub->RequestStream(&ctx, &resp);
for (int i = 0; i < 10; i++) {
EXPECT_TRUE(writer->Write(req));
expected_resp += "Hello";
}
// Expect that the interceptor will reject the 11th message
EXPECT_FALSE(writer->Write(req));
Status s = writer->Finish();
EXPECT_EQ(s.ok(), false);
}
TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
ChannelArguments args;
DummyInterceptor::Reset();

Loading…
Cancel
Save