Server side interception for CompletionOp and AsyncRequest

pull/16842/head
Yash Tibrewal 6 years ago
parent adca91f6cf
commit 456231b26d
  1. 12
      include/grpcpp/impl/codegen/call.h
  2. 85
      include/grpcpp/impl/codegen/server_interface.h
  3. 4
      include/grpcpp/server.h
  4. 99
      src/cpp/server/server_cc.cc
  5. 74
      src/cpp/server/server_context.cc

@ -1004,9 +1004,17 @@ class InterceptorBatchMethodsImpl
/* Returns true if no interceptors are run. Returns false otherwise if there
are interceptors registered. After the interceptors are done running \a f will
be invoked. This is to be used only by BaseAsyncRequest and SyncRequest. */
bool RunInterceptors(std::function<void(internal::CompletionQueueTag*)> f) {
bool RunInterceptors(std::function<void(void)> f) {
GPR_CODEGEN_ASSERT(reverse_ == true);
return true;
GPR_CODEGEN_ASSERT(call_->client_rpc_info() == nullptr);
auto* server_rpc_info = call_->server_rpc_info();
if (server_rpc_info == nullptr ||
server_rpc_info->interceptors_.size() == 0) {
return true;
}
callback_ = std::move(f);
RunServerInterceptors();
return false;
}
private:

@ -20,12 +20,14 @@
#define GRPCPP_IMPL_CODEGEN_SERVER_INTERFACE_H
#include <grpc/impl/codegen/grpc_types.h>
//#include <grpcpp/alarm.h>
#include <grpcpp/impl/codegen/byte_buffer.h>
#include <grpcpp/impl/codegen/call.h>
#include <grpcpp/impl/codegen/call_hook.h>
#include <grpcpp/impl/codegen/completion_queue_tag.h>
#include <grpcpp/impl/codegen/core_codegen_interface.h>
#include <grpcpp/impl/codegen/rpc_service_method.h>
#include <grpcpp/impl/codegen/server_context.h>
namespace grpc {
@ -149,45 +151,69 @@ class ServerInterface : public internal::CallHook {
public:
BaseAsyncRequest(ServerInterface* server, ServerContext* context,
internal::ServerAsyncStreamingInterface* stream,
CompletionQueue* call_cq, void* tag,
CompletionQueue* call_cq,
ServerCompletionQueue* notification_cq, void* tag,
bool delete_on_finalize);
virtual ~BaseAsyncRequest();
bool FinalizeResult(void** tag, bool* status) override;
private:
void ContinueFinalizeResultAfterInterception();
protected:
ServerInterface* const server_;
ServerContext* const context_;
internal::ServerAsyncStreamingInterface* const stream_;
CompletionQueue* const call_cq_;
ServerCompletionQueue* const notification_cq_;
void* const tag_;
const bool delete_on_finalize_;
grpc_call* call_;
internal::InterceptorBatchMethodsImpl interceptor_methods;
internal::Call call_wrapper_;
internal::InterceptorBatchMethodsImpl interceptor_methods_;
bool done_intercepting_;
void* dummy_alarm_; /* This should have been Alarm, but we cannot depend on
alarm.h here */
};
class RegisteredAsyncRequest : public BaseAsyncRequest {
public:
RegisteredAsyncRequest(ServerInterface* server, ServerContext* context,
internal::ServerAsyncStreamingInterface* stream,
CompletionQueue* call_cq, void* tag);
// uses BaseAsyncRequest::FinalizeResult
CompletionQueue* call_cq,
ServerCompletionQueue* notification_cq, void* tag,
const char* name);
virtual bool FinalizeResult(void** tag, bool* status) override {
/* If we are done intercepting, then there is nothing more for us to do */
if (done_intercepting_) {
return BaseAsyncRequest::FinalizeResult(tag, status);
}
call_wrapper_ = internal::Call(
call_, server_, call_cq_, server_->max_receive_message_size(),
context_->set_server_rpc_info(experimental::ServerRpcInfo(
context_, name_, *server_->interceptor_creators())));
return BaseAsyncRequest::FinalizeResult(tag, status);
}
protected:
void IssueRequest(void* registered_method, grpc_byte_buffer** payload,
ServerCompletionQueue* notification_cq);
const char* name_;
};
class NoPayloadAsyncRequest final : public RegisteredAsyncRequest {
public:
NoPayloadAsyncRequest(void* registered_method, ServerInterface* server,
ServerContext* context,
NoPayloadAsyncRequest(internal::RpcServiceMethod* registered_method,
ServerInterface* server, ServerContext* context,
internal::ServerAsyncStreamingInterface* stream,
CompletionQueue* call_cq,
ServerCompletionQueue* notification_cq, void* tag)
: RegisteredAsyncRequest(server, context, stream, call_cq, tag) {
IssueRequest(registered_method, nullptr, notification_cq);
: RegisteredAsyncRequest(server, context, stream, call_cq,
notification_cq, tag,
registered_method->name()) {
IssueRequest(registered_method->server_tag(), nullptr, notification_cq);
}
// uses RegisteredAsyncRequest::FinalizeResult
@ -196,13 +222,15 @@ class ServerInterface : public internal::CallHook {
template <class Message>
class PayloadAsyncRequest final : public RegisteredAsyncRequest {
public:
PayloadAsyncRequest(void* registered_method, ServerInterface* server,
ServerContext* context,
PayloadAsyncRequest(internal::RpcServiceMethod* registered_method,
ServerInterface* server, ServerContext* context,
internal::ServerAsyncStreamingInterface* stream,
CompletionQueue* call_cq,
ServerCompletionQueue* notification_cq, void* tag,
Message* request)
: RegisteredAsyncRequest(server, context, stream, call_cq, tag),
: RegisteredAsyncRequest(server, context, stream, call_cq,
notification_cq, tag,
registered_method->name()),
registered_method_(registered_method),
server_(server),
context_(context),
@ -211,7 +239,8 @@ class ServerInterface : public internal::CallHook {
notification_cq_(notification_cq),
tag_(tag),
request_(request) {
IssueRequest(registered_method, payload_.bbuf_ptr(), notification_cq);
IssueRequest(registered_method->server_tag(), payload_.bbuf_ptr(),
notification_cq);
}
~PayloadAsyncRequest() {
@ -219,6 +248,10 @@ class ServerInterface : public internal::CallHook {
}
bool FinalizeResult(void** tag, bool* status) override {
/* If we are done intercepting, then there is nothing more for us to do */
if (done_intercepting_) {
return RegisteredAsyncRequest::FinalizeResult(tag, status);
}
if (*status) {
if (!payload_.Valid() || !SerializationTraits<Message>::Deserialize(
payload_.bbuf_ptr(), request_)
@ -237,15 +270,24 @@ class ServerInterface : public internal::CallHook {
return false;
}
}
call_wrapper_ = internal::Call(
call_, server_, call_cq_, server_->max_receive_message_size(),
context_->set_server_rpc_info(experimental::ServerRpcInfo(
context_, name_, *server_->interceptor_creators())));
/* Set interception point for recv message */
interceptor_methods_.AddInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_MESSAGE);
interceptor_methods_.SetRecvMessage(request_);
return RegisteredAsyncRequest::FinalizeResult(tag, status);
}
private:
void* const registered_method_;
internal::RpcServiceMethod* const registered_method_;
ServerInterface* const server_;
ServerContext* const context_;
internal::ServerAsyncStreamingInterface* const stream_;
CompletionQueue* const call_cq_;
ServerCompletionQueue* const notification_cq_;
void* const tag_;
Message* const request_;
@ -274,9 +316,8 @@ class ServerInterface : public internal::CallHook {
ServerCompletionQueue* notification_cq, void* tag,
Message* message) {
GPR_CODEGEN_ASSERT(method);
new PayloadAsyncRequest<Message>(method->server_tag(), this, context,
stream, call_cq, notification_cq, tag,
message);
new PayloadAsyncRequest<Message>(method, this, context, stream, call_cq,
notification_cq, tag, message);
}
void RequestAsyncCall(internal::RpcServiceMethod* method,
@ -285,8 +326,8 @@ class ServerInterface : public internal::CallHook {
CompletionQueue* call_cq,
ServerCompletionQueue* notification_cq, void* tag) {
GPR_CODEGEN_ASSERT(method);
new NoPayloadAsyncRequest(method->server_tag(), this, context, stream,
call_cq, notification_cq, tag);
new NoPayloadAsyncRequest(method, this, context, stream, call_cq,
notification_cq, tag);
}
void RequestAsyncGenericCall(GenericServerContext* context,
@ -298,8 +339,10 @@ class ServerInterface : public internal::CallHook {
tag, true);
}
private:
virtual const std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>* interceptor_creators() {
private:
virtual const std::vector<
std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>*
interceptor_creators() {
return nullptr;
}
};

@ -191,7 +191,9 @@ class Server : public ServerInterface, private GrpcLibraryCodegen {
grpc_server* server() override { return server_; };
private:
const std::vector<std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>* interceptor_creators() override {
const std::vector<
std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>*
interceptor_creators() override {
return &interceptor_creators_;
}

@ -24,6 +24,7 @@
#include <grpc/grpc.h>
#include <grpc/support/alloc.h>
#include <grpc/support/log.h>
#include <grpcpp/alarm.h>
#include <grpcpp/completion_queue.h>
#include <grpcpp/generic/async_generic_service.h>
#include <grpcpp/impl/codegen/async_unary_call.h>
@ -240,6 +241,8 @@ class Server::SyncRequest final : public internal::CompletionQueueTag {
global_callbacks_ = global_callbacks;
resources_ = resources;
interceptor_methods_.SetCall(&call_);
interceptor_methods_.SetReverse();
/* Set interception point for RECV INITIAL METADATA */
interceptor_methods_.AddInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA);
@ -256,8 +259,7 @@ class Server::SyncRequest final : public internal::CompletionQueueTag {
experimental::InterceptionHookPoints::POST_RECV_MESSAGE);
interceptor_methods_.SetRecvMessage(request_);
}
interceptor_methods_.SetCall(&call_);
interceptor_methods_.SetReverse();
auto f = std::bind(&CallData::ContinueRunAfterInterception, this);
if (interceptor_methods_.RunInterceptors(f)) {
ContinueRunAfterInterception();
@ -725,15 +727,21 @@ void Server::PerformOpsOnCall(internal::CallOpSetInterface* ops,
ServerInterface::BaseAsyncRequest::BaseAsyncRequest(
ServerInterface* server, ServerContext* context,
internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq,
void* tag, bool delete_on_finalize)
ServerCompletionQueue* notification_cq, void* tag, bool delete_on_finalize)
: server_(server),
context_(context),
stream_(stream),
call_cq_(call_cq),
notification_cq_(notification_cq),
tag_(tag),
delete_on_finalize_(delete_on_finalize),
call_(nullptr),
call_wrapper_() {
done_intercepting_(false) {
/* Set up interception state partially for the receive ops. call_wrapper_ is
* not filled at this point, but it will be filled before the interceptors are
* run. */
interceptor_methods_.SetCall(&call_wrapper_);
interceptor_methods_.SetReverse();
call_cq_->RegisterAvalanching(); // This op will trigger more ops
}
@ -743,17 +751,47 @@ ServerInterface::BaseAsyncRequest::~BaseAsyncRequest() {
bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag,
bool* status) {
if (done_intercepting_) {
delete static_cast<Alarm*>(dummy_alarm_);
dummy_alarm_ = nullptr;
*tag = tag_;
if (delete_on_finalize_) {
delete this;
}
return true;
}
context_->set_call(call_);
context_->cq_ = call_cq_;
internal::Call call(call_, server_, call_cq_,
server_->max_receive_message_size(), nullptr);
if (call_wrapper_.call() == nullptr) {
/* Fill it since it is empty. */
call_wrapper_ = internal::Call(
call_, server_, call_cq_, server_->max_receive_message_size(), nullptr);
}
// just the pointers inside call are copied here
stream_->BindCall(&call_wrapper_);
if (*status && call_ && call_wrapper_.server_rpc_info()) {
done_intercepting_ = true;
/* Set interception point for RECV INITIAL METADATA */
interceptor_methods_.AddInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA);
interceptor_methods_.SetRecvInitialMetadata(&context_->client_metadata_);
auto f = std::bind(&ServerInterface::BaseAsyncRequest::
ContinueFinalizeResultAfterInterception,
this);
if (interceptor_methods_.RunInterceptors(f)) {
/* There are no interceptors to run. Continue */
} else {
/* There were interceptors to be run, so
ContinueFinalizeResultAfterInterception will be run when interceptors are
done. */
return false;
}
}
if (*status && call_) {
context_->BeginCompletionOp(&call);
context_->BeginCompletionOp(&call_wrapper_);
}
// just the pointers inside call are copied here
stream_->BindCall(&call);
*tag = tag_;
if (delete_on_finalize_) {
delete this;
@ -761,11 +799,23 @@ bool ServerInterface::BaseAsyncRequest::FinalizeResult(void** tag,
return true;
}
void ServerInterface::BaseAsyncRequest::
ContinueFinalizeResultAfterInterception() {
context_->BeginCompletionOp(&call_wrapper_);
/* Queue a tag which will be returned immediately */
dummy_alarm_ = new Alarm();
static_cast<Alarm*>(dummy_alarm_)
->Set(notification_cq_,
g_core_codegen_interface->gpr_time_0(GPR_CLOCK_MONOTONIC), this);
}
ServerInterface::RegisteredAsyncRequest::RegisteredAsyncRequest(
ServerInterface* server, ServerContext* context,
internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq,
void* tag)
: BaseAsyncRequest(server, context, stream, call_cq, tag, true) {}
ServerCompletionQueue* notification_cq, void* tag, const char* name)
: BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag,
true),
name_(name) {}
void ServerInterface::RegisteredAsyncRequest::IssueRequest(
void* registered_method, grpc_byte_buffer** payload,
@ -781,7 +831,7 @@ ServerInterface::GenericAsyncRequest::GenericAsyncRequest(
ServerInterface* server, GenericServerContext* context,
internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq,
ServerCompletionQueue* notification_cq, void* tag, bool delete_on_finalize)
: BaseAsyncRequest(server, context, stream, call_cq, tag,
: BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag,
delete_on_finalize) {
grpc_call_details_init(&call_details_);
GPR_ASSERT(notification_cq);
@ -794,6 +844,10 @@ ServerInterface::GenericAsyncRequest::GenericAsyncRequest(
bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag,
bool* status) {
/* If we are done intercepting, there is nothing more for us to do */
if (done_intercepting_) {
return BaseAsyncRequest::FinalizeResult(tag, status);
}
// TODO(yangg) remove the copy here.
if (*status) {
static_cast<GenericServerContext*>(context_)->method_ =
@ -804,16 +858,27 @@ bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag,
}
grpc_slice_unref(call_details_.method);
grpc_slice_unref(call_details_.host);
call_wrapper_ = internal::Call(
call_, server_, call_cq_, server_->max_receive_message_size(),
context_->set_server_rpc_info(experimental::ServerRpcInfo(
context_,
static_cast<GenericServerContext*>(context_)->method_.c_str(),
*server_->interceptor_creators())));
return BaseAsyncRequest::FinalizeResult(tag, status);
}
bool Server::UnimplementedAsyncRequest::FinalizeResult(void** tag,
bool* status) {
if (GenericAsyncRequest::FinalizeResult(tag, status) && *status) {
new UnimplementedAsyncRequest(server_, cq_);
new UnimplementedAsyncResponse(this);
if (GenericAsyncRequest::FinalizeResult(tag, status)) {
/* We either had no interceptors run or we are done interceptinh */
if (*status) {
new UnimplementedAsyncRequest(server_, cq_);
new UnimplementedAsyncResponse(this);
} else {
delete this;
}
} else {
delete this;
/* The tag was swallowed due to interception. We will see it again. */
}
return false;
}

@ -45,8 +45,8 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface {
tag_(nullptr),
refs_(2),
finalized_(false),
cancelled_(0) /*,
done_intercepting_(false)*/ {}
cancelled_(0),
done_intercepting_(false) {}
void FillOps(internal::Call* call) override;
bool FinalizeResult(void** tag, bool* status) override;
@ -69,14 +69,32 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface {
// This will be called while interceptors are run if the RPC is a hijacked
// RPC. This should set hijacking state for each of the ops.
void SetHijackingState() override {}
void SetHijackingState() override {
/* Servers don't allow hijacking */
GPR_CODEGEN_ASSERT(false);
}
/* Should be called after interceptors are done running */
void ContinueFillOpsAfterInterception() override {}
/* Should be called after interceptors are done running on the finalize result
* path */
void ContinueFinalizeResultAfterInterception() override {}
void ContinueFinalizeResultAfterInterception() override {
done_intercepting_ = true;
if (!has_tag_) {
/* We don't have a tag to return. */
std::unique_lock<std::mutex> lock(mu_);
if (--refs_ == 0) {
lock.unlock();
delete this;
}
return;
}
/* Start a dummy op so that we can return the tag */
GPR_CODEGEN_ASSERT(GRPC_CALL_OK ==
g_core_codegen_interface->grpc_call_start_batch(
call_.call(), nullptr, 0, this, nullptr));
}
private:
bool CheckCancelledNoPluck() {
@ -90,7 +108,7 @@ class ServerContext::CompletionOp final : public internal::CallOpSetInterface {
int refs_;
bool finalized_;
int cancelled_;
// bool done_intercepting_;
bool done_intercepting_;
internal::Call call_;
internal::InterceptorBatchMethodsImpl interceptor_methods_;
};
@ -111,24 +129,52 @@ void ServerContext::CompletionOp::FillOps(internal::Call* call) {
ops.reserved = nullptr;
call_ = *call;
interceptor_methods_.SetCall(&call_);
interceptor_methods_.SetReverse();
interceptor_methods_.SetCallOpSetInterface(this);
GPR_ASSERT(GRPC_CALL_OK ==
grpc_call_start_batch(call->call(), &ops, 1, this, nullptr));
/* No interceptors to run here */
}
bool ServerContext::CompletionOp::FinalizeResult(void** tag, bool* status) {
std::unique_lock<std::mutex> lock(mu_);
finalized_ = true;
bool ret = false;
if (has_tag_) {
*tag = tag_;
ret = true;
std::unique_lock<std::mutex> lock(mu_);
if (done_intercepting_) {
/* We are done intercepting. */
if (has_tag_) {
*tag = tag_;
ret = true;
}
if (--refs_ == 0) {
lock.unlock();
delete this;
}
return ret;
}
finalized_ = true;
if (!*status) cancelled_ = 1;
if (--refs_ == 0) {
lock.unlock();
delete this;
/* Release the lock since we are going to be running through interceptors now
*/
lock.unlock();
/* Add interception point and run through interceptors */
interceptor_methods_.AddInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_CLOSE);
if (interceptor_methods_.RunInterceptors()) {
/* No interceptors were run */
if (has_tag_) {
*tag = tag_;
ret = true;
}
lock.lock();
if (--refs_ == 0) {
lock.unlock();
delete this;
}
return ret;
}
return ret;
/* There are interceptors to be run. Return false for now */
return false;
}
// ServerContext body

Loading…
Cancel
Save