From 456231b26d56e13b9a56b93baabede4dd8fc2519 Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Thu, 18 Oct 2018 22:43:49 -0700 Subject: [PATCH] Server side interception for CompletionOp and AsyncRequest --- include/grpcpp/impl/codegen/call.h | 12 ++- .../grpcpp/impl/codegen/server_interface.h | 85 ++++++++++++---- include/grpcpp/server.h | 4 +- src/cpp/server/server_cc.cc | 99 +++++++++++++++---- src/cpp/server/server_context.cc | 74 +++++++++++--- 5 files changed, 219 insertions(+), 55 deletions(-) diff --git a/include/grpcpp/impl/codegen/call.h b/include/grpcpp/impl/codegen/call.h index d5911447981..167db8d4ddd 100644 --- a/include/grpcpp/impl/codegen/call.h +++ b/include/grpcpp/impl/codegen/call.h @@ -1004,9 +1004,17 @@ class InterceptorBatchMethodsImpl /* Returns true if no interceptors are run. Returns false otherwise if there are interceptors registered. After the interceptors are done running \a f will be invoked. This is to be used only by BaseAsyncRequest and SyncRequest. */ - bool RunInterceptors(std::function f) { + bool RunInterceptors(std::function f) { GPR_CODEGEN_ASSERT(reverse_ == true); - return true; + GPR_CODEGEN_ASSERT(call_->client_rpc_info() == nullptr); + auto* server_rpc_info = call_->server_rpc_info(); + if (server_rpc_info == nullptr || + server_rpc_info->interceptors_.size() == 0) { + return true; + } + callback_ = std::move(f); + RunServerInterceptors(); + return false; } private: diff --git a/include/grpcpp/impl/codegen/server_interface.h b/include/grpcpp/impl/codegen/server_interface.h index 9310ccb39f7..c83d4aa1947 100644 --- a/include/grpcpp/impl/codegen/server_interface.h +++ b/include/grpcpp/impl/codegen/server_interface.h @@ -20,12 +20,14 @@ #define GRPCPP_IMPL_CODEGEN_SERVER_INTERFACE_H #include +//#include #include #include #include #include #include #include +#include namespace grpc { @@ -149,45 +151,69 @@ class ServerInterface : public internal::CallHook { public: BaseAsyncRequest(ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, - CompletionQueue* call_cq, void* tag, + CompletionQueue* call_cq, + ServerCompletionQueue* notification_cq, void* tag, bool delete_on_finalize); virtual ~BaseAsyncRequest(); bool FinalizeResult(void** tag, bool* status) override; + private: + void ContinueFinalizeResultAfterInterception(); + protected: ServerInterface* const server_; ServerContext* const context_; internal::ServerAsyncStreamingInterface* const stream_; CompletionQueue* const call_cq_; + ServerCompletionQueue* const notification_cq_; void* const tag_; const bool delete_on_finalize_; grpc_call* call_; - internal::InterceptorBatchMethodsImpl interceptor_methods; + internal::Call call_wrapper_; + internal::InterceptorBatchMethodsImpl interceptor_methods_; + bool done_intercepting_; + void* dummy_alarm_; /* This should have been Alarm, but we cannot depend on + alarm.h here */ }; class RegisteredAsyncRequest : public BaseAsyncRequest { public: RegisteredAsyncRequest(ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, - CompletionQueue* call_cq, void* tag); - - // uses BaseAsyncRequest::FinalizeResult + CompletionQueue* call_cq, + ServerCompletionQueue* notification_cq, void* tag, + const char* name); + + virtual bool FinalizeResult(void** tag, bool* status) override { + /* If we are done intercepting, then there is nothing more for us to do */ + if (done_intercepting_) { + return BaseAsyncRequest::FinalizeResult(tag, status); + } + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), + context_->set_server_rpc_info(experimental::ServerRpcInfo( + context_, name_, *server_->interceptor_creators()))); + return BaseAsyncRequest::FinalizeResult(tag, status); + } protected: void IssueRequest(void* registered_method, grpc_byte_buffer** payload, ServerCompletionQueue* notification_cq); + const char* name_; }; class NoPayloadAsyncRequest final : public RegisteredAsyncRequest { public: - NoPayloadAsyncRequest(void* registered_method, ServerInterface* server, - ServerContext* context, + NoPayloadAsyncRequest(internal::RpcServiceMethod* registered_method, + ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag) - : RegisteredAsyncRequest(server, context, stream, call_cq, tag) { - IssueRequest(registered_method, nullptr, notification_cq); + : RegisteredAsyncRequest(server, context, stream, call_cq, + notification_cq, tag, + registered_method->name()) { + IssueRequest(registered_method->server_tag(), nullptr, notification_cq); } // uses RegisteredAsyncRequest::FinalizeResult @@ -196,13 +222,15 @@ class ServerInterface : public internal::CallHook { template class PayloadAsyncRequest final : public RegisteredAsyncRequest { public: - PayloadAsyncRequest(void* registered_method, ServerInterface* server, - ServerContext* context, + PayloadAsyncRequest(internal::RpcServiceMethod* registered_method, + ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag, Message* request) - : RegisteredAsyncRequest(server, context, stream, call_cq, tag), + : RegisteredAsyncRequest(server, context, stream, call_cq, + notification_cq, tag, + registered_method->name()), registered_method_(registered_method), server_(server), context_(context), @@ -211,7 +239,8 @@ class ServerInterface : public internal::CallHook { notification_cq_(notification_cq), tag_(tag), request_(request) { - IssueRequest(registered_method, payload_.bbuf_ptr(), notification_cq); + IssueRequest(registered_method->server_tag(), payload_.bbuf_ptr(), + notification_cq); } ~PayloadAsyncRequest() { @@ -219,6 +248,10 @@ class ServerInterface : public internal::CallHook { } bool FinalizeResult(void** tag, bool* status) override { + /* If we are done intercepting, then there is nothing more for us to do */ + if (done_intercepting_) { + return RegisteredAsyncRequest::FinalizeResult(tag, status); + } if (*status) { if (!payload_.Valid() || !SerializationTraits::Deserialize( payload_.bbuf_ptr(), request_) @@ -237,15 +270,24 @@ class ServerInterface : public internal::CallHook { return false; } } + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), + context_->set_server_rpc_info(experimental::ServerRpcInfo( + context_, name_, *server_->interceptor_creators()))); + /* Set interception point for recv message */ + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + interceptor_methods_.SetRecvMessage(request_); return RegisteredAsyncRequest::FinalizeResult(tag, status); } private: - void* const registered_method_; + internal::RpcServiceMethod* const registered_method_; ServerInterface* const server_; ServerContext* const context_; internal::ServerAsyncStreamingInterface* const stream_; CompletionQueue* const call_cq_; + ServerCompletionQueue* const notification_cq_; void* const tag_; Message* const request_; @@ -274,9 +316,8 @@ class ServerInterface : public internal::CallHook { ServerCompletionQueue* notification_cq, void* tag, Message* message) { GPR_CODEGEN_ASSERT(method); - new PayloadAsyncRequest(method->server_tag(), this, context, - stream, call_cq, notification_cq, tag, - message); + new PayloadAsyncRequest(method, this, context, stream, call_cq, + notification_cq, tag, message); } void RequestAsyncCall(internal::RpcServiceMethod* method, @@ -285,8 +326,8 @@ class ServerInterface : public internal::CallHook { CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag) { GPR_CODEGEN_ASSERT(method); - new NoPayloadAsyncRequest(method->server_tag(), this, context, stream, - call_cq, notification_cq, tag); + new NoPayloadAsyncRequest(method, this, context, stream, call_cq, + notification_cq, tag); } void RequestAsyncGenericCall(GenericServerContext* context, @@ -298,8 +339,10 @@ class ServerInterface : public internal::CallHook { tag, true); } -private: - virtual const std::vector>* interceptor_creators() { + private: + virtual const std::vector< + std::unique_ptr>* + interceptor_creators() { return nullptr; } }; diff --git a/include/grpcpp/server.h b/include/grpcpp/server.h index a593b60550b..2b89ffd317a 100644 --- a/include/grpcpp/server.h +++ b/include/grpcpp/server.h @@ -191,7 +191,9 @@ class Server : public ServerInterface, private GrpcLibraryCodegen { grpc_server* server() override { return server_; }; private: - const std::vector>* interceptor_creators() override { + const std::vector< + std::unique_ptr>* + interceptor_creators() override { return &interceptor_creators_; } diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 427d5d5abb7..d53c3534a94 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -240,6 +241,8 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { global_callbacks_ = global_callbacks; resources_ = resources; + interceptor_methods_.SetCall(&call_); + interceptor_methods_.SetReverse(); /* Set interception point for RECV INITIAL METADATA */ interceptor_methods_.AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); @@ -256,8 +259,7 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { experimental::InterceptionHookPoints::POST_RECV_MESSAGE); interceptor_methods_.SetRecvMessage(request_); } - interceptor_methods_.SetCall(&call_); - interceptor_methods_.SetReverse(); + auto f = std::bind(&CallData::ContinueRunAfterInterception, this); if (interceptor_methods_.RunInterceptors(f)) { ContinueRunAfterInterception(); @@ -725,15 +727,21 @@ void Server::PerformOpsOnCall(internal::CallOpSetInterface* ops, ServerInterface::BaseAsyncRequest::BaseAsyncRequest( ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, - void* tag, bool delete_on_finalize) + ServerCompletionQueue* notification_cq, void* tag, bool delete_on_finalize) : server_(server), context_(context), stream_(stream), call_cq_(call_cq), + notification_cq_(notification_cq), tag_(tag), delete_on_finalize_(delete_on_finalize), call_(nullptr), - call_wrapper_() { + done_intercepting_(false) { + /* Set up interception state partially for the receive ops. call_wrapper_ is + * not filled at this point, but it will be filled before the interceptors are + * run. */ + interceptor_methods_.SetCall(&call_wrapper_); + interceptor_methods_.SetReverse(); call_cq_->RegisterAvalanching(); // This op will trigger more ops } @@ -743,17 +751,47 @@ ServerInterface::BaseAsyncRequest::~BaseAsyncRequest() { bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, bool* status) { + if (done_intercepting_) { + delete static_cast(dummy_alarm_); + dummy_alarm_ = nullptr; + *tag = tag_; + if (delete_on_finalize_) { + delete this; + } + return true; + } context_->set_call(call_); context_->cq_ = call_cq_; - internal::Call call(call_, server_, call_cq_, - server_->max_receive_message_size(), nullptr); + if (call_wrapper_.call() == nullptr) { + /* Fill it since it is empty. */ + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), nullptr); + } + // just the pointers inside call are copied here + stream_->BindCall(&call_wrapper_); + + if (*status && call_ && call_wrapper_.server_rpc_info()) { + done_intercepting_ = true; + /* Set interception point for RECV INITIAL METADATA */ + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); + interceptor_methods_.SetRecvInitialMetadata(&context_->client_metadata_); + auto f = std::bind(&ServerInterface::BaseAsyncRequest:: + ContinueFinalizeResultAfterInterception, + this); + if (interceptor_methods_.RunInterceptors(f)) { + /* There are no interceptors to run. Continue */ + } else { + /* There were interceptors to be run, so + ContinueFinalizeResultAfterInterception will be run when interceptors are + done. */ + return false; + } + } if (*status && call_) { - context_->BeginCompletionOp(&call); + context_->BeginCompletionOp(&call_wrapper_); } - // just the pointers inside call are copied here - stream_->BindCall(&call); - *tag = tag_; if (delete_on_finalize_) { delete this; @@ -761,11 +799,23 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag, return true; } +void ServerInterface::BaseAsyncRequest:: + ContinueFinalizeResultAfterInterception() { + context_->BeginCompletionOp(&call_wrapper_); + /* Queue a tag which will be returned immediately */ + dummy_alarm_ = new Alarm(); + static_cast(dummy_alarm_) + ->Set(notification_cq_, + g_core_codegen_interface->gpr_time_0(GPR_CLOCK_MONOTONIC), this); +} + ServerInterface::RegisteredAsyncRequest::RegisteredAsyncRequest( ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, - void* tag) - : BaseAsyncRequest(server, context, stream, call_cq, tag, true) {} + ServerCompletionQueue* notification_cq, void* tag, const char* name) + : BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag, + true), + name_(name) {} void ServerInterface::RegisteredAsyncRequest::IssueRequest( void* registered_method, grpc_byte_buffer** payload, @@ -781,7 +831,7 @@ ServerInterface::GenericAsyncRequest::GenericAsyncRequest( ServerInterface* server, GenericServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag, bool delete_on_finalize) - : BaseAsyncRequest(server, context, stream, call_cq, tag, + : BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag, delete_on_finalize) { grpc_call_details_init(&call_details_); GPR_ASSERT(notification_cq); @@ -794,6 +844,10 @@ ServerInterface::GenericAsyncRequest::GenericAsyncRequest( bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag, bool* status) { + /* If we are done intercepting, there is nothing more for us to do */ + if (done_intercepting_) { + return BaseAsyncRequest::FinalizeResult(tag, status); + } // TODO(yangg) remove the copy here. if (*status) { static_cast(context_)->method_ = @@ -804,16 +858,27 @@ bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag, } grpc_slice_unref(call_details_.method); grpc_slice_unref(call_details_.host); + call_wrapper_ = internal::Call( + call_, server_, call_cq_, server_->max_receive_message_size(), + context_->set_server_rpc_info(experimental::ServerRpcInfo( + context_, + static_cast(context_)->method_.c_str(), + *server_->interceptor_creators()))); return BaseAsyncRequest::FinalizeResult(tag, status); } bool Server::UnimplementedAsyncRequest::FinalizeResult(void** tag, bool* status) { - if (GenericAsyncRequest::FinalizeResult(tag, status) && *status) { - new UnimplementedAsyncRequest(server_, cq_); - new UnimplementedAsyncResponse(this); + if (GenericAsyncRequest::FinalizeResult(tag, status)) { + /* We either had no interceptors run or we are done interceptinh */ + if (*status) { + new UnimplementedAsyncRequest(server_, cq_); + new UnimplementedAsyncResponse(this); + } else { + delete this; + } } else { - delete this; + /* The tag was swallowed due to interception. We will see it again. */ } return false; } diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc index 175394bd454..42ae0ed1389 100644 --- a/src/cpp/server/server_context.cc +++ b/src/cpp/server/server_context.cc @@ -45,8 +45,8 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { tag_(nullptr), refs_(2), finalized_(false), - cancelled_(0) /*, - done_intercepting_(false)*/ {} + cancelled_(0), + done_intercepting_(false) {} void FillOps(internal::Call* call) override; bool FinalizeResult(void** tag, bool* status) override; @@ -69,14 +69,32 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { // This will be called while interceptors are run if the RPC is a hijacked // RPC. This should set hijacking state for each of the ops. - void SetHijackingState() override {} + void SetHijackingState() override { + /* Servers don't allow hijacking */ + GPR_CODEGEN_ASSERT(false); + } /* Should be called after interceptors are done running */ void ContinueFillOpsAfterInterception() override {} /* Should be called after interceptors are done running on the finalize result * path */ - void ContinueFinalizeResultAfterInterception() override {} + void ContinueFinalizeResultAfterInterception() override { + done_intercepting_ = true; + if (!has_tag_) { + /* We don't have a tag to return. */ + std::unique_lock lock(mu_); + if (--refs_ == 0) { + lock.unlock(); + delete this; + } + return; + } + /* Start a dummy op so that we can return the tag */ + GPR_CODEGEN_ASSERT(GRPC_CALL_OK == + g_core_codegen_interface->grpc_call_start_batch( + call_.call(), nullptr, 0, this, nullptr)); + } private: bool CheckCancelledNoPluck() { @@ -90,7 +108,7 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { int refs_; bool finalized_; int cancelled_; - // bool done_intercepting_; + bool done_intercepting_; internal::Call call_; internal::InterceptorBatchMethodsImpl interceptor_methods_; }; @@ -111,24 +129,52 @@ void ServerContext::CompletionOp::FillOps(internal::Call* call) { ops.reserved = nullptr; call_ = *call; interceptor_methods_.SetCall(&call_); + interceptor_methods_.SetReverse(); + interceptor_methods_.SetCallOpSetInterface(this); GPR_ASSERT(GRPC_CALL_OK == grpc_call_start_batch(call->call(), &ops, 1, this, nullptr)); + /* No interceptors to run here */ } bool ServerContext::CompletionOp::FinalizeResult(void** tag, bool* status) { - std::unique_lock lock(mu_); - finalized_ = true; bool ret = false; - if (has_tag_) { - *tag = tag_; - ret = true; + std::unique_lock lock(mu_); + if (done_intercepting_) { + /* We are done intercepting. */ + if (has_tag_) { + *tag = tag_; + ret = true; + } + if (--refs_ == 0) { + lock.unlock(); + delete this; + } + return ret; } + finalized_ = true; + if (!*status) cancelled_ = 1; - if (--refs_ == 0) { - lock.unlock(); - delete this; + /* Release the lock since we are going to be running through interceptors now + */ + lock.unlock(); + /* Add interception point and run through interceptors */ + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_CLOSE); + if (interceptor_methods_.RunInterceptors()) { + /* No interceptors were run */ + if (has_tag_) { + *tag = tag_; + ret = true; + } + lock.lock(); + if (--refs_ == 0) { + lock.unlock(); + delete this; + } + return ret; } - return ret; + /* There are interceptors to be run. Return false for now */ + return false; } // ServerContext body