Merge pull request #19016 from yang-g/message_allocator

Update the message allocator API
pull/19082/head
Yang Gao 6 years ago committed by GitHub
commit c4cb6e1787
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 54
      include/grpcpp/impl/codegen/message_allocator.h
  2. 101
      include/grpcpp/impl/codegen/server_callback.h
  3. 141
      test/cpp/end2end/message_allocator_end2end_test.cc

@ -22,31 +22,49 @@
namespace grpc {
namespace experimental {
// This is per rpc struct for the allocator. We can potentially put the grpc
// call arena in here in the future.
// NOTE: This is an API for advanced users who need custom allocators.
// Per rpc struct for the allocator. This is the interface to return to user.
class RpcAllocatorState {
public:
virtual ~RpcAllocatorState() = default;
// Optionally deallocate request early to reduce the size of working set.
// A custom MessageAllocator needs to be registered to make use of this.
// This is not abstract because implementing it is optional.
virtual void FreeRequest() {}
};
// This is the interface returned by the allocator.
// grpc library will call the methods to get request/response pointers and to
// release the object when it is done.
template <typename RequestT, typename ResponseT>
struct RpcAllocatorInfo {
RequestT* request;
ResponseT* response;
// per rpc allocator internal state. MessageAllocator can set it when
// AllocateMessages is called and use it later.
void* allocator_state;
class MessageHolder : public RpcAllocatorState {
public:
// Release this object. For example, if the custom allocator's
// AllocateMessasge creates an instance of a subclass with new, the Release()
// should do a "delete this;".
virtual void Release() = 0;
RequestT* request() { return request_; }
ResponseT* response() { return response_; }
protected:
void set_request(RequestT* request) { request_ = request; }
void set_response(ResponseT* response) { response_ = response; }
private:
// NOTE: subclasses should set these pointers.
RequestT* request_;
ResponseT* response_;
};
// Implementations need to be thread-safe
// A custom allocator can be set via the generated code to a callback unary
// method, such as SetMessageAllocatorFor_Echo(custom_allocator). The allocator
// needs to be alive for the lifetime of the server.
// Implementations need to be thread-safe.
template <typename RequestT, typename ResponseT>
class MessageAllocator {
public:
virtual ~MessageAllocator() = default;
// Allocate both request and response
virtual void AllocateMessages(
RpcAllocatorInfo<RequestT, ResponseT>* info) = 0;
// Optional: deallocate request early, called by
// ServerCallbackRpcController::ReleaseRequest
virtual void DeallocateRequest(RpcAllocatorInfo<RequestT, ResponseT>* info) {}
// Deallocate response and request (if applicable)
virtual void DeallocateMessages(
RpcAllocatorInfo<RequestT, ResponseT>* info) = 0;
virtual MessageHolder<RequestT, ResponseT>* AllocateMessages() = 0;
};
} // namespace experimental

@ -77,6 +77,24 @@ class ServerReactor {
std::atomic_int on_cancel_conditions_remaining_{2};
};
template <class Request, class Response>
class DefaultMessageHolder
: public experimental::MessageHolder<Request, Response> {
public:
DefaultMessageHolder() {
this->set_request(&request_obj_);
this->set_response(&response_obj_);
}
void Release() override {
// the object is allocated in the call arena.
this->~DefaultMessageHolder<Request, Response>();
}
private:
Request request_obj_;
Response response_obj_;
};
} // namespace internal
namespace experimental {
@ -137,13 +155,9 @@ class ServerCallbackRpcController {
virtual void SetCancelCallback(std::function<void()> callback) = 0;
virtual void ClearCancelCallback() = 0;
// NOTE: This is an API for advanced users who need custom allocators.
// Optionally deallocate request early to reduce the size of working set.
// A custom MessageAllocator needs to be registered to make use of this.
virtual void FreeRequest() = 0;
// NOTE: This is an API for advanced users who need custom allocators.
// Get and maybe mutate the allocator state associated with the current RPC.
virtual void* GetAllocatorState() = 0;
virtual RpcAllocatorState* GetRpcAllocatorState() = 0;
};
// NOTE: The actual streaming object classes are provided
@ -465,13 +479,13 @@ class CallbackUnaryHandler : public MethodHandler {
void RunHandler(const HandlerParameter& param) final {
// Arena allocate a controller structure (that includes request/response)
g_core_codegen_interface->grpc_call_ref(param.call->call());
auto* allocator_info =
static_cast<experimental::RpcAllocatorInfo<RequestType, ResponseType>*>(
auto* allocator_state =
static_cast<experimental::MessageHolder<RequestType, ResponseType>*>(
param.internal_data);
auto* controller = new (g_core_codegen_interface->grpc_call_arena_alloc(
param.call->call(), sizeof(ServerCallbackRpcControllerImpl)))
ServerCallbackRpcControllerImpl(param.server_context, param.call,
allocator_info, allocator_,
allocator_state,
std::move(param.call_requester));
Status status = param.status;
if (status.ok()) {
@ -489,36 +503,24 @@ class CallbackUnaryHandler : public MethodHandler {
ByteBuffer buf;
buf.set_buffer(req);
RequestType* request = nullptr;
experimental::RpcAllocatorInfo<RequestType, ResponseType>* allocator_info =
new (g_core_codegen_interface->grpc_call_arena_alloc(
call, sizeof(*allocator_info)))
experimental::RpcAllocatorInfo<RequestType, ResponseType>();
experimental::MessageHolder<RequestType, ResponseType>* allocator_state =
nullptr;
if (allocator_ != nullptr) {
allocator_->AllocateMessages(allocator_info);
allocator_state = allocator_->AllocateMessages();
} else {
allocator_info->request =
new (g_core_codegen_interface->grpc_call_arena_alloc(
call, sizeof(RequestType))) RequestType();
allocator_info->response =
new (g_core_codegen_interface->grpc_call_arena_alloc(
call, sizeof(ResponseType))) ResponseType();
allocator_state = new (g_core_codegen_interface->grpc_call_arena_alloc(
call, sizeof(DefaultMessageHolder<RequestType, ResponseType>)))
DefaultMessageHolder<RequestType, ResponseType>();
}
*handler_data = allocator_info;
request = allocator_info->request;
*handler_data = allocator_state;
request = allocator_state->request();
*status = SerializationTraits<RequestType>::Deserialize(&buf, request);
buf.Release();
if (status->ok()) {
return request;
}
// Clean up on deserialization failure.
if (allocator_ != nullptr) {
allocator_->DeallocateMessages(allocator_info);
} else {
allocator_info->request->~RequestType();
allocator_info->response->~ResponseType();
allocator_info->request = nullptr;
allocator_info->response = nullptr;
}
allocator_state->Release();
return nullptr;
}
@ -548,9 +550,8 @@ class CallbackUnaryHandler : public MethodHandler {
}
// The response is dropped if the status is not OK.
if (s.ok()) {
finish_ops_.ServerSendStatus(
&ctx_->trailing_metadata_,
finish_ops_.SendMessagePtr(allocator_info_->response));
finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_,
finish_ops_.SendMessagePtr(response()));
} else {
finish_ops_.ServerSendStatus(&ctx_->trailing_metadata_, s);
}
@ -588,14 +589,8 @@ class CallbackUnaryHandler : public MethodHandler {
void ClearCancelCallback() override { ctx_->ClearCancelCallback(); }
void FreeRequest() override {
if (allocator_ != nullptr) {
allocator_->DeallocateRequest(allocator_info_);
}
}
void* GetAllocatorState() override {
return allocator_info_->allocator_state;
experimental::RpcAllocatorState* GetRpcAllocatorState() override {
return allocator_state_;
}
private:
@ -603,35 +598,23 @@ class CallbackUnaryHandler : public MethodHandler {
ServerCallbackRpcControllerImpl(
ServerContext* ctx, Call* call,
experimental::RpcAllocatorInfo<RequestType, ResponseType>*
allocator_info,
experimental::MessageAllocator<RequestType, ResponseType>* allocator,
experimental::MessageHolder<RequestType, ResponseType>* allocator_state,
std::function<void()> call_requester)
: ctx_(ctx),
call_(*call),
allocator_info_(allocator_info),
allocator_(allocator),
allocator_state_(allocator_state),
call_requester_(std::move(call_requester)) {
ctx_->BeginCompletionOp(call, [this](bool) { MaybeDone(); }, nullptr);
}
const RequestType* request() { return allocator_info_->request; }
ResponseType* response() { return allocator_info_->response; }
const RequestType* request() { return allocator_state_->request(); }
ResponseType* response() { return allocator_state_->response(); }
void MaybeDone() {
if (--callbacks_outstanding_ == 0) {
grpc_call* call = call_.call();
auto call_requester = std::move(call_requester_);
if (allocator_ != nullptr) {
allocator_->DeallocateMessages(allocator_info_);
} else {
if (allocator_info_->request != nullptr) {
allocator_info_->request->~RequestType();
}
if (allocator_info_->response != nullptr) {
allocator_info_->response->~ResponseType();
}
}
allocator_state_->Release();
this->~ServerCallbackRpcControllerImpl(); // explicitly call destructor
g_core_codegen_interface->grpc_call_unref(call);
call_requester();
@ -647,8 +630,8 @@ class CallbackUnaryHandler : public MethodHandler {
ServerContext* ctx_;
Call call_;
experimental::RpcAllocatorInfo<RequestType, ResponseType>* allocator_info_;
experimental::MessageAllocator<RequestType, ResponseType>* allocator_;
experimental::MessageHolder<RequestType, ResponseType>* const
allocator_state_;
std::function<void()> call_requester_;
std::atomic_int callbacks_outstanding_{
2}; // reserve for Finish and CompletionOp

@ -25,6 +25,7 @@
#include <google/protobuf/arena.h>
#include <grpc/impl/codegen/log.h>
#include <gtest/gtest.h>
#include <grpcpp/channel.h>
@ -62,11 +63,9 @@ class CallbackTestServiceImpl
public:
explicit CallbackTestServiceImpl() {}
void SetFreeRequest() { free_request_ = true; }
void SetAllocatorMutator(
std::function<void(void* allocator_state, const EchoRequest* req,
EchoResponse* resp)>
std::function<void(experimental::RpcAllocatorState* allocator_state,
const EchoRequest* req, EchoResponse* resp)>
mutator) {
allocator_mutator_ = mutator;
}
@ -75,18 +74,15 @@ class CallbackTestServiceImpl
EchoResponse* response,
experimental::ServerCallbackRpcController* controller) override {
response->set_message(request->message());
if (free_request_) {
controller->FreeRequest();
} else if (allocator_mutator_) {
allocator_mutator_(controller->GetAllocatorState(), request, response);
if (allocator_mutator_) {
allocator_mutator_(controller->GetRpcAllocatorState(), request, response);
}
controller->Finish(Status::OK);
}
private:
bool free_request_ = false;
std::function<void(void* allocator_state, const EchoRequest* req,
EchoResponse* resp)>
std::function<void(experimental::RpcAllocatorState* allocator_state,
const EchoRequest* req, EchoResponse* resp)>
allocator_mutator_;
};
@ -230,26 +226,44 @@ class SimpleAllocatorTest : public MessageAllocatorEnd2endTestBase {
class SimpleAllocator
: public experimental::MessageAllocator<EchoRequest, EchoResponse> {
public:
void AllocateMessages(
experimental::RpcAllocatorInfo<EchoRequest, EchoResponse>* info) {
allocation_count++;
info->request = new EchoRequest;
info->response = new EchoResponse;
info->allocator_state = info;
}
void DeallocateRequest(
experimental::RpcAllocatorInfo<EchoRequest, EchoResponse>* info) {
request_deallocation_count++;
delete info->request;
info->request = nullptr;
}
void DeallocateMessages(
experimental::RpcAllocatorInfo<EchoRequest, EchoResponse>* info) {
messages_deallocation_count++;
delete info->request;
delete info->response;
class MessageHolderImpl
: public experimental::MessageHolder<EchoRequest, EchoResponse> {
public:
MessageHolderImpl(int* request_deallocation_count,
int* messages_deallocation_count)
: request_deallocation_count_(request_deallocation_count),
messages_deallocation_count_(messages_deallocation_count) {
set_request(new EchoRequest);
set_response(new EchoResponse);
}
void Release() override {
(*messages_deallocation_count_)++;
delete request();
delete response();
delete this;
}
void FreeRequest() override {
(*request_deallocation_count_)++;
delete request();
set_request(nullptr);
}
EchoRequest* ReleaseRequest() {
auto* ret = request();
set_request(nullptr);
return ret;
}
private:
int* request_deallocation_count_;
int* messages_deallocation_count_;
};
experimental::MessageHolder<EchoRequest, EchoResponse>* AllocateMessages()
override {
allocation_count++;
return new MessageHolderImpl(&request_deallocation_count,
&messages_deallocation_count);
}
int allocation_count = 0;
int request_deallocation_count = 0;
int messages_deallocation_count = 0;
@ -272,7 +286,16 @@ TEST_P(SimpleAllocatorTest, RpcWithEarlyFreeRequest) {
MAYBE_SKIP_TEST;
const int kRpcCount = 10;
std::unique_ptr<SimpleAllocator> allocator(new SimpleAllocator);
callback_service_.SetFreeRequest();
auto mutator = [](experimental::RpcAllocatorState* allocator_state,
const EchoRequest* req, EchoResponse* resp) {
auto* info =
static_cast<SimpleAllocator::MessageHolderImpl*>(allocator_state);
EXPECT_EQ(req, info->request());
EXPECT_EQ(resp, info->response());
allocator_state->FreeRequest();
EXPECT_EQ(nullptr, info->request());
};
callback_service_.SetAllocatorMutator(mutator);
CreateServer(allocator.get());
ResetStub();
SendRpcs(kRpcCount);
@ -286,17 +309,15 @@ TEST_P(SimpleAllocatorTest, RpcWithReleaseRequest) {
const int kRpcCount = 10;
std::unique_ptr<SimpleAllocator> allocator(new SimpleAllocator);
std::vector<EchoRequest*> released_requests;
auto mutator = [&released_requests](void* allocator_state,
const EchoRequest* req,
EchoResponse* resp) {
auto mutator = [&released_requests](
experimental::RpcAllocatorState* allocator_state,
const EchoRequest* req, EchoResponse* resp) {
auto* info =
static_cast<experimental::RpcAllocatorInfo<EchoRequest, EchoResponse>*>(
allocator_state);
EXPECT_EQ(req, info->request);
EXPECT_EQ(resp, info->response);
EXPECT_EQ(allocator_state, info->allocator_state);
released_requests.push_back(info->request);
info->request = nullptr;
static_cast<SimpleAllocator::MessageHolderImpl*>(allocator_state);
EXPECT_EQ(req, info->request());
EXPECT_EQ(resp, info->response());
released_requests.push_back(info->ReleaseRequest());
EXPECT_EQ(nullptr, info->request());
};
callback_service_.SetAllocatorMutator(mutator);
CreateServer(allocator.get());
@ -316,30 +337,27 @@ class ArenaAllocatorTest : public MessageAllocatorEnd2endTestBase {
class ArenaAllocator
: public experimental::MessageAllocator<EchoRequest, EchoResponse> {
public:
void AllocateMessages(
experimental::RpcAllocatorInfo<EchoRequest, EchoResponse>* info) {
allocation_count++;
auto* arena = new google::protobuf::Arena;
info->allocator_state = arena;
info->request =
google::protobuf::Arena::CreateMessage<EchoRequest>(arena);
info->response =
google::protobuf::Arena::CreateMessage<EchoResponse>(arena);
}
void DeallocateRequest(
experimental::RpcAllocatorInfo<EchoRequest, EchoResponse>* info) {
GPR_ASSERT(0);
}
void DeallocateMessages(
experimental::RpcAllocatorInfo<EchoRequest, EchoResponse>* info) {
deallocation_count++;
auto* arena =
static_cast<google::protobuf::Arena*>(info->allocator_state);
delete arena;
class MessageHolderImpl
: public experimental::MessageHolder<EchoRequest, EchoResponse> {
public:
MessageHolderImpl() {
set_request(
google::protobuf::Arena::CreateMessage<EchoRequest>(&arena_));
set_response(
google::protobuf::Arena::CreateMessage<EchoResponse>(&arena_));
}
void Release() override { delete this; }
void FreeRequest() override { GPR_ASSERT(0); }
private:
google::protobuf::Arena arena_;
};
experimental::MessageHolder<EchoRequest, EchoResponse>* AllocateMessages()
override {
allocation_count++;
return new MessageHolderImpl;
}
int allocation_count = 0;
int deallocation_count = 0;
};
};
@ -351,7 +369,6 @@ TEST_P(ArenaAllocatorTest, SimpleRpc) {
ResetStub();
SendRpcs(kRpcCount);
EXPECT_EQ(kRpcCount, allocator->allocation_count);
EXPECT_EQ(kRpcCount, allocator->deallocation_count);
}
std::vector<TestScenario> CreateTestScenarios(bool test_insecure) {

Loading…
Cancel
Save