Merge pull request #17630 from yashykt/nocopyinterception

Modifying semantics for GetSendMessage and GetSerializedSendMessage. Also adding ModifySendMessage
pull/17654/head
Yash Tibrewal 6 years ago committed by GitHub
commit 8dcda4dc36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 57
      include/grpcpp/impl/codegen/call_op_set.h
  2. 2
      include/grpcpp/impl/codegen/interceptor.h
  3. 35
      include/grpcpp/impl/codegen/interceptor_common.h
  4. 8
      test/cpp/end2end/client_interceptors_end2end_test.cc
  5. 71
      test/cpp/end2end/server_interceptors_end2end_test.cc

@ -317,7 +317,15 @@ class CallOpSendMessage {
protected: protected:
void AddOp(grpc_op* ops, size_t* nops) { void AddOp(grpc_op* ops, size_t* nops) {
if (!send_buf_.Valid() || hijacked_) return; if (msg_ == nullptr && !send_buf_.Valid()) return;
if (hijacked_) {
serializer_ = nullptr;
return;
}
if (msg_ != nullptr) {
GPR_CODEGEN_ASSERT(serializer_(msg_).ok());
}
serializer_ = nullptr;
grpc_op* op = &ops[(*nops)++]; grpc_op* op = &ops[(*nops)++];
op->op = GRPC_OP_SEND_MESSAGE; op->op = GRPC_OP_SEND_MESSAGE;
op->flags = write_options_.flags(); op->flags = write_options_.flags();
@ -327,9 +335,7 @@ class CallOpSendMessage {
write_options_.Clear(); write_options_.Clear();
} }
void FinishOp(bool* status) { void FinishOp(bool* status) {
if (!send_buf_.Valid()) { if (msg_ == nullptr && !send_buf_.Valid()) return;
return;
}
if (hijacked_ && failed_send_) { if (hijacked_ && failed_send_) {
// Hijacking interceptor failed this Op // Hijacking interceptor failed this Op
*status = false; *status = false;
@ -341,22 +347,25 @@ class CallOpSendMessage {
void SetInterceptionHookPoint( void SetInterceptionHookPoint(
InterceptorBatchMethodsImpl* interceptor_methods) { InterceptorBatchMethodsImpl* interceptor_methods) {
if (!send_buf_.Valid()) return; if (msg_ == nullptr && !send_buf_.Valid()) return;
interceptor_methods->AddInterceptionHookPoint( interceptor_methods->AddInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE); experimental::InterceptionHookPoints::PRE_SEND_MESSAGE);
interceptor_methods->SetSendMessage(&send_buf_, msg_, &failed_send_); interceptor_methods->SetSendMessage(&send_buf_, &msg_, &failed_send_,
serializer_);
} }
void SetFinishInterceptionHookPoint( void SetFinishInterceptionHookPoint(
InterceptorBatchMethodsImpl* interceptor_methods) { InterceptorBatchMethodsImpl* interceptor_methods) {
if (send_buf_.Valid()) { if (msg_ != nullptr || send_buf_.Valid()) {
interceptor_methods->AddInterceptionHookPoint( interceptor_methods->AddInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_SEND_MESSAGE); experimental::InterceptionHookPoints::POST_SEND_MESSAGE);
} }
send_buf_.Clear(); send_buf_.Clear();
msg_ = nullptr;
// The contents of the SendMessage value that was previously set // The contents of the SendMessage value that was previously set
// has had its references stolen by core's operations // has had its references stolen by core's operations
interceptor_methods->SetSendMessage(nullptr, nullptr, &failed_send_); interceptor_methods->SetSendMessage(nullptr, nullptr, &failed_send_,
nullptr);
} }
void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) { void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) {
@ -369,22 +378,32 @@ class CallOpSendMessage {
bool failed_send_ = false; bool failed_send_ = false;
ByteBuffer send_buf_; ByteBuffer send_buf_;
WriteOptions write_options_; WriteOptions write_options_;
std::function<Status(const void*)> serializer_;
}; };
template <class M> template <class M>
Status CallOpSendMessage::SendMessage(const M& message, WriteOptions options) { Status CallOpSendMessage::SendMessage(const M& message, WriteOptions options) {
write_options_ = options; write_options_ = options;
bool own_buf; serializer_ = [this](const void* message) {
// TODO(vjpai): Remove the void below when possible bool own_buf;
// The void in the template parameter below should not be needed send_buf_.Clear();
// (since it should be implicit) but is needed due to an observed // TODO(vjpai): Remove the void below when possible
// difference in behavior between clang and gcc for certain internal users // The void in the template parameter below should not be needed
Status result = SerializationTraits<M, void>::Serialize( // (since it should be implicit) but is needed due to an observed
message, send_buf_.bbuf_ptr(), &own_buf); // difference in behavior between clang and gcc for certain internal users
if (!own_buf) { Status result = SerializationTraits<M, void>::Serialize(
send_buf_.Duplicate(); *static_cast<const M*>(message), send_buf_.bbuf_ptr(), &own_buf);
} if (!own_buf) {
return result; send_buf_.Duplicate();
}
return result;
};
// Serialize immediately only if we do not have access to the message pointer
if (msg_ == nullptr) {
return serializer_(&message);
serializer_ = nullptr;
}
return Status();
} }
template <class M> template <class M>

@ -118,6 +118,8 @@ class InterceptorBatchMethods {
/// only supported for sync and callback APIs at the present moment. /// only supported for sync and callback APIs at the present moment.
virtual const void* GetSendMessage() = 0; virtual const void* GetSendMessage() = 0;
virtual void ModifySendMessage(const void* message) = 0;
/// Checks whether the SEND MESSAGE op succeeded. Valid for POST_SEND_MESSAGE /// Checks whether the SEND MESSAGE op succeeded. Valid for POST_SEND_MESSAGE
/// interceptions. /// interceptions.
virtual bool GetSendMessageStatus() = 0; virtual bool GetSendMessageStatus() = 0;

@ -79,9 +79,24 @@ class InterceptorBatchMethodsImpl
hooks_[static_cast<size_t>(type)] = true; hooks_[static_cast<size_t>(type)] = true;
} }
ByteBuffer* GetSerializedSendMessage() override { return send_message_; } ByteBuffer* GetSerializedSendMessage() override {
GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr);
if (*orig_send_message_ != nullptr) {
GPR_CODEGEN_ASSERT(serializer_(*orig_send_message_).ok());
*orig_send_message_ = nullptr;
}
return send_message_;
}
const void* GetSendMessage() override {
GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr);
return *orig_send_message_;
}
const void* GetSendMessage() override { return orig_send_message_; } void ModifySendMessage(const void* message) override {
GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr);
*orig_send_message_ = message;
}
bool GetSendMessageStatus() override { return !*fail_send_message_; } bool GetSendMessageStatus() override { return !*fail_send_message_; }
@ -125,11 +140,13 @@ class InterceptorBatchMethodsImpl
return recv_trailing_metadata_->map(); return recv_trailing_metadata_->map();
} }
void SetSendMessage(ByteBuffer* buf, const void* msg, void SetSendMessage(ByteBuffer* buf, const void** msg,
bool* fail_send_message) { bool* fail_send_message,
std::function<Status(const void*)> serializer) {
send_message_ = buf; send_message_ = buf;
orig_send_message_ = msg; orig_send_message_ = msg;
fail_send_message_ = fail_send_message; fail_send_message_ = fail_send_message;
serializer_ = serializer;
} }
void SetSendInitialMetadata( void SetSendInitialMetadata(
@ -359,7 +376,8 @@ class InterceptorBatchMethodsImpl
ByteBuffer* send_message_ = nullptr; ByteBuffer* send_message_ = nullptr;
bool* fail_send_message_ = nullptr; bool* fail_send_message_ = nullptr;
const void* orig_send_message_ = nullptr; const void** orig_send_message_ = nullptr;
std::function<Status(const void*)> serializer_;
std::multimap<grpc::string, grpc::string>* send_initial_metadata_; std::multimap<grpc::string, grpc::string>* send_initial_metadata_;
@ -429,6 +447,13 @@ class CancelInterceptorBatchMethods
return nullptr; return nullptr;
} }
void ModifySendMessage(const void* message) override {
GPR_CODEGEN_ASSERT(
false &&
"It is illegal to call ModifySendMessage on a method which "
"has a Cancel notification");
}
std::multimap<grpc::string, grpc::string>* GetSendInitialMetadata() override { std::multimap<grpc::string, grpc::string>* GetSendInitialMetadata() override {
GPR_CODEGEN_ASSERT(false && GPR_CODEGEN_ASSERT(false &&
"It is illegal to call GetSendInitialMetadata on a " "It is illegal to call GetSendInitialMetadata on a "

@ -516,16 +516,16 @@ class LoggingInterceptor : public experimental::Interceptor {
if (methods->QueryInterceptionHookPoint( if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
EchoRequest req; EchoRequest req;
EXPECT_EQ(static_cast<const EchoRequest*>(methods->GetSendMessage())
->message()
.find("Hello"),
0u);
auto* buffer = methods->GetSerializedSendMessage(); auto* buffer = methods->GetSerializedSendMessage();
auto copied_buffer = *buffer; auto copied_buffer = *buffer;
EXPECT_TRUE( EXPECT_TRUE(
SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req) SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
.ok()); .ok());
EXPECT_TRUE(req.message().find("Hello") == 0u); EXPECT_TRUE(req.message().find("Hello") == 0u);
EXPECT_EQ(static_cast<const EchoRequest*>(methods->GetSendMessage())
->message()
.find("Hello"),
0u);
} }
if (methods->QueryInterceptionHookPoint( if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {

@ -142,29 +142,68 @@ class LoggingInterceptorFactory
} }
}; };
// Test if GetSendMessage works as expected // Test if SendMessage function family works as expected for sync/callback apis
class GetSendMessageTester : public experimental::Interceptor { class SyncSendMessageTester : public experimental::Interceptor {
public: public:
GetSendMessageTester(experimental::ServerRpcInfo* info) {} SyncSendMessageTester(experimental::ServerRpcInfo* info) {}
void Intercept(experimental::InterceptorBatchMethods* methods) override { void Intercept(experimental::InterceptorBatchMethods* methods) override {
if (methods->QueryInterceptionHookPoint( if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) { experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
EXPECT_EQ(static_cast<const EchoRequest*>(methods->GetSendMessage()) string old_msg =
->message() static_cast<const EchoRequest*>(methods->GetSendMessage())->message();
.find("Hello"), EXPECT_EQ(old_msg.find("Hello"), 0u);
0u); new_msg_.set_message("World" + old_msg);
methods->ModifySendMessage(&new_msg_);
} }
methods->Proceed(); methods->Proceed();
} }
private:
EchoRequest new_msg_;
}; };
class GetSendMessageTesterFactory class SyncSendMessageTesterFactory
: public experimental::ServerInterceptorFactoryInterface { : public experimental::ServerInterceptorFactoryInterface {
public: public:
virtual experimental::Interceptor* CreateServerInterceptor( virtual experimental::Interceptor* CreateServerInterceptor(
experimental::ServerRpcInfo* info) override { experimental::ServerRpcInfo* info) override {
return new GetSendMessageTester(info); return new SyncSendMessageTester(info);
}
};
// Test if SendMessage function family works as expected for sync/callback apis
class SyncSendMessageVerifier : public experimental::Interceptor {
public:
SyncSendMessageVerifier(experimental::ServerRpcInfo* info) {}
void Intercept(experimental::InterceptorBatchMethods* methods) override {
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
// Make sure that the changes made in SyncSendMessageTester persisted
string old_msg =
static_cast<const EchoRequest*>(methods->GetSendMessage())->message();
EXPECT_EQ(old_msg.find("World"), 0u);
// Remove the "World" part of the string that we added earlier
new_msg_.set_message(old_msg.erase(0, 5));
methods->ModifySendMessage(&new_msg_);
// LoggingInterceptor verifies that changes got reverted
}
methods->Proceed();
}
private:
EchoRequest new_msg_;
};
class SyncSendMessageVerifierFactory
: public experimental::ServerInterceptorFactoryInterface {
public:
virtual experimental::Interceptor* CreateServerInterceptor(
experimental::ServerRpcInfo* info) override {
return new SyncSendMessageVerifier(info);
} }
}; };
@ -201,10 +240,13 @@ class ServerInterceptorsEnd2endSyncUnaryTest : public ::testing::Test {
creators; creators;
creators.push_back( creators.push_back(
std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
new LoggingInterceptorFactory())); new SyncSendMessageTesterFactory()));
creators.push_back( creators.push_back(
std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
new GetSendMessageTesterFactory())); new SyncSendMessageVerifierFactory()));
creators.push_back(
std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
new LoggingInterceptorFactory()));
// Add 20 dummy interceptor factories and null interceptor factories // Add 20 dummy interceptor factories and null interceptor factories
for (auto i = 0; i < 20; i++) { for (auto i = 0; i < 20; i++) {
creators.push_back(std::unique_ptr<DummyInterceptorFactory>( creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
@ -244,10 +286,13 @@ class ServerInterceptorsEnd2endSyncStreamingTest : public ::testing::Test {
creators; creators;
creators.push_back( creators.push_back(
std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
new LoggingInterceptorFactory())); new SyncSendMessageTesterFactory()));
creators.push_back( creators.push_back(
std::unique_ptr<experimental::ServerInterceptorFactoryInterface>( std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
new GetSendMessageTesterFactory())); new SyncSendMessageVerifierFactory()));
creators.push_back(
std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
new LoggingInterceptorFactory()));
for (auto i = 0; i < 20; i++) { for (auto i = 0; i < 20; i++) {
creators.push_back(std::unique_ptr<DummyInterceptorFactory>( creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
new DummyInterceptorFactory())); new DummyInterceptorFactory()));

Loading…
Cancel
Save