From d042a5acf1fc83810c5a3b3e7cf2a8340748f1ba Mon Sep 17 00:00:00 2001 From: Yash Tibrewal Date: Wed, 17 Oct 2018 22:05:48 -0700 Subject: [PATCH] some tests fail --- include/grpcpp/impl/codegen/byte_buffer.h | 10 +- include/grpcpp/impl/codegen/call.h | 157 +++++++++++++----- .../grpcpp/impl/codegen/method_handler_impl.h | 64 ++++--- .../grpcpp/impl/codegen/rpc_service_method.h | 17 +- include/grpcpp/impl/codegen/server_context.h | 8 + src/cpp/server/server_cc.cc | 59 +++++-- 6 files changed, 236 insertions(+), 79 deletions(-) diff --git a/include/grpcpp/impl/codegen/byte_buffer.h b/include/grpcpp/impl/codegen/byte_buffer.h index 8cc51581152..d54ae31852e 100644 --- a/include/grpcpp/impl/codegen/byte_buffer.h +++ b/include/grpcpp/impl/codegen/byte_buffer.h @@ -50,6 +50,11 @@ class ErrorMethodHandler; template class DeserializeFuncType; class GrpcByteBufferPeer; +template +class RpcMethodHandler; +template +class ServerStreamingHandler; + } // namespace internal /// A sequence of bytes. class ByteBuffer final { @@ -141,7 +146,10 @@ class ByteBuffer final { template friend class internal::CallOpRecvMessage; friend class internal::CallOpGenericRecvMessage; - friend class internal::MethodHandler; + template + friend class RpcMethodHandler; + template + friend class ServerStreamingHandler; template friend class internal::RpcMethodHandler; template diff --git a/include/grpcpp/impl/codegen/call.h b/include/grpcpp/impl/codegen/call.h index b4e5b05663b..ccadd2af43d 100644 --- a/include/grpcpp/impl/codegen/call.h +++ b/include/grpcpp/impl/codegen/call.h @@ -796,10 +796,15 @@ class Call final { CompletionQueue* cq() const { return cq_; } int max_receive_message_size() const { return max_receive_message_size_; } + experimental::ClientRpcInfo* client_rpc_info() const { return client_rpc_info_; } + experimental::ServerRpcInfo* server_rpc_info() const { + return server_rpc_info_; + } + private: CallHook* call_hook_; CompletionQueue* cq_; @@ -862,44 +867,17 @@ class InterceptorBatchMethodsImpl } virtual void Proceed() override { /* fill this */ - curr_iteration_ = reverse_ ? curr_iteration_ - 1 : curr_iteration_ + 1; - auto* rpc_info = call_->client_rpc_info(); - if (rpc_info->hijacked_ && - (!reverse_ && curr_iteration_ == rpc_info->hijacked_interceptor_ + 1)) { - /* We now need to provide hijacked recv ops to this interceptor */ - ClearHookPoints(); - ops_->SetHijackingState(); - rpc_info->RunInterceptor(this, curr_iteration_ - 1); - return; - } - if (!reverse_) { - /* We are going down the stack of interceptors */ - if (curr_iteration_ < static_cast(rpc_info->interceptors_.size())) { - if (rpc_info->hijacked_ && - curr_iteration_ > rpc_info->hijacked_interceptor_) { - /* This is a hijacked RPC and we are done with hijacking */ - ops_->ContinueFillOpsAfterInterception(); - } else { - rpc_info->RunInterceptor(this, curr_iteration_); - } - } else { - /* we are done running all the interceptors without any hijacking */ - ops_->ContinueFillOpsAfterInterception(); - } - } else { - /* We are going up the stack of interceptors */ - if (curr_iteration_ >= 0) { - /* Continue running interceptors */ - rpc_info->RunInterceptor(this, curr_iteration_); - } else { - /* we are done running all the interceptors without any hijacking */ - ops_->ContinueFinalizeResultAfterInterception(); - } + if (call_->client_rpc_info() != nullptr) { + return ProceedClient(); } + GPR_CODEGEN_ASSERT(call_->server_rpc_info() != nullptr); + ProceedServer(); } virtual void Hijack() override { /* fill this */ - GPR_CODEGEN_ASSERT(!reverse_); + /* Only the client can hijack when sending down initial metadata */ + GPR_CODEGEN_ASSERT(!reverse_ && ops_ != nullptr && + call_->client_rpc_info() != nullptr); auto* rpc_info = call_->client_rpc_info(); rpc_info->hijacked_ = true; rpc_info->hijacked_interceptor_ = curr_iteration_; @@ -997,12 +975,44 @@ class InterceptorBatchMethodsImpl void SetCallOpSetInterface(CallOpSetInterface* ops) { ops_ = ops; } - /* Returns true if no interceptors are run */ + /* Returns true if no interceptors are run. This should be used only by + subclasses of CallOpSetInterface. SetCall and SetCallOpSetInterface should + have been called before this. After all the interceptors are done running, + either ContinueFillOpsAfterInterception or + ContinueFinalizeOpsAfterInterception will be called. Note that neither of them + is invoked if there were no interceptors registered. + */ bool RunInterceptors() { - auto* rpc_info = call_->client_rpc_info(); - if (rpc_info == nullptr || rpc_info->interceptors_.size() == 0) { + auto* client_rpc_info = call_->client_rpc_info(); + if (client_rpc_info == nullptr || + client_rpc_info->interceptors_.size() == 0) { + return true; + } else { + RunClientInterceptors(); + return false; + } + + auto* server_rpc_info = call_->server_rpc_info(); + if (server_rpc_info == nullptr || + server_rpc_info->interceptors_.size() == 0) { return true; } + GPR_ASSERT(false); + RunServerInterceptors(); + return false; + } + + /* Returns true if no interceptors are run. Returns false otherwise if there + are interceptors registered. After the interceptors are done running \a f will + be invoked. This is to be used only by BaseAsyncRequest and SyncRequest. */ + bool RunInterceptors(std::function f) { + GPR_CODEGEN_ASSERT(reverse_ == true); + return true; + } + + private: + void RunClientInterceptors() { + auto* rpc_info = call_->client_rpc_info(); if (!reverse_) { curr_iteration_ = 0; } else { @@ -1015,10 +1025,78 @@ class InterceptorBatchMethodsImpl } } rpc_info->RunInterceptor(this, curr_iteration_); - return false; } - private: + void RunServerInterceptors() { + auto* rpc_info = call_->server_rpc_info(); + if (!reverse_) { + curr_iteration_ = 0; + } else { + curr_iteration_ = rpc_info->interceptors_.size() - 1; + } + rpc_info->RunInterceptor(this, curr_iteration_); + } + + void ProceedClient() { + curr_iteration_ = reverse_ ? curr_iteration_ - 1 : curr_iteration_ + 1; + auto* rpc_info = call_->client_rpc_info(); + if (rpc_info->hijacked_ && + (!reverse_ && curr_iteration_ == rpc_info->hijacked_interceptor_ + 1)) { + /* We now need to provide hijacked recv ops to this interceptor */ + ClearHookPoints(); + ops_->SetHijackingState(); + rpc_info->RunInterceptor(this, curr_iteration_ - 1); + return; + } + if (!reverse_) { + /* We are going down the stack of interceptors */ + if (curr_iteration_ < static_cast(rpc_info->interceptors_.size())) { + if (rpc_info->hijacked_ && + curr_iteration_ > rpc_info->hijacked_interceptor_) { + /* This is a hijacked RPC and we are done with hijacking */ + ops_->ContinueFillOpsAfterInterception(); + } else { + rpc_info->RunInterceptor(this, curr_iteration_); + } + } else { + /* we are done running all the interceptors without any hijacking */ + ops_->ContinueFillOpsAfterInterception(); + } + } else { + /* We are going up the stack of interceptors */ + if (curr_iteration_ >= 0) { + /* Continue running interceptors */ + rpc_info->RunInterceptor(this, curr_iteration_); + } else { + /* we are done running all the interceptors without any hijacking */ + ops_->ContinueFinalizeResultAfterInterception(); + } + } + } + + void ProceedServer() { + auto* rpc_info = call_->server_rpc_info(); + if (!reverse_) { + curr_iteration_++; + if (curr_iteration_ < static_cast(rpc_info->interceptors_.size())) { + return rpc_info->RunInterceptor(this, curr_iteration_); + } + } else { + curr_iteration_--; + /* We are going up the stack of interceptors */ + if (curr_iteration_ >= 0) { + /* Continue running interceptors */ + return rpc_info->RunInterceptor(this, curr_iteration_); + } + } + /* we are done running all the interceptors */ + if (ops_) { + ops_->ContinueFinalizeResultAfterInterception(); + } + GPR_CODEGEN_ASSERT(callback_); + callback_(); + } + void ClearHookPoints() { for (auto i = 0; i < static_cast( @@ -1038,6 +1116,7 @@ class InterceptorBatchMethodsImpl Call* call_ = nullptr; // The Call object is present along with CallOpSet object CallOpSetInterface* ops_ = nullptr; + std::function callback_; ByteBuffer* send_message_ = nullptr; diff --git a/include/grpcpp/impl/codegen/method_handler_impl.h b/include/grpcpp/impl/codegen/method_handler_impl.h index 1d3a2c9708c..93f4af03eed 100644 --- a/include/grpcpp/impl/codegen/method_handler_impl.h +++ b/include/grpcpp/impl/codegen/method_handler_impl.h @@ -59,13 +59,10 @@ class RpcMethodHandler : public MethodHandler { : func_(func), service_(service) {} void RunHandler(const HandlerParameter& param) final { - RequestType req; - Status status = SerializationTraits::Deserialize( - param.request.bbuf_ptr(), &req); ResponseType rsp; - if (status.ok()) { - status = CatchingFunctionHandler([this, ¶m, &req, &rsp] { - return func_(service_, param.server_context, &req, &rsp); + if (status_.ok()) { + status_ = CatchingFunctionHandler([this, ¶m, &rsp] { + return func_(service_, param.server_context, &this->req_, &rsp); }); } @@ -78,14 +75,25 @@ class RpcMethodHandler : public MethodHandler { if (param.server_context->compression_level_set()) { ops.set_compression_level(param.server_context->compression_level()); } - if (status.ok()) { - status = ops.SendMessage(rsp); + if (status_.ok()) { + status_ = ops.SendMessage(rsp); } - ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status); + ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status_); param.call->PerformOps(&ops); param.call->cq()->Pluck(&ops); } + void* Deserialize(grpc_byte_buffer* req) final { + ByteBuffer buf; + buf.set_buffer(req); + status_ = SerializationTraits::Deserialize(&buf, &req_); + buf.Release(); + if (status_.ok()) { + return &req_; + } + return nullptr; + } + private: /// Application provided rpc handler function. std::function::Deserialize( - param.request.bbuf_ptr(), &req); - - if (status.ok()) { + if (status_.ok()) { ServerWriter writer(param.call, param.server_context); - status = CatchingFunctionHandler([this, ¶m, &req, &writer] { - return func_(service_, param.server_context, &req, &writer); + status_ = CatchingFunctionHandler([this, ¶m, &writer] { + return func_(service_, param.server_context, &this->req_, &writer); }); } @@ -169,7 +175,7 @@ class ServerStreamingHandler : public MethodHandler { ops.set_compression_level(param.server_context->compression_level()); } } - ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status); + ops.ServerSendStatus(¶m.server_context->trailing_metadata_, status_); param.call->PerformOps(&ops); if (param.server_context->has_pending_ops_) { param.call->cq()->Pluck(¶m.server_context->pending_ops_); @@ -177,11 +183,24 @@ class ServerStreamingHandler : public MethodHandler { param.call->cq()->Pluck(&ops); } + void* Deserialize(grpc_byte_buffer* req) final { + ByteBuffer buf; + buf.set_buffer(req); + status_ = SerializationTraits::Deserialize(&buf, &req_); + buf.Release(); + if (status_.ok()) { + return &req_; + } + return nullptr; + } + private: std::function*)> func_; ServiceType* service_; + RequestType req_; + Status status_; }; /// A wrapper class of an application provided bidi-streaming handler. @@ -296,11 +315,14 @@ class ErrorMethodHandler : public MethodHandler { FillOps(param.server_context, &ops); param.call->PerformOps(&ops); param.call->cq()->Pluck(&ops); - // We also have to destroy any request payload in the handler parameter - ByteBuffer* payload = param.request.bbuf_ptr(); - if (payload != nullptr) { - payload->Clear(); + } + + void* Deserialize(grpc_byte_buffer* req) final { + // We have to destroy any request payload + if (req != nullptr) { + g_core_codegen_interface->grpc_byte_buffer_destroy(req); } + return nullptr; } }; diff --git a/include/grpcpp/impl/codegen/rpc_service_method.h b/include/grpcpp/impl/codegen/rpc_service_method.h index 5cf88e216f9..04607efd7d1 100644 --- a/include/grpcpp/impl/codegen/rpc_service_method.h +++ b/include/grpcpp/impl/codegen/rpc_service_method.h @@ -40,17 +40,20 @@ class MethodHandler { public: virtual ~MethodHandler() {} struct HandlerParameter { - HandlerParameter(Call* c, ServerContext* context, grpc_byte_buffer* req) - : call(c), server_context(context) { - request.set_buffer(req); - } - ~HandlerParameter() { request.Release(); } + HandlerParameter(Call* c, ServerContext* context) + : call(c), server_context(context) {} + ~HandlerParameter() {} Call* call; ServerContext* server_context; - // Handler required to destroy these contents - ByteBuffer request; }; virtual void RunHandler(const HandlerParameter& param) = 0; + + /* Returns pointer to the deserialized request. Ownership is retained by the + handler. Returns nullptr if deserialization failed */ + virtual void* Deserialize(grpc_byte_buffer* req) { + GPR_CODEGEN_ASSERT(req == nullptr); + return nullptr; + } }; /// Server side rpc method class diff --git a/include/grpcpp/impl/codegen/server_context.h b/include/grpcpp/impl/codegen/server_context.h index b58f029de93..ad6f04260fa 100644 --- a/include/grpcpp/impl/codegen/server_context.h +++ b/include/grpcpp/impl/codegen/server_context.h @@ -285,6 +285,12 @@ class ServerContext { uint32_t initial_metadata_flags() const { return 0; } + experimental::ServerRpcInfo* set_server_rpc_info( + experimental::ServerRpcInfo info) { + rpc_info_ = std::move(info); + return &rpc_info_; + } + CompletionOp* completion_op_; bool has_notify_when_done_tag_; void* async_notify_when_done_tag_; @@ -306,6 +312,8 @@ class ServerContext { internal::CallOpSendMessage> pending_ops_; bool has_pending_ops_; + + experimental::ServerRpcInfo rpc_info_; }; } // namespace grpc diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 807a66baefb..59121ba136a 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -27,7 +27,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -208,14 +210,17 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { public: explicit CallData(Server* server, SyncRequest* mrd) : cq_(mrd->cq_), - call_(mrd->call_, server, &cq_, server->max_receive_message_size(), - nullptr), ctx_(mrd->deadline_, &mrd->request_metadata_), has_request_payload_(mrd->has_request_payload_), request_payload_(has_request_payload_ ? mrd->request_payload_ : nullptr), method_(mrd->method_), - server_(server) { + call_(mrd->call_, server, &cq_, server->max_receive_message_size(), + ctx_.set_server_rpc_info(experimental::ServerRpcInfo( + &ctx_, method_->name(), server->interceptor_creators_))), + server_(server), + global_callbacks_(nullptr), + resources_(false) { ctx_.set_call(mrd->call_); ctx_.cq_ = &cq_; GPR_ASSERT(mrd->in_flight_); @@ -231,14 +236,43 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { void Run(const std::shared_ptr& global_callbacks, bool resources) { + global_callbacks_ = global_callbacks; + resources_ = resources; + + /* Set interception point for RECV INITIAL METADATA */ + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA); + interceptor_methods_.SetRecvInitialMetadata(&ctx_.client_metadata_); + + if (has_request_payload_) { + /* Set interception point for RECV MESSAGE */ + auto* handler = resources_ ? method_->handler() + : server_->resource_exhausted_handler_.get(); + auto* request = handler->Deserialize(request_payload_); + request_payload_ = nullptr; + interceptor_methods_.AddInterceptionHookPoint( + experimental::InterceptionHookPoints::POST_RECV_MESSAGE); + interceptor_methods_.SetRecvMessage(request); + } + interceptor_methods_.SetCall(&call_); + interceptor_methods_.SetReverse(); + auto f = std::bind(&CallData::ContinueRunAfterInterception, this); + if (interceptor_methods_.RunInterceptors(f)) { + ContinueRunAfterInterception(); + } else { + /* There were interceptors to be run, so ContinueRunAfterInterception + will be run when interceptors are done. */ + } + } + + void ContinueRunAfterInterception() { ctx_.BeginCompletionOp(&call_); - global_callbacks->PreSynchronousRequest(&ctx_); - auto* handler = resources ? method_->handler() - : server_->resource_exhausted_handler_.get(); - handler->RunHandler(internal::MethodHandler::HandlerParameter( - &call_, &ctx_, request_payload_)); - global_callbacks->PostSynchronousRequest(&ctx_); - request_payload_ = nullptr; + global_callbacks_->PreSynchronousRequest(&ctx_); + auto* handler = resources_ ? method_->handler() + : server_->resource_exhausted_handler_.get(); + handler->RunHandler( + internal::MethodHandler::HandlerParameter(&call_, &ctx_)); + global_callbacks_->PostSynchronousRequest(&ctx_); cq_.Shutdown(); @@ -252,12 +286,15 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { private: CompletionQueue cq_; - internal::Call call_; ServerContext ctx_; const bool has_request_payload_; grpc_byte_buffer* request_payload_; internal::RpcServiceMethod* const method_; + internal::Call call_; Server* server_; + std::shared_ptr global_callbacks_; + bool resources_; + internal::InterceptorBatchMethodsImpl interceptor_methods_; }; private: