Allow the interceptor to know the method type

pull/17430/head
Vijay Pai 6 years ago
parent e97c9457e2
commit 97de30d7b3
  1. 1
      include/grpcpp/impl/codegen/channel_interface.h
  2. 6
      include/grpcpp/impl/codegen/client_context.h
  3. 48
      include/grpcpp/impl/codegen/client_interceptor.h
  4. 4
      include/grpcpp/impl/codegen/server_context.h
  5. 25
      include/grpcpp/impl/codegen/server_interceptor.h
  6. 18
      include/grpcpp/impl/codegen/server_interface.h
  7. 5
      src/cpp/client/channel_cc.cc
  8. 17
      src/cpp/server/server_cc.cc
  9. 1
      test/cpp/end2end/client_interceptors_end2end_test.cc
  10. 27
      test/cpp/end2end/server_interceptors_end2end_test.cc

@ -21,7 +21,6 @@
#include <grpc/impl/codegen/connectivity_state.h> #include <grpc/impl/codegen/connectivity_state.h>
#include <grpcpp/impl/codegen/call.h> #include <grpcpp/impl/codegen/call.h>
#include <grpcpp/impl/codegen/client_context.h>
#include <grpcpp/impl/codegen/status.h> #include <grpcpp/impl/codegen/status.h>
#include <grpcpp/impl/codegen/time.h> #include <grpcpp/impl/codegen/time.h>

