Allow the interceptor to know the method type

pull/17430/head
Vijay Pai 6 years ago
parent e97c9457e2
commit 97de30d7b3
  1. 1
      include/grpcpp/impl/codegen/channel_interface.h
  2. 6
      include/grpcpp/impl/codegen/client_context.h
  3. 48
      include/grpcpp/impl/codegen/client_interceptor.h
  4. 4
      include/grpcpp/impl/codegen/server_context.h
  5. 25
      include/grpcpp/impl/codegen/server_interceptor.h
  6. 18
      include/grpcpp/impl/codegen/server_interface.h
  7. 5
      src/cpp/client/channel_cc.cc
  8. 17
      src/cpp/server/server_cc.cc
  9. 1
      test/cpp/end2end/client_interceptors_end2end_test.cc
  10. 27
      test/cpp/end2end/server_interceptors_end2end_test.cc

@ -21,7 +21,6 @@
#include <grpc/impl/codegen/connectivity_state.h>
#include <grpcpp/impl/codegen/call.h>
#include <grpcpp/impl/codegen/client_context.h>
#include <grpcpp/impl/codegen/status.h>
#include <grpcpp/impl/codegen/time.h>

@ -46,6 +46,7 @@
#include <grpcpp/impl/codegen/core_codegen_interface.h>
#include <grpcpp/impl/codegen/create_auth_context.h>
#include <grpcpp/impl/codegen/metadata_map.h>
#include <grpcpp/impl/codegen/rpc_method.h>
#include <grpcpp/impl/codegen/security/auth_context.h>
#include <grpcpp/impl/codegen/slice.h>
#include <grpcpp/impl/codegen/status.h>
@ -418,12 +419,13 @@ class ClientContext {
void set_call(grpc_call* call, const std::shared_ptr<Channel>& channel);
experimental::ClientRpcInfo* set_client_rpc_info(
const char* method, grpc::ChannelInterface* channel,
const char* method, internal::RpcMethod::RpcType type,
grpc::ChannelInterface* channel,
const std::vector<
std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>&
creators,
size_t interceptor_pos) {
rpc_info_ = experimental::ClientRpcInfo(this, method, channel);
rpc_info_ = experimental::ClientRpcInfo(this, type, method, channel);
rpc_info_.RegisterInterceptors(creators, interceptor_pos);
return &rpc_info_;
}

@ -23,6 +23,7 @@
#include <vector>
#include <grpcpp/impl/codegen/interceptor.h>
#include <grpcpp/impl/codegen/rpc_method.h>
#include <grpcpp/impl/codegen/string_ref.h>
namespace grpc {
@ -52,23 +53,56 @@ extern experimental::ClientInterceptorFactoryInterface*
namespace experimental {
class ClientRpcInfo {
public:
ClientRpcInfo() {}
// TODO(yashykt): Stop default-constructing ClientRpcInfo and remove UNKNOWN
// from the list of possible Types.
enum class Type {
UNARY,
CLIENT_STREAMING,
SERVER_STREAMING,
BIDI_STREAMING,
UNKNOWN // UNKNOWN is not API and will be removed later
};
~ClientRpcInfo(){};
ClientRpcInfo(const ClientRpcInfo&) = delete;
ClientRpcInfo(ClientRpcInfo&&) = default;
ClientRpcInfo& operator=(ClientRpcInfo&&) = default;
// Getter methods
const char* method() { return method_; }
const char* method() const { return method_; }
ChannelInterface* channel() { return channel_; }
grpc::ClientContext* client_context() { return ctx_; }
Type type() const { return type_; }
private:
ClientRpcInfo(grpc::ClientContext* ctx, const char* method,
grpc::ChannelInterface* channel)
: ctx_(ctx), method_(method), channel_(channel) {}
static_assert(Type::UNARY ==
static_cast<Type>(internal::RpcMethod::NORMAL_RPC),
"violated expectation about Type enum");
static_assert(Type::CLIENT_STREAMING ==
static_cast<Type>(internal::RpcMethod::CLIENT_STREAMING),
"violated expectation about Type enum");
static_assert(Type::SERVER_STREAMING ==
static_cast<Type>(internal::RpcMethod::SERVER_STREAMING),
"violated expectation about Type enum");
static_assert(Type::BIDI_STREAMING ==
static_cast<Type>(internal::RpcMethod::BIDI_STREAMING),
"violated expectation about Type enum");
// Default constructor should only be used by ClientContext
ClientRpcInfo() = default;
// Constructor will only be called from ClientContext
ClientRpcInfo(grpc::ClientContext* ctx, internal::RpcMethod::RpcType type,
const char* method, grpc::ChannelInterface* channel)
: ctx_(ctx),
type_(static_cast<Type>(type)),
method_(method),
channel_(channel) {}
// Move assignment should only be used by ClientContext
// TODO(yashykt): Delete move assignment
ClientRpcInfo& operator=(ClientRpcInfo&&) = default;
// Runs interceptor at pos \a pos.
void RunInterceptor(
experimental::InterceptorBatchMethods* interceptor_methods, size_t pos) {
@ -97,6 +131,8 @@ class ClientRpcInfo {
}
grpc::ClientContext* ctx_ = nullptr;
// TODO(yashykt): make type_ const once move-assignment is deleted
Type type_{Type::UNKNOWN};
const char* method_ = nullptr;
grpc::ChannelInterface* channel_ = nullptr;
std::vector<std::unique_ptr<experimental::Interceptor>> interceptors_;

@ -314,12 +314,12 @@ class ServerContext {
uint32_t initial_metadata_flags() const { return 0; }
experimental::ServerRpcInfo* set_server_rpc_info(
const char* method,
const char* method, internal::RpcMethod::RpcType type,
const std::vector<
std::unique_ptr<experimental::ServerInterceptorFactoryInterface>>&
creators) {
if (creators.size() != 0) {
rpc_info_ = new experimental::ServerRpcInfo(this, method);
rpc_info_ = new experimental::ServerRpcInfo(this, method, type);
rpc_info_->RegisterInterceptors(creators);
}
return rpc_info_;

@ -23,6 +23,7 @@
#include <vector>
#include <grpcpp/impl/codegen/interceptor.h>
#include <grpcpp/impl/codegen/rpc_method.h>
#include <grpcpp/impl/codegen/string_ref.h>
namespace grpc {
@ -44,6 +45,8 @@ class ServerInterceptorFactoryInterface {
class ServerRpcInfo {
public:
enum class Type { UNARY, CLIENT_STREAMING, SERVER_STREAMING, BIDI_STREAMING };
~ServerRpcInfo(){};
ServerRpcInfo(const ServerRpcInfo&) = delete;
@ -51,12 +54,27 @@ class ServerRpcInfo {
ServerRpcInfo& operator=(ServerRpcInfo&&) = default;
// Getter methods
const char* method() { return method_; }
const char* method() const { return method_; }
Type type() const { return type_; }
grpc::ServerContext* server_context() { return ctx_; }
private:
ServerRpcInfo(grpc::ServerContext* ctx, const char* method)
: ctx_(ctx), method_(method) {
static_assert(Type::UNARY ==
static_cast<Type>(internal::RpcMethod::NORMAL_RPC),
"violated expectation about Type enum");
static_assert(Type::CLIENT_STREAMING ==
static_cast<Type>(internal::RpcMethod::CLIENT_STREAMING),
"violated expectation about Type enum");
static_assert(Type::SERVER_STREAMING ==
static_cast<Type>(internal::RpcMethod::SERVER_STREAMING),
"violated expectation about Type enum");
static_assert(Type::BIDI_STREAMING ==
static_cast<Type>(internal::RpcMethod::BIDI_STREAMING),
"violated expectation about Type enum");
ServerRpcInfo(grpc::ServerContext* ctx, const char* method,
internal::RpcMethod::RpcType type)
: ctx_(ctx), method_(method), type_(static_cast<Type>(type)) {
ref_.store(1);
}
@ -86,6 +104,7 @@ class ServerRpcInfo {
grpc::ServerContext* ctx_ = nullptr;
const char* method_ = nullptr;
const Type type_;
std::atomic_int ref_;
std::vector<std::unique_ptr<experimental::Interceptor>> interceptors_;

@ -174,13 +174,14 @@ class ServerInterface : public internal::CallHook {
bool done_intercepting_;
};
/// RegisteredAsyncRequest is not part of the C++ API
class RegisteredAsyncRequest : public BaseAsyncRequest {
public:
RegisteredAsyncRequest(ServerInterface* server, ServerContext* context,
internal::ServerAsyncStreamingInterface* stream,
CompletionQueue* call_cq,
ServerCompletionQueue* notification_cq, void* tag,
const char* name);
const char* name, internal::RpcMethod::RpcType type);
virtual bool FinalizeResult(void** tag, bool* status) override {
/* If we are done intercepting, then there is nothing more for us to do */
@ -189,7 +190,7 @@ class ServerInterface : public internal::CallHook {
}
call_wrapper_ = internal::Call(
call_, server_, call_cq_, server_->max_receive_message_size(),
context_->set_server_rpc_info(name_,
context_->set_server_rpc_info(name_, type_,
*server_->interceptor_creators()));
return BaseAsyncRequest::FinalizeResult(tag, status);
}
@ -198,6 +199,7 @@ class ServerInterface : public internal::CallHook {
void IssueRequest(void* registered_method, grpc_byte_buffer** payload,
ServerCompletionQueue* notification_cq);
const char* name_;
const internal::RpcMethod::RpcType type_;
};
class NoPayloadAsyncRequest final : public RegisteredAsyncRequest {
@ -207,9 +209,9 @@ class ServerInterface : public internal::CallHook {
internal::ServerAsyncStreamingInterface* stream,
CompletionQueue* call_cq,
ServerCompletionQueue* notification_cq, void* tag)
: RegisteredAsyncRequest(server, context, stream, call_cq,
notification_cq, tag,
registered_method->name()) {
: RegisteredAsyncRequest(
server, context, stream, call_cq, notification_cq, tag,
registered_method->name(), registered_method->method_type()) {
IssueRequest(registered_method->server_tag(), nullptr, notification_cq);
}
@ -225,9 +227,9 @@ class ServerInterface : public internal::CallHook {
CompletionQueue* call_cq,
ServerCompletionQueue* notification_cq, void* tag,
Message* request)
: RegisteredAsyncRequest(server, context, stream, call_cq,
notification_cq, tag,
registered_method->name()),
: RegisteredAsyncRequest(
server, context, stream, call_cq, notification_cq, tag,
registered_method->name(), registered_method->method_type()),
registered_method_(registered_method),
server_(server),
context_(context),

@ -149,8 +149,9 @@ internal::Call Channel::CreateCallInternal(const internal::RpcMethod& method,
// ClientRpcInfo should be set before call because set_call also checks
// whether the call has been cancelled, and if the call was cancelled, we
// should notify the interceptors too/
auto* info = context->set_client_rpc_info(
method.name(), this, interceptor_creators_, interceptor_pos);
auto* info =
context->set_client_rpc_info(method.name(), method.method_type(), this,
interceptor_creators_, interceptor_pos);
context->set_call(c_call, shared_from_this());
return internal::Call(c_call, this, cq, info);

@ -236,9 +236,10 @@ class Server::SyncRequest final : public internal::CompletionQueueTag {
: nullptr),
request_(nullptr),
method_(mrd->method_),
call_(mrd->call_, server, &cq_, server->max_receive_message_size(),
ctx_.set_server_rpc_info(method_->name(),
server->interceptor_creators_)),
call_(
mrd->call_, server, &cq_, server->max_receive_message_size(),
ctx_.set_server_rpc_info(method_->name(), method_->method_type(),
server->interceptor_creators_)),
server_(server),
global_callbacks_(nullptr),
resources_(false) {
@ -427,7 +428,8 @@ class Server::CallbackRequest final : public internal::CompletionQueueTag {
req_->call_, req_->server_, req_->cq_,
req_->server_->max_receive_message_size(),
req_->ctx_.set_server_rpc_info(
req_->method_->name(), req_->server_->interceptor_creators_));
req_->method_->name(), req_->method_->method_type(),
req_->server_->interceptor_creators_));
req_->interceptor_methods_.SetCall(call_);
req_->interceptor_methods_.SetReverse();
@ -1041,10 +1043,12 @@ void ServerInterface::BaseAsyncRequest::
ServerInterface::RegisteredAsyncRequest::RegisteredAsyncRequest(
ServerInterface* server, ServerContext* context,
internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq,
ServerCompletionQueue* notification_cq, void* tag, const char* name)
ServerCompletionQueue* notification_cq, void* tag, const char* name,
internal::RpcMethod::RpcType type)
: BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag,
true),
name_(name) {}
name_(name),
type_(type) {}
void ServerInterface::RegisteredAsyncRequest::IssueRequest(
void* registered_method, grpc_byte_buffer** payload,
@ -1091,6 +1095,7 @@ bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag,
call_, server_, call_cq_, server_->max_receive_message_size(),
context_->set_server_rpc_info(
static_cast<GenericServerContext*>(context_)->method_.c_str(),
internal::RpcMethod::BIDI_STREAMING,
*server_->interceptor_creators()));
return BaseAsyncRequest::FinalizeResult(tag, status);
}

@ -50,6 +50,7 @@ class HijackingInterceptor : public experimental::Interceptor {
info_ = info;
// Make sure it is the right method
EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
EXPECT_EQ(info->type(), experimental::ClientRpcInfo::Type::UNARY);
}
virtual void Intercept(experimental::InterceptorBatchMethods* methods) {

@ -44,7 +44,32 @@ namespace {
class LoggingInterceptor : public experimental::Interceptor {
public:
LoggingInterceptor(experimental::ServerRpcInfo* info) { info_ = info; }
LoggingInterceptor(experimental::ServerRpcInfo* info) {
info_ = info;
// Check the method name and compare to the type
const char* method = info->method();
experimental::ServerRpcInfo::Type type = info->type();
// Check that we use one of our standard methods with expected type.
// We accept BIDI_STREAMING for Echo in case it's an AsyncGenericService
// being tested (the GenericRpc test).
// The empty method is for the Unimplemented requests that arise
// when draining the CQ.
EXPECT_TRUE(
(strcmp(method, "/grpc.testing.EchoTestService/Echo") == 0 &&
(type == experimental::ServerRpcInfo::Type::UNARY ||
type == experimental::ServerRpcInfo::Type::BIDI_STREAMING)) ||
(strcmp(method, "/grpc.testing.EchoTestService/RequestStream") == 0 &&
type == experimental::ServerRpcInfo::Type::CLIENT_STREAMING) ||
(strcmp(method, "/grpc.testing.EchoTestService/ResponseStream") == 0 &&
type == experimental::ServerRpcInfo::Type::SERVER_STREAMING) ||
(strcmp(method, "/grpc.testing.EchoTestService/BidiStream") == 0 &&
type == experimental::ServerRpcInfo::Type::BIDI_STREAMING) ||
strcmp(method, "/grpc.testing.EchoTestService/Unimplemented") == 0 ||
(strcmp(method, "") == 0 &&
type == experimental::ServerRpcInfo::Type::BIDI_STREAMING));
}
virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
if (methods->QueryInterceptionHookPoint(

Loading…
Cancel
Save