From 5d831da9d135d7f1c58ff61bacb6e5a2787f05c9 Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Tue, 2 Oct 2018 14:17:59 -0700 Subject: [PATCH] Adding hook points for interception. Code compiles and tests still run --- include/grpcpp/impl/codegen/call.h | 223 ++++++++++++------ .../grpcpp/impl/codegen/client_interceptor.h | 59 ++++- include/grpcpp/impl/codegen/interceptor.h | 2 +- src/cpp/client/channel_cc.cc | 6 +- src/cpp/client/client_interceptor.cc | 46 ++++ src/cpp/server/server_cc.cc | 2 +- src/cpp/server/server_context.cc | 4 +- 7 files changed, 268 insertions(+), 74 deletions(-) create mode 100644 src/cpp/client/client_interceptor.cc diff --git a/include/grpcpp/impl/codegen/call.h b/include/grpcpp/impl/codegen/call.h index 7cadea00555..771fc22d466 100644 --- a/include/grpcpp/impl/codegen/call.h +++ b/include/grpcpp/impl/codegen/call.h @@ -24,10 +24,12 @@ #include #include #include +#include #include #include #include +#include #include #include #include @@ -50,6 +52,58 @@ namespace internal { class Call; class CallHook; +/// Straightforward wrapping of the C call object +class Call final { + public: + /** call is owned by the caller */ + Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq) + : call_hook_(call_hook), + cq_(cq), + call_(call), + max_receive_message_size_(-1), + rpc_info_(nullptr, nullptr, nullptr) {} + + Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq, + experimental::ClientRpcInfo rpc_info, + const std::vector< + std::unique_ptr>& + creators) + : call_hook_(call_hook), + cq_(cq), + call_(call), + max_receive_message_size_(-1), + rpc_info_(rpc_info) { + for (const auto& creator : creators) { + interceptors_.push_back(creator->CreateClientInterceptor(&rpc_info_)); + } + } + + Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq, + int max_receive_message_size) + : call_hook_(call_hook), + cq_(cq), + call_(call), + max_receive_message_size_(max_receive_message_size), + rpc_info_(nullptr, nullptr, nullptr) {} + + void PerformOps(CallOpSetInterface* ops) { + call_hook_->PerformOpsOnCall(ops, this); + } + + grpc_call* call() const { return call_; } + CompletionQueue* cq() const { return cq_; } + + int max_receive_message_size() const { return max_receive_message_size_; } + + private: + CallHook* call_hook_; + CompletionQueue* cq_; + grpc_call* call_; + int max_receive_message_size_; + experimental::ClientRpcInfo rpc_info_; + std::vector interceptors_; +}; + // TODO(yangg) if the map is changed before we send, the pointers will be a // mess. Make sure it does not happen. inline grpc_metadata* FillMetadataArray( @@ -201,13 +255,45 @@ class WriteOptions { }; namespace internal { + +class InterceptorBatchMethodsImpl + : public experimental::InterceptorBatchMethods { + public: + InterceptorBatchMethodsImpl() {} + + virtual ~InterceptorBatchMethodsImpl() {} + + virtual bool QueryInterceptionHookPoint( + experimental::InterceptionHookPoints type) override { + return hooks_[static_cast(type)]; + } + + virtual void Proceed() override { /* fill this */ + } + + virtual void Hijack() override { /* fill this */ + } + + void AddInterceptionHookPoint(experimental::InterceptionHookPoints type) { + hooks_[static_cast(type)]; + } + + private: + std::array( + experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS)> + hooks_; +}; + /// Default argument for CallOpSet. I is unused by the class, but can be /// used for generating multiple names for the same thing. template class CallNoOp { protected: - void AddOp(grpc_op* ops, size_t* nops) {} - void FinishOp(bool* status) {} + void AddOp(grpc_op* ops, size_t* nops, + InterceptorBatchMethodsImpl* interceptor_methods) {} + void FinishOp(bool* status, + InterceptorBatchMethodsImpl* interceptor_methods) {} }; class CallOpSendInitialMetadata { @@ -232,7 +318,8 @@ class CallOpSendInitialMetadata { } protected: - void AddOp(grpc_op* ops, size_t* nops) { + void AddOp(grpc_op* ops, size_t* nops, + InterceptorBatchMethodsImpl* interceptor_methods) { if (!send_) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_SEND_INITIAL_METADATA; @@ -246,8 +333,11 @@ class CallOpSendInitialMetadata { op->data.send_initial_metadata.maybe_compression_level.level = maybe_compression_level_.level; } + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA); } - void FinishOp(bool* status) { + void FinishOp(bool* status, + InterceptorBatchMethodsImpl* interceptor_methods) { if (!send_) return; g_core_codegen_interface->gpr_free(initial_metadata_); send_ = false; @@ -277,7 +367,8 @@ class CallOpSendMessage { Status SendMessage(const M& message) GRPC_MUST_USE_RESULT; protected: - void AddOp(grpc_op* ops, size_t* nops) { + void AddOp(grpc_op* ops, size_t* nops, + InterceptorBatchMethodsImpl* interceptor_methods) { if (!send_buf_.Valid()) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_SEND_MESSAGE; @@ -286,8 +377,13 @@ class CallOpSendMessage { op->data.send_message.send_message = send_buf_.c_buffer(); // Flags are per-message: clear them after use. write_options_.Clear(); + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_MESSAGE); + } + void FinishOp(bool* status, + InterceptorBatchMethodsImpl* interceptor_methods) { + send_buf_.Clear(); } - void FinishOp(bool* status) { send_buf_.Clear(); } private: ByteBuffer send_buf_; @@ -331,7 +427,8 @@ class CallOpRecvMessage { bool got_message; protected: - void AddOp(grpc_op* ops, size_t* nops) { + void AddOp(grpc_op* ops, size_t* nops, + InterceptorBatchMethodsImpl* interceptor_methods) { if (message_ == nullptr) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_RECV_MESSAGE; @@ -340,7 +437,8 @@ class CallOpRecvMessage { op->data.recv_message.recv_message = recv_buf_.c_buffer_ptr(); } - void FinishOp(bool* status) { + void FinishOp(bool* status, + InterceptorBatchMethodsImpl* interceptor_methods) { if (message_ == nullptr) return; if (recv_buf_.Valid()) { if (*status) { @@ -359,6 +457,8 @@ class CallOpRecvMessage { } } message_ = nullptr; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_RECV_MESSAGE); } private: @@ -406,7 +506,8 @@ class CallOpGenericRecvMessage { bool got_message; protected: - void AddOp(grpc_op* ops, size_t* nops) { + void AddOp(grpc_op* ops, size_t* nops, + InterceptorBatchMethodsImpl* interceptor_methods) { if (!deserialize_) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_RECV_MESSAGE; @@ -415,7 +516,8 @@ class CallOpGenericRecvMessage { op->data.recv_message.recv_message = recv_buf_.c_buffer_ptr(); } - void FinishOp(bool* status) { + void FinishOp(bool* status, + InterceptorBatchMethodsImpl* interceptor_methods) { if (!deserialize_) return; if (recv_buf_.Valid()) { if (*status) { @@ -433,6 +535,8 @@ class CallOpGenericRecvMessage { } } deserialize_.reset(); + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); } private: @@ -448,14 +552,18 @@ class CallOpClientSendClose { void ClientSendClose() { send_ = true; } protected: - void AddOp(grpc_op* ops, size_t* nops) { + void AddOp(grpc_op* ops, size_t* nops, + InterceptorBatchMethodsImpl* interceptor_methods) { if (!send_) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_SEND_CLOSE_FROM_CLIENT; op->flags = 0; op->reserved = NULL; } - void FinishOp(bool* status) { send_ = false; } + void FinishOp(bool* status, + InterceptorBatchMethodsImpl* interceptor_methods) { + send_ = false; + } private: bool send_; @@ -477,7 +585,8 @@ class CallOpServerSendStatus { } protected: - void AddOp(grpc_op* ops, size_t* nops) { + void AddOp(grpc_op* ops, size_t* nops, + InterceptorBatchMethodsImpl* interceptor_methods) { if (!send_status_available_) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_SEND_STATUS_FROM_SERVER; @@ -490,9 +599,12 @@ class CallOpServerSendStatus { send_error_message_.empty() ? nullptr : &error_message_slice_; op->flags = 0; op->reserved = NULL; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::PRE_SEND_STATUS); } - void FinishOp(bool* status) { + void FinishOp(bool* status, + InterceptorBatchMethodsImpl* interceptor_methods) { if (!send_status_available_) return; g_core_codegen_interface->gpr_free(trailing_metadata_); send_status_available_ = false; @@ -518,7 +630,8 @@ class CallOpRecvInitialMetadata { } protected: - void AddOp(grpc_op* ops, size_t* nops) { + void AddOp(grpc_op* ops, size_t* nops, + InterceptorBatchMethodsImpl* interceptor_methods) { if (metadata_map_ == nullptr) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_RECV_INITIAL_METADATA; @@ -527,9 +640,12 @@ class CallOpRecvInitialMetadata { op->reserved = NULL; } - void FinishOp(bool* status) { + void FinishOp(bool* status, + InterceptorBatchMethodsImpl* interceptor_methods) { if (metadata_map_ == nullptr) return; metadata_map_ = nullptr; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); } private: @@ -549,7 +665,8 @@ class CallOpClientRecvStatus { } protected: - void AddOp(grpc_op* ops, size_t* nops) { + void AddOp(grpc_op* ops, size_t* nops, + InterceptorBatchMethodsImpl* interceptor_methods) { if (recv_status_ == nullptr) return; grpc_op* op = &ops[(*nops)++]; op->op = GRPC_OP_RECV_STATUS_ON_CLIENT; @@ -561,7 +678,8 @@ class CallOpClientRecvStatus { op->reserved = NULL; } - void FinishOp(bool* status) { + void FinishOp(bool* status, + InterceptorBatchMethodsImpl* interceptor_methods) { if (recv_status_ == nullptr) return; grpc::string binary_error_details = metadata_map_->GetBinaryErrorDetails(); *recv_status_ = @@ -578,6 +696,8 @@ class CallOpClientRecvStatus { g_core_codegen_interface->gpr_free((void*)debug_error_string_); } recv_status_ = nullptr; + interceptor_methods->AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_STATUS); } private: @@ -598,7 +718,7 @@ class CallOpSetInterface : public CompletionQueueTag { public: /// Fills in grpc_op, starting from ops[*nops] and moving /// upwards. - virtual void FillOps(grpc_call* call, grpc_op* ops, size_t* nops) = 0; + virtual void FillOps(internal::Call* call, grpc_op* ops, size_t* nops) = 0; /// Get the tag to be used at the core completion queue. Generally, the /// value of cq_tag will be "this". However, it can be overridden if we @@ -624,27 +744,27 @@ class CallOpSet : public CallOpSetInterface, public Op6 { public: CallOpSet() : cq_tag_(this), return_tag_(this), call_(nullptr) {} - void FillOps(grpc_call* call, grpc_op* ops, size_t* nops) override { - this->Op1::AddOp(ops, nops); - this->Op2::AddOp(ops, nops); - this->Op3::AddOp(ops, nops); - this->Op4::AddOp(ops, nops); - this->Op5::AddOp(ops, nops); - this->Op6::AddOp(ops, nops); - g_core_codegen_interface->grpc_call_ref(call); + void FillOps(Call* call, grpc_op* ops, size_t* nops) override { + this->Op1::AddOp(ops, nops, &interceptor_methods_); + this->Op2::AddOp(ops, nops, &interceptor_methods_); + this->Op3::AddOp(ops, nops, &interceptor_methods_); + this->Op4::AddOp(ops, nops, &interceptor_methods_); + this->Op5::AddOp(ops, nops, &interceptor_methods_); + this->Op6::AddOp(ops, nops, &interceptor_methods_); + g_core_codegen_interface->grpc_call_ref(call->call()); call_ = call; } bool FinalizeResult(void** tag, bool* status) override { - this->Op1::FinishOp(status); - this->Op2::FinishOp(status); - this->Op3::FinishOp(status); - this->Op4::FinishOp(status); - this->Op5::FinishOp(status); - this->Op6::FinishOp(status); + this->Op1::FinishOp(status, &interceptor_methods_); + this->Op2::FinishOp(status, &interceptor_methods_); + this->Op3::FinishOp(status, &interceptor_methods_); + this->Op4::FinishOp(status, &interceptor_methods_); + this->Op5::FinishOp(status, &interceptor_methods_); + this->Op6::FinishOp(status, &interceptor_methods_); *tag = return_tag_; - g_core_codegen_interface->grpc_call_unref(call_); + g_core_codegen_interface->grpc_call_unref(call_->call()); return true; } @@ -661,41 +781,10 @@ class CallOpSet : public CallOpSetInterface, private: void* cq_tag_; void* return_tag_; - grpc_call* call_; + Call* call_; + InterceptorBatchMethodsImpl interceptor_methods_; }; -/// Straightforward wrapping of the C call object -class Call final { - public: - /** call is owned by the caller */ - Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq) - : call_hook_(call_hook), - cq_(cq), - call_(call), - max_receive_message_size_(-1) {} - - Call(grpc_call* call, CallHook* call_hook, CompletionQueue* cq, - int max_receive_message_size) - : call_hook_(call_hook), - cq_(cq), - call_(call), - max_receive_message_size_(max_receive_message_size) {} - - void PerformOps(CallOpSetInterface* ops) { - call_hook_->PerformOpsOnCall(ops, this); - } - - grpc_call* call() const { return call_; } - CompletionQueue* cq() const { return cq_; } - - int max_receive_message_size() const { return max_receive_message_size_; } - - private: - CallHook* call_hook_; - CompletionQueue* cq_; - grpc_call* call_; - int max_receive_message_size_; -}; } // namespace internal } // namespace grpc diff --git a/include/grpcpp/impl/codegen/client_interceptor.h b/include/grpcpp/impl/codegen/client_interceptor.h index f460c5ac0c0..f7963a57d5e 100644 --- a/include/grpcpp/impl/codegen/client_interceptor.h +++ b/include/grpcpp/impl/codegen/client_interceptor.h @@ -19,7 +19,9 @@ #ifndef GRPCPP_IMPL_CODEGEN_CLIENT_INTERCEPTOR_H #define GRPCPP_IMPL_CODEGEN_CLIENT_INTERCEPTOR_H +#include #include +#include namespace grpc { namespace experimental { @@ -30,7 +32,62 @@ class ClientInterceptor { virtual void Intercept(InterceptorBatchMethods* methods) = 0; }; -class ClientRpcInfo {}; +class ClientRpcInfo { + public: + ClientRpcInfo(grpc::ClientContext* ctx, const char* method, + const grpc::Channel* channel) + : ctx_(ctx), method_(method), channel_(channel) {} + ~ClientRpcInfo(){}; + + // Getter methods + const char* method() { return method_; } + string peer() { return ctx_->peer(); } + const Channel* channel() { return channel_; } + // const grpc::InterceptedMessage& outgoing_message(); + // grpc::InterceptedMessage *mutable_outgoing_message(); + // const grpc::InterceptedMessage& received_message(); + // grpc::InterceptedMessage *mutable_received_message(); + std::shared_ptr auth_context() { + return ctx_->auth_context(); + } + const struct census_context* census_context() { + return ctx_->census_context(); + } + gpr_timespec deadline() { return ctx_->raw_deadline(); } + // const std::multimap* client_initial_metadata() + // { return &ctx_->send_initial_metadata_; } const + // std::multimap* + // server_initial_metadata() { return &ctx_->GetServerInitialMetadata(); } + // const std::multimap* + // server_trailing_metadata() { return &ctx_->GetServerTrailingMetadata(); } + // const Status *status(); + + // Setter methods + template + void set_deadline(const T& deadline) { + ctx_->set_deadline(deadline); + } + void set_census_context(struct census_context* cc) { + ctx_->set_census_context(cc); + } + // template + // void set_outgoing_message(M* msg); // edit outgoing message + // template + // void set_received_message(M* msg); // edit received message + // for hijacking (can be called multiple times for streaming) + // template + // void inject_received_message(M* msg); + // void set_client_initial_metadata( + // const std::multimap& overwrite); + // void set_server_initial_metadata(const std::multimap& overwrite); void set_server_trailing_metadata(const + // std::multimap& overwrite); void + // set_status(Status status); + private: + grpc::ClientContext* ctx_; + const char* method_; + const grpc::Channel* channel_; +}; class ClientInterceptorFactoryInterface { public: diff --git a/include/grpcpp/impl/codegen/interceptor.h b/include/grpcpp/impl/codegen/interceptor.h index 6402a3a9466..84dce42f977 100644 --- a/include/grpcpp/impl/codegen/interceptor.h +++ b/include/grpcpp/impl/codegen/interceptor.h @@ -50,7 +50,7 @@ enum class InterceptionHookPoints { class InterceptorBatchMethods { public: - virtual ~InterceptorBatchMethods(); + virtual ~InterceptorBatchMethods(){}; // Queries to check whether the current batch has an interception hook point // of type \a type virtual bool QueryInterceptionHookPoint(InterceptionHookPoints type) = 0; diff --git a/src/cpp/client/channel_cc.cc b/src/cpp/client/channel_cc.cc index 2cab41b3f56..eba92f00e90 100644 --- a/src/cpp/client/channel_cc.cc +++ b/src/cpp/client/channel_cc.cc @@ -147,7 +147,9 @@ internal::Call Channel::CreateCall(const internal::RpcMethod& method, } grpc_census_call_set_context(c_call, context->census_context()); context->set_call(c_call, shared_from_this()); - return internal::Call(c_call, this, cq); + + experimental::ClientRpcInfo info(context, method.name(), this); + return internal::Call(c_call, this, cq, info, interceptor_creators_); } void Channel::PerformOpsOnCall(internal::CallOpSetInterface* ops, @@ -155,7 +157,7 @@ void Channel::PerformOpsOnCall(internal::CallOpSetInterface* ops, static const size_t MAX_OPS = 8; size_t nops = 0; grpc_op cops[MAX_OPS]; - ops->FillOps(call->call(), cops, &nops); + ops->FillOps(call, cops, &nops); GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(call->call(), cops, nops, ops->cq_tag(), nullptr)); } diff --git a/src/cpp/client/client_interceptor.cc b/src/cpp/client/client_interceptor.cc new file mode 100644 index 00000000000..7cc07fa46de --- /dev/null +++ b/src/cpp/client/client_interceptor.cc @@ -0,0 +1,46 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include + +namespace grpc { +namespace experimental { +const ClientRpcInfo::grpc::InterceptedMessage& outgoing_message() {} +grpc::InterceptedMessage* ClientRpcInfo::mutable_outgoing_message() {} +const grpc::InterceptedMessage& ClientRpcInfo::received_message() {} +grpc::InterceptedMessage* ClientRpcInfo::mutable_received_message() {} +const Status ClientRpcInfo::*status() {} + +// Setter methods +template +void ClientRpcInfo::set_outgoing_message(M* msg) {} // edit outgoing message +template +void ClientRpcInfo::set_received_message(M* msg) {} // edit received message +// for hijacking (can be called multiple times for streaming) +template +void ClientRpcInfo::inject_received_message(M* msg) {} +void set_client_initial_metadata( + const ClientRpcInfo::std::multimap& overwrite) { +} +void ClientRpcInfo::set_server_initial_metadata( + const std::multimap& overwrite) {} +void ClientRpcInfo::set_server_trailing_metadata( + const std::multimap& overwrite) {} +void ClientRpcInfo::set_status(Status status) {} +} // namespace experimental +} // namespace grpc diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 7c764f4bce8..27629f2be09 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -670,7 +670,7 @@ void Server::PerformOpsOnCall(internal::CallOpSetInterface* ops, static const size_t MAX_OPS = 8; size_t nops = 0; grpc_op cops[MAX_OPS]; - ops->FillOps(call->call(), cops, &nops); + ops->FillOps(call, cops, &nops); // TODO(vjpai): Use ops->cq_tag once this case supports callbacks auto result = grpc_call_start_batch(call->call(), cops, nops, ops, nullptr); if (result != GRPC_CALL_OK) { diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc index b7254b6bb9c..cfa6c8d7e83 100644 --- a/src/cpp/server/server_context.cc +++ b/src/cpp/server/server_context.cc @@ -47,7 +47,7 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { finalized_(false), cancelled_(0) {} - void FillOps(grpc_call* call, grpc_op* ops, size_t* nops) override; + void FillOps(internal::Call* call, grpc_op* ops, size_t* nops) override; bool FinalizeResult(void** tag, bool* status) override; bool CheckCancelled(CompletionQueue* cq) { @@ -88,7 +88,7 @@ void ServerContext::CompletionOp::Unref() { } } -void ServerContext::CompletionOp::FillOps(grpc_call* call, grpc_op* ops, +void ServerContext::CompletionOp::FillOps(internal::Call* call, grpc_op* ops, size_t* nops) { ops->op = GRPC_OP_RECV_CLOSE_ON_SERVER; ops->data.recv_close_on_server.cancelled = &cancelled_;