Provide GetOriginalSendMessage for some APIs

pull/17609/head
Yash Tibrewal 6 years ago
parent a76465c65c
commit 4aeba42528
  1. 29
      include/grpcpp/impl/codegen/call_op_set.h
  2. 6
      include/grpcpp/impl/codegen/client_callback.h
  3. 2
      include/grpcpp/impl/codegen/client_unary_call.h
  4. 6
      include/grpcpp/impl/codegen/interceptor.h
  5. 16
      include/grpcpp/impl/codegen/interceptor_common.h
  6. 8
      include/grpcpp/impl/codegen/server_callback.h
  7. 10
      include/grpcpp/impl/codegen/sync_stream.h
  8. 5
      test/cpp/end2end/client_interceptors_end2end_test.cc

@ -303,6 +303,18 @@ class CallOpSendMessage {
template <class M> template <class M>
Status SendMessage(const M& message) GRPC_MUST_USE_RESULT; Status SendMessage(const M& message) GRPC_MUST_USE_RESULT;
/// Send \a message using \a options for the write. The \a options are cleared
/// after use. This form of SendMessage allows gRPC to reference \a message
/// beyond the lifetime of SendMessage.
template <class M>
Status SendMessage(const M* message,
WriteOptions options) GRPC_MUST_USE_RESULT;
/// This form of SendMessage allows gRPC to reference \a message beyond the
/// lifetime of SendMessage.
template <class M>
Status SendMessage(const M* message) GRPC_MUST_USE_RESULT;
protected: protected:
void AddOp(grpc_op* ops, size_t* nops) { void AddOp(grpc_op* ops, size_t* nops) {
if (!send_buf_.Valid() || hijacked_) return; if (!send_buf_.Valid() || hijacked_) return;
@ -321,14 +333,14 @@ class CallOpSendMessage {
if (!send_buf_.Valid()) return; if (!send_buf_.Valid()) return;
interceptor_methods->AddInterceptionHookPoint( interceptor_methods->AddInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE); experimental::InterceptionHookPoints::PRE_SEND_MESSAGE);
interceptor_methods->SetSendMessage(&send_buf_); interceptor_methods->SetSendMessage(&send_buf_, msg_);
} }
void SetFinishInterceptionHookPoint( void SetFinishInterceptionHookPoint(
InterceptorBatchMethodsImpl* interceptor_methods) { InterceptorBatchMethodsImpl* interceptor_methods) {
// The contents of the SendMessage value that was previously set // The contents of the SendMessage value that was previously set
// has had its references stolen by core's operations // has had its references stolen by core's operations
interceptor_methods->SetSendMessage(nullptr); interceptor_methods->SetSendMessage(nullptr, nullptr);
} }
void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) { void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) {
@ -336,6 +348,7 @@ class CallOpSendMessage {
} }
private: private:
const void* msg_ = nullptr; // The original non-serialized message
bool hijacked_ = false; bool hijacked_ = false;
ByteBuffer send_buf_; ByteBuffer send_buf_;
WriteOptions write_options_; WriteOptions write_options_;
@ -362,6 +375,18 @@ Status CallOpSendMessage::SendMessage(const M& message) {
return SendMessage(message, WriteOptions()); return SendMessage(message, WriteOptions());
} }
template <class M>
Status CallOpSendMessage::SendMessage(const M* message, WriteOptions options) {
msg_ = message;
return SendMessage(*message, options);
}
template <class M>
Status CallOpSendMessage::SendMessage(const M* message) {
msg_ = message;
return SendMessage(*message, WriteOptions());
}
template <class R> template <class R>
class CallOpRecvMessage { class CallOpRecvMessage {
public: public:

@ -73,7 +73,7 @@ class CallbackUnaryCallImpl {
CallbackWithStatusTag(call.call(), on_completion, ops); CallbackWithStatusTag(call.call(), on_completion, ops);
// TODO(vjpai): Unify code with sync API as much as possible // TODO(vjpai): Unify code with sync API as much as possible
Status s = ops->SendMessage(*request); Status s = ops->SendMessage(request);
if (!s.ok()) { if (!s.ok()) {
tag->force_run(s); tag->force_run(s);
return; return;
@ -341,7 +341,7 @@ class ClientCallbackReaderWriterImpl
start_corked_ = false; start_corked_ = false;
} }
// TODO(vjpai): don't assert // TODO(vjpai): don't assert
GPR_CODEGEN_ASSERT(write_ops_.SendMessage(*msg).ok()); GPR_CODEGEN_ASSERT(write_ops_.SendMessage(msg).ok());
if (options.is_last_message()) { if (options.is_last_message()) {
options.set_buffer_hint(); options.set_buffer_hint();
@ -650,7 +650,7 @@ class ClientCallbackWriterImpl
start_corked_ = false; start_corked_ = false;
} }
// TODO(vjpai): don't assert // TODO(vjpai): don't assert
GPR_CODEGEN_ASSERT(write_ops_.SendMessage(*msg).ok()); GPR_CODEGEN_ASSERT(write_ops_.SendMessage(msg).ok());
if (options.is_last_message()) { if (options.is_last_message()) {
options.set_buffer_hint(); options.set_buffer_hint();

@ -57,7 +57,7 @@ class BlockingUnaryCallImpl {
CallOpRecvInitialMetadata, CallOpRecvMessage<OutputMessage>, CallOpRecvInitialMetadata, CallOpRecvMessage<OutputMessage>,
CallOpClientSendClose, CallOpClientRecvStatus> CallOpClientSendClose, CallOpClientRecvStatus>
ops; ops;
status_ = ops.SendMessage(request); status_ = ops.SendMessage(&request);
if (!status_.ok()) { if (!status_.ok()) {
return; return;
} }

@ -111,6 +111,12 @@ class InterceptorBatchMethods {
/// A return value of nullptr indicates that this ByteBuffer is not valid. /// A return value of nullptr indicates that this ByteBuffer is not valid.
virtual ByteBuffer* GetSendMessage() = 0; virtual ByteBuffer* GetSendMessage() = 0;
/// Returns a non-modifiable pointer to the original non-serialized form of
/// the message. Valid for PRE_SEND_MESSAGE interceptions. A return value of
/// nullptr indicates that this field is not valid. Also note that this is
/// only supported for sync and callback APIs at the present moment.
virtual const void* GetOriginalSendMessage() = 0;
/// Returns a modifiable multimap of the initial metadata to be sent. Valid /// Returns a modifiable multimap of the initial metadata to be sent. Valid
/// for PRE_SEND_INITIAL_METADATA interceptions. A value of nullptr indicates /// for PRE_SEND_INITIAL_METADATA interceptions. A value of nullptr indicates
/// that this field is not valid. /// that this field is not valid.

@ -81,6 +81,8 @@ class InterceptorBatchMethodsImpl
ByteBuffer* GetSendMessage() override { return send_message_; } ByteBuffer* GetSendMessage() override { return send_message_; }
const void* GetOriginalSendMessage() override { return orig_send_message_; }
std::multimap<grpc::string, grpc::string>* GetSendInitialMetadata() override { std::multimap<grpc::string, grpc::string>* GetSendInitialMetadata() override {
return send_initial_metadata_; return send_initial_metadata_;
} }
@ -115,7 +117,10 @@ class InterceptorBatchMethodsImpl
return recv_trailing_metadata_->map(); return recv_trailing_metadata_->map();
} }
void SetSendMessage(ByteBuffer* buf) { send_message_ = buf; } void SetSendMessage(ByteBuffer* buf, const void* msg) {
send_message_ = buf;
orig_send_message_ = msg;
}
void SetSendInitialMetadata( void SetSendInitialMetadata(
std::multimap<grpc::string, grpc::string>* metadata) { std::multimap<grpc::string, grpc::string>* metadata) {
@ -334,6 +339,7 @@ class InterceptorBatchMethodsImpl
std::function<void(void)> callback_; std::function<void(void)> callback_;
ByteBuffer* send_message_ = nullptr; ByteBuffer* send_message_ = nullptr;
const void* orig_send_message_ = nullptr;
std::multimap<grpc::string, grpc::string>* send_initial_metadata_; std::multimap<grpc::string, grpc::string>* send_initial_metadata_;
@ -386,6 +392,14 @@ class CancelInterceptorBatchMethods
return nullptr; return nullptr;
} }
const void* GetOriginalSendMessage() override {
GPR_CODEGEN_ASSERT(
false &&
"It is illegal to call GetOriginalSendMessage on a method which "
"has a Cancel notification");
return nullptr;
}
std::multimap<grpc::string, grpc::string>* GetSendInitialMetadata() override { std::multimap<grpc::string, grpc::string>* GetSendInitialMetadata() override {
GPR_CODEGEN_ASSERT(false && GPR_CODEGEN_ASSERT(false &&
"It is illegal to call GetSendInitialMetadata on a " "It is illegal to call GetSendInitialMetadata on a "

@ -642,7 +642,7 @@ class CallbackServerStreamingHandler : public MethodHandler {
ctx_->sent_initial_metadata_ = true; ctx_->sent_initial_metadata_ = true;
} }
// TODO(vjpai): don't assert // TODO(vjpai): don't assert
GPR_CODEGEN_ASSERT(write_ops_.SendMessage(*resp, options).ok()); GPR_CODEGEN_ASSERT(write_ops_.SendMessage(resp, options).ok());
call_.PerformOps(&write_ops_); call_.PerformOps(&write_ops_);
} }
@ -652,7 +652,7 @@ class CallbackServerStreamingHandler : public MethodHandler {
// Don't send any message if the status is bad // Don't send any message if the status is bad
if (s.ok()) { if (s.ok()) {
// TODO(vjpai): don't assert // TODO(vjpai): don't assert
GPR_CODEGEN_ASSERT(finish_ops_.SendMessage(*resp, options).ok()); GPR_CODEGEN_ASSERT(finish_ops_.SendMessage(resp, options).ok());
} }
Finish(std::move(s)); Finish(std::move(s));
} }
@ -804,7 +804,7 @@ class CallbackBidiHandler : public MethodHandler {
ctx_->sent_initial_metadata_ = true; ctx_->sent_initial_metadata_ = true;
} }
// TODO(vjpai): don't assert // TODO(vjpai): don't assert
GPR_CODEGEN_ASSERT(write_ops_.SendMessage(*resp, options).ok()); GPR_CODEGEN_ASSERT(write_ops_.SendMessage(resp, options).ok());
call_.PerformOps(&write_ops_); call_.PerformOps(&write_ops_);
} }
@ -813,7 +813,7 @@ class CallbackBidiHandler : public MethodHandler {
// Don't send any message if the status is bad // Don't send any message if the status is bad
if (s.ok()) { if (s.ok()) {
// TODO(vjpai): don't assert // TODO(vjpai): don't assert
GPR_CODEGEN_ASSERT(finish_ops_.SendMessage(*resp, options).ok()); GPR_CODEGEN_ASSERT(finish_ops_.SendMessage(resp, options).ok());
} }
Finish(std::move(s)); Finish(std::move(s));
} }

@ -253,7 +253,7 @@ class ClientReader final : public ClientReaderInterface<R> {
ops.SendInitialMetadata(&context->send_initial_metadata_, ops.SendInitialMetadata(&context->send_initial_metadata_,
context->initial_metadata_flags()); context->initial_metadata_flags());
// TODO(ctiller): don't assert // TODO(ctiller): don't assert
GPR_CODEGEN_ASSERT(ops.SendMessage(request).ok()); GPR_CODEGEN_ASSERT(ops.SendMessage(&request).ok());
ops.ClientSendClose(); ops.ClientSendClose();
call_.PerformOps(&ops); call_.PerformOps(&ops);
cq_.Pluck(&ops); cq_.Pluck(&ops);
@ -331,7 +331,7 @@ class ClientWriter : public ClientWriterInterface<W> {
context_->initial_metadata_flags()); context_->initial_metadata_flags());
context_->set_initial_metadata_corked(false); context_->set_initial_metadata_corked(false);
} }
if (!ops.SendMessage(msg, options).ok()) { if (!ops.SendMessage(&msg, options).ok()) {
return false; return false;
} }
@ -502,7 +502,7 @@ class ClientReaderWriter final : public ClientReaderWriterInterface<W, R> {
context_->initial_metadata_flags()); context_->initial_metadata_flags());
context_->set_initial_metadata_corked(false); context_->set_initial_metadata_corked(false);
} }
if (!ops.SendMessage(msg, options).ok()) { if (!ops.SendMessage(&msg, options).ok()) {
return false; return false;
} }
@ -656,7 +656,7 @@ class ServerWriter final : public ServerWriterInterface<W> {
options.set_buffer_hint(); options.set_buffer_hint();
} }
if (!ctx_->pending_ops_.SendMessage(msg, options).ok()) { if (!ctx_->pending_ops_.SendMessage(&msg, options).ok()) {
return false; return false;
} }
if (!ctx_->sent_initial_metadata_) { if (!ctx_->sent_initial_metadata_) {
@ -734,7 +734,7 @@ class ServerReaderWriterBody final {
if (options.is_last_message()) { if (options.is_last_message()) {
options.set_buffer_hint(); options.set_buffer_hint();
} }
if (!ctx_->pending_ops_.SendMessage(msg, options).ok()) { if (!ctx_->pending_ops_.SendMessage(&msg, options).ok()) {
return false; return false;
} }
if (!ctx_->sent_initial_metadata_) { if (!ctx_->sent_initial_metadata_) {

@ -293,6 +293,11 @@ class LoggingInterceptor : public experimental::Interceptor {
SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req) SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
.ok()); .ok());
EXPECT_TRUE(req.message().find("Hello") == 0); EXPECT_TRUE(req.message().find("Hello") == 0);
EXPECT_EQ(
static_cast<const EchoRequest*>(methods->GetOriginalSendMessage())
->message()
.find("Hello"),
0);
} }
if (methods->QueryInterceptionHookPoint( if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) { experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {

Loading…
Cancel
Save