From adca91f6cfe57cbd4af1e5a8cc8bfe3b506445c5 Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Thu, 18 Oct 2018 16:07:00 -0700 Subject: [PATCH] Server interception for SyncRequest --- include/grpcpp/impl/codegen/call.h | 6 --- .../grpcpp/impl/codegen/method_handler_impl.h | 52 +++++++++++-------- .../grpcpp/impl/codegen/rpc_service_method.h | 16 ++++-- .../grpcpp/impl/codegen/server_interface.h | 7 +++ include/grpcpp/server.h | 4 ++ src/cpp/server/server_cc.cc | 51 ++++++++++-------- src/cpp/server/server_context.cc | 6 +-- 7 files changed, 84 insertions(+), 58 deletions(-) diff --git a/include/grpcpp/impl/codegen/call.h b/include/grpcpp/impl/codegen/call.h index ccadd2af43d..d5911447981 100644 --- a/include/grpcpp/impl/codegen/call.h +++ b/include/grpcpp/impl/codegen/call.h @@ -997,7 +997,6 @@ class InterceptorBatchMethodsImpl server_rpc_info->interceptors_.size() == 0) { return true; } - GPR_ASSERT(false); RunServerInterceptors(); return false; } @@ -1128,7 +1127,6 @@ class InterceptorBatchMethodsImpl Status send_status_; std::multimap* send_trailing_metadata_ = nullptr; - size_t* send_trailing_metadata_count_ = nullptr; void* recv_message_ = nullptr; @@ -1137,10 +1135,6 @@ class InterceptorBatchMethodsImpl Status* recv_status_ = nullptr; internal::MetadataMap* recv_trailing_metadata_ = nullptr; - - // void (*hijacking_state_setter_)(); - // void (*continue_after_interception_)(); - // void (*continue_after_reverse_interception_)(); }; /// Primary implementation of CallOpSetInterface. diff --git a/include/grpcpp/impl/codegen/method_handler_impl.h b/include/grpcpp/impl/codegen/method_handler_impl.h index 93f4af03eed..4f02e3e39b9 100644 --- a/include/grpcpp/impl/codegen/method_handler_impl.h +++ b/include/grpcpp/impl/codegen/method_handler_impl.h @@ -60,10 +60,13 @@ class RpcMethodHandler : public MethodHandler { void RunHandler(const HandlerParameter& param) final { ResponseType rsp; - if (status_.ok()) { - status_ = CatchingFunctionHandler([this, ¶m, &rsp] { - return func_(service_, param.server_context, &this->req_, &rsp); + Status status = param.status; + if (status.ok()) { + status = CatchingFunctionHandler([this, ¶m, &rsp] { + return func_(service_, param.server_context, + static_cast(param.request), &rsp); }); + delete static_cast(param.request); } GPR_CODEGEN_ASSERT(!param.server_context->sent_initial_metadata_); @@ -75,22 +78,24 @@ class RpcMethodHandler : public MethodHandler { if (param.server_context->compression_level_set()) { ops.set_compression_level(param.server_context->compression_level()); } - if (status_.ok()) { - status_ = ops.SendMessage(rsp); + if (status.ok()) { + status = ops.SendMessage(rsp); } - ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status_); + ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status); param.call->PerformOps(&ops); param.call->cq()->Pluck(&ops); } - void* Deserialize(grpc_byte_buffer* req) final { + void* Deserialize(grpc_byte_buffer* req, Status* status) final { ByteBuffer buf; buf.set_buffer(req); - status_ = SerializationTraits::Deserialize(&buf, &req_); + auto* request = new RequestType(); + *status = SerializationTraits::Deserialize(&buf, request); buf.Release(); - if (status_.ok()) { - return &req_; + if (status->ok()) { + return request; } + delete request; return nullptr; } @@ -101,8 +106,6 @@ class RpcMethodHandler : public MethodHandler { func_; // The class the above handler function lives in. ServiceType* service_; - RequestType req_; - Status status_; }; /// A wrapper class of an application provided client streaming handler. @@ -160,11 +163,14 @@ class ServerStreamingHandler : public MethodHandler { : func_(func), service_(service) {} void RunHandler(const HandlerParameter& param) final { - if (status_.ok()) { + Status status = param.status; + if (status.ok()) { ServerWriter writer(param.call, param.server_context); - status_ = CatchingFunctionHandler([this, ¶m, &writer] { - return func_(service_, param.server_context, &this->req_, &writer); + status = CatchingFunctionHandler([this, ¶m, &writer] { + return func_(service_, param.server_context, + static_cast(param.request), &writer); }); + delete static_cast(param.request); } CallOpSet ops; @@ -175,7 +181,7 @@ class ServerStreamingHandler : public MethodHandler { ops.set_compression_level(param.server_context->compression_level()); } } - ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status_); + ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status); param.call->PerformOps(&ops); if (param.server_context->has_pending_ops_) { param.call->cq()->Pluck(¶m.server_context->pending_ops_); @@ -183,14 +189,16 @@ class ServerStreamingHandler : public MethodHandler { param.call->cq()->Pluck(&ops); } - void* Deserialize(grpc_byte_buffer* req) final { + void* Deserialize(grpc_byte_buffer* req, Status* status) final { ByteBuffer buf; buf.set_buffer(req); - status_ = SerializationTraits::Deserialize(&buf, &req_); + auto* request = new RequestType(); + *status = SerializationTraits::Deserialize(&buf, request); buf.Release(); - if (status_.ok()) { - return &req_; + if (status->ok()) { + return request; } + delete request; return nullptr; } @@ -199,8 +207,6 @@ class ServerStreamingHandler : public MethodHandler { ServerWriter*)> func_; ServiceType* service_; - RequestType req_; - Status status_; }; /// A wrapper class of an application provided bidi-streaming handler. @@ -317,7 +323,7 @@ class ErrorMethodHandler : public MethodHandler { param.call->cq()->Pluck(&ops); } - void* Deserialize(grpc_byte_buffer* req) final { + void* Deserialize(grpc_byte_buffer* req, Status* status) final { // We have to destroy any request payload if (req != nullptr) { g_core_codegen_interface->grpc_byte_buffer_destroy(req); diff --git a/include/grpcpp/impl/codegen/rpc_service_method.h b/include/grpcpp/impl/codegen/rpc_service_method.h index 04607efd7d1..44da2bd7689 100644 --- a/include/grpcpp/impl/codegen/rpc_service_method.h +++ b/include/grpcpp/impl/codegen/rpc_service_method.h @@ -40,17 +40,23 @@ class MethodHandler { public: virtual ~MethodHandler() {} struct HandlerParameter { - HandlerParameter(Call* c, ServerContext* context) - : call(c), server_context(context) {} + HandlerParameter(Call* c, ServerContext* context, void* req, + Status req_status) + : call(c), server_context(context), request(req), status(req_status) {} ~HandlerParameter() {} Call* call; ServerContext* server_context; + void* request; + Status status; }; virtual void RunHandler(const HandlerParameter& param) = 0; - /* Returns pointer to the deserialized request. Ownership is retained by the - handler. Returns nullptr if deserialization failed */ - virtual void* Deserialize(grpc_byte_buffer* req) { + /* Returns a pointer to the deserialized request. \a status reflects the + result of deserialization. This pointer and the status should be filled in + a HandlerParameter and passed to RunHandler. It is illegal to access the + pointer after calling RunHandler. Ownership of the deserialized request is + retained by the handler. Returns nullptr if deserialization failed. */ + virtual void* Deserialize(grpc_byte_buffer* req, Status* status) { GPR_CODEGEN_ASSERT(req == nullptr); return nullptr; } diff --git a/include/grpcpp/impl/codegen/server_interface.h b/include/grpcpp/impl/codegen/server_interface.h index 237991cde60..9310ccb39f7 100644 --- a/include/grpcpp/impl/codegen/server_interface.h +++ b/include/grpcpp/impl/codegen/server_interface.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -162,6 +163,7 @@ class ServerInterface : public internal::CallHook { void* const tag_; const bool delete_on_finalize_; grpc_call* call_; + internal::InterceptorBatchMethodsImpl interceptor_methods; }; class RegisteredAsyncRequest : public BaseAsyncRequest { @@ -295,6 +297,11 @@ class ServerInterface : public internal::CallHook { new GenericAsyncRequest(this, context, stream, call_cq, notification_cq, tag, true); } + +private: + virtual const std::vector>* interceptor_creators() { + return nullptr; + } }; } // namespace grpc diff --git a/include/grpcpp/server.h b/include/grpcpp/server.h index 27d1ec0cfa0..a593b60550b 100644 --- a/include/grpcpp/server.h +++ b/include/grpcpp/server.h @@ -191,6 +191,10 @@ class Server : public ServerInterface, private GrpcLibraryCodegen { grpc_server* server() override { return server_; }; private: + const std::vector>* interceptor_creators() override { + return &interceptor_creators_; + } + friend class AsyncGenericService; friend class ServerBuilder; friend class ServerInitializer; diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 59121ba136a..427d5d5abb7 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -214,6 +214,7 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { has_request_payload_(mrd->has_request_payload_), request_payload_(has_request_payload_ ? mrd->request_payload_ : nullptr), + request_(nullptr), method_(mrd->method_), call_(mrd->call_, server, &cq_, server->max_receive_message_size(), ctx_.set_server_rpc_info(experimental::ServerRpcInfo( @@ -248,11 +249,12 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { /* Set interception point for RECV MESSAGE */ auto* handler = resources_ ? method_->handler() : server_->resource_exhausted_handler_.get(); - auto* request = handler->Deserialize(request_payload_); + request_ = handler->Deserialize(request_payload_, &request_status_); + request_payload_ = nullptr; interceptor_methods_.AddInterceptionHookPoint( experimental::InterceptionHookPoints::POST_RECV_MESSAGE); - interceptor_methods_.SetRecvMessage(request); + interceptor_methods_.SetRecvMessage(request_); } interceptor_methods_.SetCall(&call_); interceptor_methods_.SetReverse(); @@ -266,22 +268,26 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { } void ContinueRunAfterInterception() { - ctx_.BeginCompletionOp(&call_); - global_callbacks_->PreSynchronousRequest(&ctx_); - auto* handler = resources_ ? method_->handler() - : server_->resource_exhausted_handler_.get(); - handler->RunHandler( - internal::MethodHandler::HandlerParameter(&call_, &ctx_)); - global_callbacks_->PostSynchronousRequest(&ctx_); - - cq_.Shutdown(); - - internal::CompletionQueueTag* op_tag = ctx_.GetCompletionOpTag(); - cq_.TryPluck(op_tag, gpr_inf_future(GPR_CLOCK_REALTIME)); - - /* Ensure the cq_ is shutdown */ - DummyTag ignored_tag; - GPR_ASSERT(cq_.Pluck(&ignored_tag) == false); + { + ctx_.BeginCompletionOp(&call_); + global_callbacks_->PreSynchronousRequest(&ctx_); + auto* handler = resources_ ? method_->handler() + : server_->resource_exhausted_handler_.get(); + handler->RunHandler(internal::MethodHandler::HandlerParameter( + &call_, &ctx_, request_, request_status_)); + request_ = nullptr; + global_callbacks_->PostSynchronousRequest(&ctx_); + + cq_.Shutdown(); + + internal::CompletionQueueTag* op_tag = ctx_.GetCompletionOpTag(); + cq_.TryPluck(op_tag, gpr_inf_future(GPR_CLOCK_REALTIME)); + + /* Ensure the cq_ is shutdown */ + DummyTag ignored_tag; + GPR_ASSERT(cq_.Pluck(&ignored_tag) == false); + } + delete this; } private: @@ -289,6 +295,8 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { ServerContext ctx_; const bool has_request_payload_; grpc_byte_buffer* request_payload_; + void* request_; + Status request_status_; internal::RpcServiceMethod* const method_; internal::Call call_; Server* server_; @@ -359,7 +367,7 @@ class Server::SyncRequestThreadManager : public ThreadManager { if (ok) { // Calldata takes ownership of the completion queue and interceptors // inside sync_req - SyncRequest::CallData cd(server_, sync_req); + auto* cd = new SyncRequest::CallData(server_, sync_req); // Prepare for the next request if (!IsShutdown()) { sync_req->SetupRequest(); // Create new completion queue for sync_req @@ -367,7 +375,7 @@ class Server::SyncRequestThreadManager : public ThreadManager { } GPR_TIMER_SCOPE("cd.Run()", 0); - cd.Run(global_callbacks_, resources); + cd->Run(global_callbacks_, resources); } // TODO (sreek) If ok is false here (which it isn't in case of // grpc_request_registered_call), we should still re-queue the request @@ -724,7 +732,8 @@ ServerInterface::BaseAsyncRequest::BaseAsyncRequest( call_cq_(call_cq), tag_(tag), delete_on_finalize_(delete_on_finalize), - call_(nullptr) { + call_(nullptr), + call_wrapper_() { call_cq_->RegisterAvalanching(); // This op will trigger more ops } diff --git a/src/cpp/server/server_context.cc b/src/cpp/server/server_context.cc index 03507bb6477..175394bd454 100644 --- a/src/cpp/server/server_context.cc +++ b/src/cpp/server/server_context.cc @@ -45,8 +45,8 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { tag_(nullptr), refs_(2), finalized_(false), - cancelled_(0), - done_intercepting_(false) {} + cancelled_(0) /*, + done_intercepting_(false)*/ {} void FillOps(internal::Call* call) override; bool FinalizeResult(void** tag, bool* status) override; @@ -90,7 +90,7 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface { int refs_; bool finalized_; int cancelled_; - bool done_intercepting_; + // bool done_intercepting_; internal::Call call_; internal::InterceptorBatchMethodsImpl interceptor_methods_; };