diff --git a/include/grpcpp/impl/codegen/message_allocator.h b/include/grpcpp/impl/codegen/message_allocator.h index 107bec62f1d..83544c64068 100644 --- a/include/grpcpp/impl/codegen/message_allocator.h +++ b/include/grpcpp/impl/codegen/message_allocator.h @@ -22,31 +22,49 @@ namespace grpc { namespace experimental { -// This is per rpc struct for the allocator. We can potentially put the grpc -// call arena in here in the future. +// NOTE: This is an API for advanced users who need custom allocators. +// Per rpc struct for the allocator. This is the interface to return to user. +class RpcAllocatorState { + public: + virtual ~RpcAllocatorState() = default; + // Optionally deallocate request early to reduce the size of working set. + // A custom MessageAllocator needs to be registered to make use of this. + // This is not abstract because implementing it is optional. + virtual void FreeRequest() {} +}; + +// This is the interface returned by the allocator. +// grpc library will call the methods to get request/response pointers and to +// release the object when it is done. template -struct RpcAllocatorInfo { - RequestT* request; - ResponseT* response; - // per rpc allocator internal state. MessageAllocator can set it when - // AllocateMessages is called and use it later. - void* allocator_state; +class MessageHolder : public RpcAllocatorState { + public: + // Release this object. For example, if the custom allocator's + // AllocateMessasge creates an instance of a subclass with new, the Release() + // should do a "delete this;". + virtual void Release() = 0; + RequestT* request() { return request_; } + ResponseT* response() { return response_; } + + protected: + void set_request(RequestT* request) { request_ = request; } + void set_response(ResponseT* response) { response_ = response; } + + private: + // NOTE: subclasses should set these pointers. + RequestT* request_; + ResponseT* response_; }; -// Implementations need to be thread-safe +// A custom allocator can be set via the generated code to a callback unary +// method, such as SetMessageAllocatorFor_Echo(custom_allocator). The allocator +// needs to be alive for the lifetime of the server. +// Implementations need to be thread-safe. template class MessageAllocator { public: virtual ~MessageAllocator() = default; - // Allocate both request and response - virtual void AllocateMessages( - RpcAllocatorInfo* info) = 0; - // Optional: deallocate request early, called by - // ServerCallbackRpcController::ReleaseRequest - virtual void DeallocateRequest(RpcAllocatorInfo* info) {} - // Deallocate response and request (if applicable) - virtual void DeallocateMessages( - RpcAllocatorInfo* info) = 0; + virtual MessageHolder* AllocateMessages() = 0; }; } // namespace experimental diff --git a/include/grpcpp/impl/codegen/server_callback.h b/include/grpcpp/impl/codegen/server_callback.h index 3f6d5cd3555..9254518b605 100644 --- a/include/grpcpp/impl/codegen/server_callback.h +++ b/include/grpcpp/impl/codegen/server_callback.h @@ -77,6 +77,24 @@ class ServerReactor { std::atomic_int on_cancel_conditions_remaining_{2}; }; +template +class DefaultMessageHolder + : public experimental::MessageHolder { + public: + DefaultMessageHolder() { + this->set_request(&request_obj_); + this->set_response(&response_obj_); + } + void Release() override { + // the object is allocated in the call arena. + this->~DefaultMessageHolder(); + } + + private: + Request request_obj_; + Response response_obj_; +}; + } // namespace internal namespace experimental { @@ -137,13 +155,9 @@ class ServerCallbackRpcController { virtual void SetCancelCallback(std::function callback) = 0; virtual void ClearCancelCallback() = 0; - // NOTE: This is an API for advanced users who need custom allocators. - // Optionally deallocate request early to reduce the size of working set. - // A custom MessageAllocator needs to be registered to make use of this. - virtual void FreeRequest() = 0; // NOTE: This is an API for advanced users who need custom allocators. // Get and maybe mutate the allocator state associated with the current RPC. - virtual void* GetAllocatorState() = 0; + virtual RpcAllocatorState* GetRpcAllocatorState() = 0; }; // NOTE: The actual streaming object classes are provided @@ -465,13 +479,13 @@ class CallbackUnaryHandler : public MethodHandler { void RunHandler(const HandlerParameter& param) final { // Arena allocate a controller structure (that includes request/response) g_core_codegen_interface->grpc_call_ref(param.call->call()); - auto* allocator_info = - static_cast*>( + auto* allocator_state = + static_cast*>( param.internal_data); auto* controller = new (g_core_codegen_interface->grpc_call_arena_alloc( param.call->call(), sizeof(ServerCallbackRpcControllerImpl))) ServerCallbackRpcControllerImpl(param.server_context, param.call, - allocator_info, allocator_, + allocator_state, std::move(param.call_requester)); Status status = param.status; if (status.ok()) { @@ -489,36 +503,24 @@ class CallbackUnaryHandler : public MethodHandler { ByteBuffer buf; buf.set_buffer(req); RequestType* request = nullptr; - experimental::RpcAllocatorInfo* allocator_info = - new (g_core_codegen_interface->grpc_call_arena_alloc( - call, sizeof(*allocator_info))) - experimental::RpcAllocatorInfo(); + experimental::MessageHolder* allocator_state = + nullptr; if (allocator_ != nullptr) { - allocator_->AllocateMessages(allocator_info); + allocator_state = allocator_->AllocateMessages(); } else { - allocator_info->request = - new (g_core_codegen_interface->grpc_call_arena_alloc( - call, sizeof(RequestType))) RequestType(); - allocator_info->response = - new (g_core_codegen_interface->grpc_call_arena_alloc( - call, sizeof(ResponseType))) ResponseType(); + allocator_state = new (g_core_codegen_interface->grpc_call_arena_alloc( + call, sizeof(DefaultMessageHolder))) + DefaultMessageHolder(); } - *handler_data = allocator_info; - request = allocator_info->request; + *handler_data = allocator_state; + request = allocator_state->request(); *status = SerializationTraits::Deserialize(&buf, request); buf.Release(); if (status->ok()) { return request; } // Clean up on deserialization failure. - if (allocator_ != nullptr) { - allocator_->DeallocateMessages(allocator_info); - } else { - allocator_info->request->~RequestType(); - allocator_info->response->~ResponseType(); - allocator_info->request = nullptr; - allocator_info->response = nullptr; - } + allocator_state->Release(); return nullptr; } @@ -548,9 +550,8 @@ class CallbackUnaryHandler : public MethodHandler { } // The response is dropped if the status is not OK. if (s.ok()) { - finish_ops_.ServerSendStatus( - &ctx_->trailing_metadata_, - finish_ops_.SendMessagePtr(allocator_info_->response)); + finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, + finish_ops_.SendMessagePtr(response())); } else { finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, s); } @@ -588,14 +589,8 @@ class CallbackUnaryHandler : public MethodHandler { void ClearCancelCallback() override { ctx_->ClearCancelCallback(); } - void FreeRequest() override { - if (allocator_ != nullptr) { - allocator_->DeallocateRequest(allocator_info_); - } - } - - void* GetAllocatorState() override { - return allocator_info_->allocator_state; + experimental::RpcAllocatorState* GetRpcAllocatorState() override { + return allocator_state_; } private: @@ -603,35 +598,23 @@ class CallbackUnaryHandler : public MethodHandler { ServerCallbackRpcControllerImpl( ServerContext* ctx, Call* call, - experimental::RpcAllocatorInfo* - allocator_info, - experimental::MessageAllocator* allocator, + experimental::MessageHolder* allocator_state, std::function call_requester) : ctx_(ctx), call_(*call), - allocator_info_(allocator_info), - allocator_(allocator), + allocator_state_(allocator_state), call_requester_(std::move(call_requester)) { ctx_->BeginCompletionOp(call, [this](bool) { MaybeDone(); }, nullptr); } - const RequestType* request() { return allocator_info_->request; } - ResponseType* response() { return allocator_info_->response; } + const RequestType* request() { return allocator_state_->request(); } + ResponseType* response() { return allocator_state_->response(); } void MaybeDone() { if (--callbacks_outstanding_ == 0) { grpc_call* call = call_.call(); auto call_requester = std::move(call_requester_); - if (allocator_ != nullptr) { - allocator_->DeallocateMessages(allocator_info_); - } else { - if (allocator_info_->request != nullptr) { - allocator_info_->request->~RequestType(); - } - if (allocator_info_->response != nullptr) { - allocator_info_->response->~ResponseType(); - } - } + allocator_state_->Release(); this->~ServerCallbackRpcControllerImpl(); // explicitly call destructor g_core_codegen_interface->grpc_call_unref(call); call_requester(); @@ -647,8 +630,8 @@ class CallbackUnaryHandler : public MethodHandler { ServerContext* ctx_; Call call_; - experimental::RpcAllocatorInfo* allocator_info_; - experimental::MessageAllocator* allocator_; + experimental::MessageHolder* const + allocator_state_; std::function call_requester_; std::atomic_int callbacks_outstanding_{ 2}; // reserve for Finish and CompletionOp diff --git a/test/cpp/end2end/message_allocator_end2end_test.cc b/test/cpp/end2end/message_allocator_end2end_test.cc index c833a4ff1b6..1c52259088a 100644 --- a/test/cpp/end2end/message_allocator_end2end_test.cc +++ b/test/cpp/end2end/message_allocator_end2end_test.cc @@ -25,6 +25,7 @@ #include +#include #include #include @@ -62,11 +63,9 @@ class CallbackTestServiceImpl public: explicit CallbackTestServiceImpl() {} - void SetFreeRequest() { free_request_ = true; } - void SetAllocatorMutator( - std::function + std::function mutator) { allocator_mutator_ = mutator; } @@ -75,18 +74,15 @@ class CallbackTestServiceImpl EchoResponse* response, experimental::ServerCallbackRpcController* controller) override { response->set_message(request->message()); - if (free_request_) { - controller->FreeRequest(); - } else if (allocator_mutator_) { - allocator_mutator_(controller->GetAllocatorState(), request, response); + if (allocator_mutator_) { + allocator_mutator_(controller->GetRpcAllocatorState(), request, response); } controller->Finish(Status::OK); } private: - bool free_request_ = false; - std::function + std::function allocator_mutator_; }; @@ -230,26 +226,44 @@ class SimpleAllocatorTest : public MessageAllocatorEnd2endTestBase { class SimpleAllocator : public experimental::MessageAllocator { public: - void AllocateMessages( - experimental::RpcAllocatorInfo* info) { + class MessageHolderImpl + : public experimental::MessageHolder { + public: + MessageHolderImpl(int* request_deallocation_count, + int* messages_deallocation_count) + : request_deallocation_count_(request_deallocation_count), + messages_deallocation_count_(messages_deallocation_count) { + set_request(new EchoRequest); + set_response(new EchoResponse); + } + void Release() override { + (*messages_deallocation_count_)++; + delete request(); + delete response(); + delete this; + } + void FreeRequest() override { + (*request_deallocation_count_)++; + delete request(); + set_request(nullptr); + } + + EchoRequest* ReleaseRequest() { + auto* ret = request(); + set_request(nullptr); + return ret; + } + + private: + int* request_deallocation_count_; + int* messages_deallocation_count_; + }; + experimental::MessageHolder* AllocateMessages() + override { allocation_count++; - info->request = new EchoRequest; - info->response = new EchoResponse; - info->allocator_state = info; - } - void DeallocateRequest( - experimental::RpcAllocatorInfo* info) { - request_deallocation_count++; - delete info->request; - info->request = nullptr; - } - void DeallocateMessages( - experimental::RpcAllocatorInfo* info) { - messages_deallocation_count++; - delete info->request; - delete info->response; + return new MessageHolderImpl(&request_deallocation_count, + &messages_deallocation_count); } - int allocation_count = 0; int request_deallocation_count = 0; int messages_deallocation_count = 0; @@ -272,7 +286,16 @@ TEST_P(SimpleAllocatorTest, RpcWithEarlyFreeRequest) { MAYBE_SKIP_TEST; const int kRpcCount = 10; std::unique_ptr allocator(new SimpleAllocator); - callback_service_.SetFreeRequest(); + auto mutator = [](experimental::RpcAllocatorState* allocator_state, + const EchoRequest* req, EchoResponse* resp) { + auto* info = + static_cast(allocator_state); + EXPECT_EQ(req, info->request()); + EXPECT_EQ(resp, info->response()); + allocator_state->FreeRequest(); + EXPECT_EQ(nullptr, info->request()); + }; + callback_service_.SetAllocatorMutator(mutator); CreateServer(allocator.get()); ResetStub(); SendRpcs(kRpcCount); @@ -286,17 +309,15 @@ TEST_P(SimpleAllocatorTest, RpcWithReleaseRequest) { const int kRpcCount = 10; std::unique_ptr allocator(new SimpleAllocator); std::vector released_requests; - auto mutator = [&released_requests](void* allocator_state, - const EchoRequest* req, - EchoResponse* resp) { + auto mutator = [&released_requests]( + experimental::RpcAllocatorState* allocator_state, + const EchoRequest* req, EchoResponse* resp) { auto* info = - static_cast*>( - allocator_state); - EXPECT_EQ(req, info->request); - EXPECT_EQ(resp, info->response); - EXPECT_EQ(allocator_state, info->allocator_state); - released_requests.push_back(info->request); - info->request = nullptr; + static_cast(allocator_state); + EXPECT_EQ(req, info->request()); + EXPECT_EQ(resp, info->response()); + released_requests.push_back(info->ReleaseRequest()); + EXPECT_EQ(nullptr, info->request()); }; callback_service_.SetAllocatorMutator(mutator); CreateServer(allocator.get()); @@ -316,30 +337,27 @@ class ArenaAllocatorTest : public MessageAllocatorEnd2endTestBase { class ArenaAllocator : public experimental::MessageAllocator { public: - void AllocateMessages( - experimental::RpcAllocatorInfo* info) { + class MessageHolderImpl + : public experimental::MessageHolder { + public: + MessageHolderImpl() { + set_request( + google::protobuf::Arena::CreateMessage(&arena_)); + set_response( + google::protobuf::Arena::CreateMessage(&arena_)); + } + void Release() override { delete this; } + void FreeRequest() override { GPR_ASSERT(0); } + + private: + google::protobuf::Arena arena_; + }; + experimental::MessageHolder* AllocateMessages() + override { allocation_count++; - auto* arena = new google::protobuf::Arena; - info->allocator_state = arena; - info->request = - google::protobuf::Arena::CreateMessage(arena); - info->response = - google::protobuf::Arena::CreateMessage(arena); - } - void DeallocateRequest( - experimental::RpcAllocatorInfo* info) { - GPR_ASSERT(0); + return new MessageHolderImpl; } - void DeallocateMessages( - experimental::RpcAllocatorInfo* info) { - deallocation_count++; - auto* arena = - static_cast(info->allocator_state); - delete arena; - } - int allocation_count = 0; - int deallocation_count = 0; }; }; @@ -351,7 +369,6 @@ TEST_P(ArenaAllocatorTest, SimpleRpc) { ResetStub(); SendRpcs(kRpcCount); EXPECT_EQ(kRpcCount, allocator->allocation_count); - EXPECT_EQ(kRpcCount, allocator->deallocation_count); } std::vector CreateTestScenarios(bool test_insecure) {