@ -46,6 +46,7 @@
#include <grpcpp/impl/codegen/core_codegen_interface.h> #include <grpcpp/impl/codegen/core_codegen_interface.h>
#include <grpcpp/impl/codegen/create_auth_context.h> #include <grpcpp/impl/codegen/create_auth_context.h>
#include <grpcpp/impl/codegen/metadata_map.h> #include <grpcpp/impl/codegen/metadata_map.h>
#include <grpcpp/impl/codegen/rpc_method.h>
#include <grpcpp/impl/codegen/security/auth_context.h> #include <grpcpp/impl/codegen/security/auth_context.h>
#include <grpcpp/impl/codegen/slice.h> #include <grpcpp/impl/codegen/slice.h>
#include <grpcpp/impl/codegen/status.h> #include <grpcpp/impl/codegen/status.h>
@ -418,12 +419,13 @@ class ClientContext {
void set_call(grpc_call* call, const std::shared_ptr<Channel>& channel); void set_call(grpc_call* call, const std::shared_ptr<Channel>& channel);
experimental::ClientRpcInfo* set_client_rpc_info( experimental::ClientRpcInfo* set_client_rpc_info(
const char* method, grpc::ChannelInterface* channel, const char* method, internal::RpcMethod::RpcType type,
grpc::ChannelInterface* channel,
const std::vector< const std::vector<
std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>& std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>&
creators, creators,
size_t interceptor_pos) { size_t interceptor_pos) {
rpc_info_ = experimental::ClientRpcInfo(this, method, channel); rpc_info_ = experimental::ClientRpcInfo(this, type, method, channel);
rpc_info_.RegisterInterceptors(creators, interceptor_pos); rpc_info_.RegisterInterceptors(creators, interceptor_pos);
return &rpc_info_; return &rpc_info_;
} }

@ -23,6 +23,7 @@
#include <vector> #include <vector>
#include <grpcpp/impl/codegen/interceptor.h> #include <grpcpp/impl/codegen/interceptor.h>
#include <grpcpp/impl/codegen/rpc_method.h>
#include <grpcpp/impl/codegen/string_ref.h> #include <grpcpp/impl/codegen/string_ref.h>
namespace grpc { namespace grpc {
@ -52,23 +53,56 @@ extern experimental::ClientInterceptorFactoryInterface*
namespace experimental { namespace experimental {
class ClientRpcInfo { class ClientRpcInfo {
public: public:
ClientRpcInfo() {} // TODO(yashykt): Stop default-constructing ClientRpcInfo and remove UNKNOWN
// from the list of possible Types.
enum class Type {
UNARY,
CLIENT_STREAMING,
SERVER_STREAMING,
BIDI_STREAMING,
UNKNOWN // UNKNOWN is not API and will be removed later
};
~ClientRpcInfo(){}; ~ClientRpcInfo(){};
ClientRpcInfo(const ClientRpcInfo&) = delete; ClientRpcInfo(const ClientRpcInfo&) = delete;
ClientRpcInfo(ClientRpcInfo&&) = default; ClientRpcInfo(ClientRpcInfo&&) = default;
ClientRpcInfo& operator=(ClientRpcInfo&&) = default;
// Getter methods // Getter methods
const char* method() { return method_; } const char* method() const { return method_; }
ChannelInterface* channel() { return channel_; } ChannelInterface* channel() { return channel_; }
grpc::ClientContext* client_context() { return ctx_; } grpc::ClientContext* client_context() { return ctx_; }
Type type() const { return type_; }
private: private:
ClientRpcInfo(grpc::ClientContext* ctx, const char* method, static_assert(Type::UNARY ==
grpc::ChannelInterface* channel) static_cast<Type>(internal::RpcMethod::NORMAL_RPC),
: ctx_(ctx), method_(method), channel_(channel) {} "violated expectation about Type enum");
static_assert(Type::CLIENT_STREAMING ==
static_cast<Type>(internal::RpcMethod::CLIENT_STREAMING),
"violated expectation about Type enum");
static_assert(Type::SERVER_STREAMING ==
static_cast<Type>(internal::RpcMethod::SERVER_STREAMING),
"violated expectation about Type enum");
static_assert(Type::BIDI_STREAMING ==
static_cast<Type>(internal::RpcMethod::BIDI_STREAMING),
"violated expectation about Type enum");
// Default constructor should only be used by ClientContext
ClientRpcInfo() = default;
// Constructor will only be called from ClientContext
ClientRpcInfo(grpc::ClientContext* ctx, internal::RpcMethod::RpcType type,
const char* method, grpc::ChannelInterface* channel)
: ctx_(ctx),
type_(static_cast<Type>(type)),
method_(method),
channel_(channel) {}
// Move assignment should only be used by ClientContext
// TODO(yashykt): Delete move assignment
ClientRpcInfo& operator=(ClientRpcInfo&&) = default;
// Runs interceptor at pos \a pos. // Runs interceptor at pos \a pos.
void RunInterceptor( void RunInterceptor(
experimental::InterceptorBatchMethods* interceptor_methods, size_t pos) { experimental::InterceptorBatchMethods* interceptor_methods, size_t pos) {
@ -97,6 +131,8 @@ class ClientRpcInfo {
} }
grpc::ClientContext* ctx_ = nullptr; grpc::ClientContext* ctx_ = nullptr;
// TODO(yashykt): make type_ const once move-assignment is deleted
Type type_{Type::UNKNOWN};
const char* method_ = nullptr; const char* method_ = nullptr;
grpc::ChannelInterface* channel_ = nullptr; grpc::ChannelInterface* channel_ = nullptr;
std::vector<std::unique_ptr<experimental::Interceptor>> interceptors_; std::vector<std::unique_ptr<experimental::Interceptor>> interceptors_;

@ -314,12 +314,12 @@ class ServerContext {
uint32_t initial_metadata_flags() const { return 0; } uint32_t initial_metadata_flags() const { return 0; }
experimental::ServerRpcInfo* set_server_rpc_info( experimental::ServerRpcInfo* set_server_rpc_info(
const char* method, const char* method, internal::RpcMethod::RpcType type,
const std::vector< const std::vector<
std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>& std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>&
creators) { creators) {
if (creators.size() != 0) { if (creators.size() != 0) {
rpc_info_ = new experimental::ServerRpcInfo(this, method); rpc_info_ = new experimental::ServerRpcInfo(this, method, type);
rpc_info_->RegisterInterceptors(creators); rpc_info_->RegisterInterceptors(creators);
} }
return rpc_info_; return rpc_info_;

@ -23,6 +23,7 @@
#include <vector> #include <vector>
#include <grpcpp/impl/codegen/interceptor.h> #include <grpcpp/impl/codegen/interceptor.h>
#include <grpcpp/impl/codegen/rpc_method.h>
#include <grpcpp/impl/codegen/string_ref.h> #include <grpcpp/impl/codegen/string_ref.h>
namespace grpc { namespace grpc {
@ -44,6 +45,8 @@ class ServerInterceptorFactoryInterface {
class ServerRpcInfo { class ServerRpcInfo {
public: public:
enum class Type { UNARY, CLIENT_STREAMING, SERVER_STREAMING, BIDI_STREAMING };
~ServerRpcInfo(){}; ~ServerRpcInfo(){};
ServerRpcInfo(const ServerRpcInfo&) = delete; ServerRpcInfo(const ServerRpcInfo&) = delete;
@ -51,12 +54,27 @@ class ServerRpcInfo {
ServerRpcInfo& operator=(ServerRpcInfo&&) = default; ServerRpcInfo& operator=(ServerRpcInfo&&) = default;
// Getter methods // Getter methods
const char* method() { return method_; } const char* method() const { return method_; }
Type type() const { return type_; }
grpc::ServerContext* server_context() { return ctx_; } grpc::ServerContext* server_context() { return ctx_; }
private: private:
ServerRpcInfo(grpc::ServerContext* ctx, const char* method) static_assert(Type::UNARY ==
: ctx_(ctx), method_(method) { static_cast<Type>(internal::RpcMethod::NORMAL_RPC),
"violated expectation about Type enum");
static_assert(Type::CLIENT_STREAMING ==
static_cast<Type>(internal::RpcMethod::CLIENT_STREAMING),
"violated expectation about Type enum");
static_assert(Type::SERVER_STREAMING ==
static_cast<Type>(internal::RpcMethod::SERVER_STREAMING),
"violated expectation about Type enum");
static_assert(Type::BIDI_STREAMING ==
static_cast<Type>(internal::RpcMethod::BIDI_STREAMING),
"violated expectation about Type enum");
ServerRpcInfo(grpc::ServerContext* ctx, const char* method,
internal::RpcMethod::RpcType type)
: ctx_(ctx), method_(method), type_(static_cast<Type>(type)) {
ref_.store(1); ref_.store(1);
} }
@ -86,6 +104,7 @@ class ServerRpcInfo {
grpc::ServerContext* ctx_ = nullptr; grpc::ServerContext* ctx_ = nullptr;
const char* method_ = nullptr; const char* method_ = nullptr;
const Type type_;
std::atomic_int ref_; std::atomic_int ref_;
std::vector<std::unique_ptr<experimental::Interceptor>> interceptors_; std::vector<std::unique_ptr<experimental::Interceptor>> interceptors_;

@ -174,13 +174,14 @@ class ServerInterface : public internal::CallHook {
bool done_intercepting_; bool done_intercepting_;
}; };
/// RegisteredAsyncRequest is not part of the C++ API
class RegisteredAsyncRequest : public BaseAsyncRequest { class RegisteredAsyncRequest : public BaseAsyncRequest {
public: public:
RegisteredAsyncRequest(ServerInterface* server, ServerContext* context, RegisteredAsyncRequest(ServerInterface* server, ServerContext* context,
internal::ServerAsyncStreamingInterface* stream, internal::ServerAsyncStreamingInterface* stream,
CompletionQueue* call_cq, CompletionQueue* call_cq,
ServerCompletionQueue* notification_cq, void* tag, ServerCompletionQueue* notification_cq, void* tag,
const char* name); const char* name, internal::RpcMethod::RpcType type);
virtual bool FinalizeResult(void** tag, bool* status) override { virtual bool FinalizeResult(void** tag, bool* status) override {
/* If we are done intercepting, then there is nothing more for us to do */ /* If we are done intercepting, then there is nothing more for us to do */
@ -189,7 +190,7 @@ class ServerInterface : public internal::CallHook {
} }
call_wrapper_ = internal::Call( call_wrapper_ = internal::Call(
call_, server_, call_cq_, server_->max_receive_message_size(), call_, server_, call_cq_, server_->max_receive_message_size(),
context_->set_server_rpc_info(name_, context_->set_server_rpc_info(name_, type_,
*server_->interceptor_creators())); *server_->interceptor_creators()));
return BaseAsyncRequest::FinalizeResult(tag, status); return BaseAsyncRequest::FinalizeResult(tag, status);
} }
@ -198,6 +199,7 @@ class ServerInterface : public internal::CallHook {
void IssueRequest(void* registered_method, grpc_byte_buffer** payload, void IssueRequest(void* registered_method, grpc_byte_buffer** payload,
ServerCompletionQueue* notification_cq); ServerCompletionQueue* notification_cq);
const char* name_; const char* name_;
const internal::RpcMethod::RpcType type_;
}; };
class NoPayloadAsyncRequest final : public RegisteredAsyncRequest { class NoPayloadAsyncRequest final : public RegisteredAsyncRequest {
@ -207,9 +209,9 @@ class ServerInterface : public internal::CallHook {
internal::ServerAsyncStreamingInterface* stream, internal::ServerAsyncStreamingInterface* stream,
CompletionQueue* call_cq, CompletionQueue* call_cq,
ServerCompletionQueue* notification_cq, void* tag) ServerCompletionQueue* notification_cq, void* tag)
: RegisteredAsyncRequest(server, context, stream, call_cq, : RegisteredAsyncRequest(
notification_cq, tag, server, context, stream, call_cq, notification_cq, tag,
registered_method->name()) { registered_method->name(), registered_method->method_type()) {
IssueRequest(registered_method->server_tag(), nullptr, notification_cq); IssueRequest(registered_method->server_tag(), nullptr, notification_cq);
} }
@ -225,9 +227,9 @@ class ServerInterface : public internal::CallHook {
CompletionQueue* call_cq, CompletionQueue* call_cq,
ServerCompletionQueue* notification_cq, void* tag, ServerCompletionQueue* notification_cq, void* tag,
Message* request) Message* request)
: RegisteredAsyncRequest(server, context, stream, call_cq, : RegisteredAsyncRequest(
notification_cq, tag, server, context, stream, call_cq, notification_cq, tag,
registered_method->name()), registered_method->name(), registered_method->method_type()),
registered_method_(registered_method), registered_method_(registered_method),
server_(server), server_(server),
context_(context), context_(context),

@ -149,8 +149,9 @@ internal::Call Channel::CreateCallInternal(const internal::RpcMethod& method,
// ClientRpcInfo should be set before call because set_call also checks // ClientRpcInfo should be set before call because set_call also checks
// whether the call has been cancelled, and if the call was cancelled, we // whether the call has been cancelled, and if the call was cancelled, we
// should notify the interceptors too/ // should notify the interceptors too/
auto* info = context->set_client_rpc_info( auto* info =
method.name(), this, interceptor_creators_, interceptor_pos); context->set_client_rpc_info(method.name(), method.method_type(), this,
interceptor_creators_, interceptor_pos);
context->set_call(c_call, shared_from_this()); context->set_call(c_call, shared_from_this());
return internal::Call(c_call, this, cq, info); return internal::Call(c_call, this, cq, info);

@ -236,9 +236,10 @@ class Server::SyncRequest final : public internal::CompletionQueueTag {
: nullptr), : nullptr),
request_(nullptr), request_(nullptr),
method_(mrd->method_), method_(mrd->method_),
call_(mrd->call_, server, &cq_, server->max_receive_message_size(), call_(
ctx_.set_server_rpc_info(method_->name(), mrd->call_, server, &cq_, server->max_receive_message_size(),
server->interceptor_creators_)), ctx_.set_server_rpc_info(method_->name(), method_->method_type(),
server->interceptor_creators_)),
server_(server), server_(server),
global_callbacks_(nullptr), global_callbacks_(nullptr),
resources_(false) { resources_(false) {
@ -427,7 +428,8 @@ class Server::CallbackRequest final : public internal::CompletionQueueTag {
req_->call_, req_->server_, req_->cq_, req_->call_, req_->server_, req_->cq_,
req_->server_->max_receive_message_size(), req_->server_->max_receive_message_size(),
req_->ctx_.set_server_rpc_info( req_->ctx_.set_server_rpc_info(
req_->method_->name(), req_->server_->interceptor_creators_)); req_->method_->name(), req_->method_->method_type(),
req_->server_->interceptor_creators_));
req_->interceptor_methods_.SetCall(call_); req_->interceptor_methods_.SetCall(call_);
req_->interceptor_methods_.SetReverse(); req_->interceptor_methods_.SetReverse();
@ -1041,10 +1043,12 @@ void ServerInterface::BaseAsyncRequest::
ServerInterface::RegisteredAsyncRequest::RegisteredAsyncRequest( ServerInterface::RegisteredAsyncRequest::RegisteredAsyncRequest(
ServerInterface* server, ServerContext* context, ServerInterface* server, ServerContext* context,
internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq,
ServerCompletionQueue* notification_cq, void* tag, const char* name) ServerCompletionQueue* notification_cq, void* tag, const char* name,
internal::RpcMethod::RpcType type)
: BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag, : BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag,
true), true),
name_(name) {} name_(name),
type_(type) {}
void ServerInterface::RegisteredAsyncRequest::IssueRequest( void ServerInterface::RegisteredAsyncRequest::IssueRequest(
void* registered_method, grpc_byte_buffer** payload, void* registered_method, grpc_byte_buffer** payload,
@ -1091,6 +1095,7 @@ bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag,
call_, server_, call_cq_, server_->max_receive_message_size(), call_, server_, call_cq_, server_->max_receive_message_size(),
context_->set_server_rpc_info( context_->set_server_rpc_info(
static_cast<GenericServerContext*>(context_)->method_.c_str(), static_cast<GenericServerContext*>(context_)->method_.c_str(),
internal::RpcMethod::BIDI_STREAMING,
*server_->interceptor_creators())); *server_->interceptor_creators()));
return BaseAsyncRequest::FinalizeResult(tag, status); return BaseAsyncRequest::FinalizeResult(tag, status);
} }

@ -50,6 +50,7 @@ class HijackingInterceptor : public experimental::Interceptor {
info_ = info; info_ = info;
// Make sure it is the right method // Make sure it is the right method
EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0); EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
EXPECT_EQ(info->type(), experimental::ClientRpcInfo::Type::UNARY);
} }
virtual void Intercept(experimental::InterceptorBatchMethods* methods) { virtual void Intercept(experimental::InterceptorBatchMethods* methods) {

@ -44,7 +44,32 @@ namespace {
class LoggingInterceptor : public experimental::Interceptor { class LoggingInterceptor : public experimental::Interceptor {
public: public:
LoggingInterceptor(experimental::ServerRpcInfo* info) { info_ = info; } LoggingInterceptor(experimental::ServerRpcInfo* info) {
info_ = info;
// Check the method name and compare to the type
const char* method = info->method();
experimental::ServerRpcInfo::Type type = info->type();
// Check that we use one of our standard methods with expected type.
// We accept BIDI_STREAMING for Echo in case it's an AsyncGenericService
// being tested (the GenericRpc test).
// The empty method is for the Unimplemented requests that arise
// when draining the CQ.
EXPECT_TRUE(
(strcmp(method, "/grpc.testing.EchoTestService/Echo") == 0 &&
(type == experimental::ServerRpcInfo::Type::UNARY ||
type == experimental::ServerRpcInfo::Type::BIDI_STREAMING)) ||
(strcmp(method, "/grpc.testing.EchoTestService/RequestStream") == 0 &&
type == experimental::ServerRpcInfo::Type::CLIENT_STREAMING) ||
(strcmp(method, "/grpc.testing.EchoTestService/ResponseStream") == 0 &&
type == experimental::ServerRpcInfo::Type::SERVER_STREAMING) ||
(strcmp(method, "/grpc.testing.EchoTestService/BidiStream") == 0 &&
type == experimental::ServerRpcInfo::Type::BIDI_STREAMING) ||
strcmp(method, "/grpc.testing.EchoTestService/Unimplemented") == 0 ||
(strcmp(method, "") == 0 &&
type == experimental::ServerRpcInfo::Type::BIDI_STREAMING));
}
virtual void Intercept(experimental::InterceptorBatchMethods* methods) { virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
if (methods->QueryInterceptionHookPoint( if (methods->QueryInterceptionHookPoint(

Loading…
Cancel
Save