diff --git a/include/grpcpp/impl/codegen/channel_interface.h b/include/grpcpp/impl/codegen/channel_interface.h index 728a7b90496..5353f5feaa6 100644 --- a/include/grpcpp/impl/codegen/channel_interface.h +++ b/include/grpcpp/impl/codegen/channel_interface.h @@ -21,7 +21,6 @@ #include #include -#include #include #include diff --git a/include/grpcpp/impl/codegen/client_context.h b/include/grpcpp/impl/codegen/client_context.h index 142cfa35dd0..0a71f3d9b6a 100644 --- a/include/grpcpp/impl/codegen/client_context.h +++ b/include/grpcpp/impl/codegen/client_context.h @@ -46,6 +46,7 @@ #include #include #include +#include #include #include #include @@ -418,12 +419,13 @@ class ClientContext { void set_call(grpc_call* call, const std::shared_ptr& channel); experimental::ClientRpcInfo* set_client_rpc_info( - const char* method, grpc::ChannelInterface* channel, + const char* method, internal::RpcMethod::RpcType type, + grpc::ChannelInterface* channel, const std::vector< std::unique_ptr>& creators, size_t interceptor_pos) { - rpc_info_ = experimental::ClientRpcInfo(this, method, channel); + rpc_info_ = experimental::ClientRpcInfo(this, type, method, channel); rpc_info_.RegisterInterceptors(creators, interceptor_pos); return &rpc_info_; } diff --git a/include/grpcpp/impl/codegen/client_interceptor.h b/include/grpcpp/impl/codegen/client_interceptor.h index f69c99ab229..2bae11a2516 100644 --- a/include/grpcpp/impl/codegen/client_interceptor.h +++ b/include/grpcpp/impl/codegen/client_interceptor.h @@ -23,6 +23,7 @@ #include #include +#include #include namespace grpc { @@ -52,23 +53,56 @@ extern experimental::ClientInterceptorFactoryInterface* namespace experimental { class ClientRpcInfo { public: - ClientRpcInfo() {} + // TODO(yashykt): Stop default-constructing ClientRpcInfo and remove UNKNOWN + // from the list of possible Types. + enum class Type { + UNARY, + CLIENT_STREAMING, + SERVER_STREAMING, + BIDI_STREAMING, + UNKNOWN // UNKNOWN is not API and will be removed later + }; ~ClientRpcInfo(){}; ClientRpcInfo(const ClientRpcInfo&) = delete; ClientRpcInfo(ClientRpcInfo&&) = default; - ClientRpcInfo& operator=(ClientRpcInfo&&) = default; // Getter methods - const char* method() { return method_; } + const char* method() const { return method_; } ChannelInterface* channel() { return channel_; } grpc::ClientContext* client_context() { return ctx_; } + Type type() const { return type_; } private: - ClientRpcInfo(grpc::ClientContext* ctx, const char* method, - grpc::ChannelInterface* channel) - : ctx_(ctx), method_(method), channel_(channel) {} + static_assert(Type::UNARY == + static_cast(internal::RpcMethod::NORMAL_RPC), + "violated expectation about Type enum"); + static_assert(Type::CLIENT_STREAMING == + static_cast(internal::RpcMethod::CLIENT_STREAMING), + "violated expectation about Type enum"); + static_assert(Type::SERVER_STREAMING == + static_cast(internal::RpcMethod::SERVER_STREAMING), + "violated expectation about Type enum"); + static_assert(Type::BIDI_STREAMING == + static_cast(internal::RpcMethod::BIDI_STREAMING), + "violated expectation about Type enum"); + + // Default constructor should only be used by ClientContext + ClientRpcInfo() = default; + + // Constructor will only be called from ClientContext + ClientRpcInfo(grpc::ClientContext* ctx, internal::RpcMethod::RpcType type, + const char* method, grpc::ChannelInterface* channel) + : ctx_(ctx), + type_(static_cast(type)), + method_(method), + channel_(channel) {} + + // Move assignment should only be used by ClientContext + // TODO(yashykt): Delete move assignment + ClientRpcInfo& operator=(ClientRpcInfo&&) = default; + // Runs interceptor at pos \a pos. void RunInterceptor( experimental::InterceptorBatchMethods* interceptor_methods, size_t pos) { @@ -97,6 +131,8 @@ class ClientRpcInfo { } grpc::ClientContext* ctx_ = nullptr; + // TODO(yashykt): make type_ const once move-assignment is deleted + Type type_{Type::UNKNOWN}; const char* method_ = nullptr; grpc::ChannelInterface* channel_ = nullptr; std::vector> interceptors_; diff --git a/include/grpcpp/impl/codegen/server_context.h b/include/grpcpp/impl/codegen/server_context.h index 4a5f9e2dd91..ccb5925e7d8 100644 --- a/include/grpcpp/impl/codegen/server_context.h +++ b/include/grpcpp/impl/codegen/server_context.h @@ -314,12 +314,12 @@ class ServerContext { uint32_t initial_metadata_flags() const { return 0; } experimental::ServerRpcInfo* set_server_rpc_info( - const char* method, + const char* method, internal::RpcMethod::RpcType type, const std::vector< std::unique_ptr>& creators) { if (creators.size() != 0) { - rpc_info_ = new experimental::ServerRpcInfo(this, method); + rpc_info_ = new experimental::ServerRpcInfo(this, method, type); rpc_info_->RegisterInterceptors(creators); } return rpc_info_; diff --git a/include/grpcpp/impl/codegen/server_interceptor.h b/include/grpcpp/impl/codegen/server_interceptor.h index 5fb5df28b70..cd7c0600b60 100644 --- a/include/grpcpp/impl/codegen/server_interceptor.h +++ b/include/grpcpp/impl/codegen/server_interceptor.h @@ -23,6 +23,7 @@ #include #include +#include #include namespace grpc { @@ -44,6 +45,8 @@ class ServerInterceptorFactoryInterface { class ServerRpcInfo { public: + enum class Type { UNARY, CLIENT_STREAMING, SERVER_STREAMING, BIDI_STREAMING }; + ~ServerRpcInfo(){}; ServerRpcInfo(const ServerRpcInfo&) = delete; @@ -51,12 +54,27 @@ class ServerRpcInfo { ServerRpcInfo& operator=(ServerRpcInfo&&) = default; // Getter methods - const char* method() { return method_; } + const char* method() const { return method_; } + Type type() const { return type_; } grpc::ServerContext* server_context() { return ctx_; } private: - ServerRpcInfo(grpc::ServerContext* ctx, const char* method) - : ctx_(ctx), method_(method) { + static_assert(Type::UNARY == + static_cast(internal::RpcMethod::NORMAL_RPC), + "violated expectation about Type enum"); + static_assert(Type::CLIENT_STREAMING == + static_cast(internal::RpcMethod::CLIENT_STREAMING), + "violated expectation about Type enum"); + static_assert(Type::SERVER_STREAMING == + static_cast(internal::RpcMethod::SERVER_STREAMING), + "violated expectation about Type enum"); + static_assert(Type::BIDI_STREAMING == + static_cast(internal::RpcMethod::BIDI_STREAMING), + "violated expectation about Type enum"); + + ServerRpcInfo(grpc::ServerContext* ctx, const char* method, + internal::RpcMethod::RpcType type) + : ctx_(ctx), method_(method), type_(static_cast(type)) { ref_.store(1); } @@ -86,6 +104,7 @@ class ServerRpcInfo { grpc::ServerContext* ctx_ = nullptr; const char* method_ = nullptr; + const Type type_; std::atomic_int ref_; std::vector> interceptors_; diff --git a/include/grpcpp/impl/codegen/server_interface.h b/include/grpcpp/impl/codegen/server_interface.h index 55c94f4d2fa..e0e26298270 100644 --- a/include/grpcpp/impl/codegen/server_interface.h +++ b/include/grpcpp/impl/codegen/server_interface.h @@ -174,13 +174,14 @@ class ServerInterface : public internal::CallHook { bool done_intercepting_; }; + /// RegisteredAsyncRequest is not part of the C++ API class RegisteredAsyncRequest : public BaseAsyncRequest { public: RegisteredAsyncRequest(ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag, - const char* name); + const char* name, internal::RpcMethod::RpcType type); virtual bool FinalizeResult(void** tag, bool* status) override { /* If we are done intercepting, then there is nothing more for us to do */ @@ -189,7 +190,7 @@ class ServerInterface : public internal::CallHook { } call_wrapper_ = internal::Call( call_, server_, call_cq_, server_->max_receive_message_size(), - context_->set_server_rpc_info(name_, + context_->set_server_rpc_info(name_, type_, *server_->interceptor_creators())); return BaseAsyncRequest::FinalizeResult(tag, status); } @@ -198,6 +199,7 @@ class ServerInterface : public internal::CallHook { void IssueRequest(void* registered_method, grpc_byte_buffer** payload, ServerCompletionQueue* notification_cq); const char* name_; + const internal::RpcMethod::RpcType type_; }; class NoPayloadAsyncRequest final : public RegisteredAsyncRequest { @@ -207,9 +209,9 @@ class ServerInterface : public internal::CallHook { internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag) - : RegisteredAsyncRequest(server, context, stream, call_cq, - notification_cq, tag, - registered_method->name()) { + : RegisteredAsyncRequest( + server, context, stream, call_cq, notification_cq, tag, + registered_method->name(), registered_method->method_type()) { IssueRequest(registered_method->server_tag(), nullptr, notification_cq); } @@ -225,9 +227,9 @@ class ServerInterface : public internal::CallHook { CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag, Message* request) - : RegisteredAsyncRequest(server, context, stream, call_cq, - notification_cq, tag, - registered_method->name()), + : RegisteredAsyncRequest( + server, context, stream, call_cq, notification_cq, tag, + registered_method->name(), registered_method->method_type()), registered_method_(registered_method), server_(server), context_(context), diff --git a/src/cpp/client/channel_cc.cc b/src/cpp/client/channel_cc.cc index d1c55319f7c..a31d0b30b15 100644 --- a/src/cpp/client/channel_cc.cc +++ b/src/cpp/client/channel_cc.cc @@ -149,8 +149,9 @@ internal::Call Channel::CreateCallInternal(const internal::RpcMethod& method, // ClientRpcInfo should be set before call because set_call also checks // whether the call has been cancelled, and if the call was cancelled, we // should notify the interceptors too/ - auto* info = context->set_client_rpc_info( - method.name(), this, interceptor_creators_, interceptor_pos); + auto* info = + context->set_client_rpc_info(method.name(), method.method_type(), this, + interceptor_creators_, interceptor_pos); context->set_call(c_call, shared_from_this()); return internal::Call(c_call, this, cq, info); diff --git a/src/cpp/server/server_cc.cc b/src/cpp/server/server_cc.cc index 69af43a6562..1e3c57446f4 100644 --- a/src/cpp/server/server_cc.cc +++ b/src/cpp/server/server_cc.cc @@ -236,9 +236,10 @@ class Server::SyncRequest final : public internal::CompletionQueueTag { : nullptr), request_(nullptr), method_(mrd->method_), - call_(mrd->call_, server, &cq_, server->max_receive_message_size(), - ctx_.set_server_rpc_info(method_->name(), - server->interceptor_creators_)), + call_( + mrd->call_, server, &cq_, server->max_receive_message_size(), + ctx_.set_server_rpc_info(method_->name(), method_->method_type(), + server->interceptor_creators_)), server_(server), global_callbacks_(nullptr), resources_(false) { @@ -427,7 +428,8 @@ class Server::CallbackRequest final : public internal::CompletionQueueTag { req_->call_, req_->server_, req_->cq_, req_->server_->max_receive_message_size(), req_->ctx_.set_server_rpc_info( - req_->method_->name(), req_->server_->interceptor_creators_)); + req_->method_->name(), req_->method_->method_type(), + req_->server_->interceptor_creators_)); req_->interceptor_methods_.SetCall(call_); req_->interceptor_methods_.SetReverse(); @@ -1041,10 +1043,12 @@ void ServerInterface::BaseAsyncRequest:: ServerInterface::RegisteredAsyncRequest::RegisteredAsyncRequest( ServerInterface* server, ServerContext* context, internal::ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, - ServerCompletionQueue* notification_cq, void* tag, const char* name) + ServerCompletionQueue* notification_cq, void* tag, const char* name, + internal::RpcMethod::RpcType type) : BaseAsyncRequest(server, context, stream, call_cq, notification_cq, tag, true), - name_(name) {} + name_(name), + type_(type) {} void ServerInterface::RegisteredAsyncRequest::IssueRequest( void* registered_method, grpc_byte_buffer** payload, @@ -1091,6 +1095,7 @@ bool ServerInterface::GenericAsyncRequest::FinalizeResult(void** tag, call_, server_, call_cq_, server_->max_receive_message_size(), context_->set_server_rpc_info( static_cast(context_)->method_.c_str(), + internal::RpcMethod::BIDI_STREAMING, *server_->interceptor_creators())); return BaseAsyncRequest::FinalizeResult(tag, status); } diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc index 3a191d1e038..f55ece1c2bf 100644 --- a/test/cpp/end2end/client_interceptors_end2end_test.cc +++ b/test/cpp/end2end/client_interceptors_end2end_test.cc @@ -50,6 +50,7 @@ class HijackingInterceptor : public experimental::Interceptor { info_ = info; // Make sure it is the right method EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0); + EXPECT_EQ(info->type(), experimental::ClientRpcInfo::Type::UNARY); } virtual void Intercept(experimental::InterceptorBatchMethods* methods) { diff --git a/test/cpp/end2end/server_interceptors_end2end_test.cc b/test/cpp/end2end/server_interceptors_end2end_test.cc index c98b6143c66..1e0e3668707 100644 --- a/test/cpp/end2end/server_interceptors_end2end_test.cc +++ b/test/cpp/end2end/server_interceptors_end2end_test.cc @@ -44,7 +44,32 @@ namespace { class LoggingInterceptor : public experimental::Interceptor { public: - LoggingInterceptor(experimental::ServerRpcInfo* info) { info_ = info; } + LoggingInterceptor(experimental::ServerRpcInfo* info) { + info_ = info; + + // Check the method name and compare to the type + const char* method = info->method(); + experimental::ServerRpcInfo::Type type = info->type(); + + // Check that we use one of our standard methods with expected type. + // We accept BIDI_STREAMING for Echo in case it's an AsyncGenericService + // being tested (the GenericRpc test). + // The empty method is for the Unimplemented requests that arise + // when draining the CQ. + EXPECT_TRUE( + (strcmp(method, "/grpc.testing.EchoTestService/Echo") == 0 && + (type == experimental::ServerRpcInfo::Type::UNARY || + type == experimental::ServerRpcInfo::Type::BIDI_STREAMING)) || + (strcmp(method, "/grpc.testing.EchoTestService/RequestStream") == 0 && + type == experimental::ServerRpcInfo::Type::CLIENT_STREAMING) || + (strcmp(method, "/grpc.testing.EchoTestService/ResponseStream") == 0 && + type == experimental::ServerRpcInfo::Type::SERVER_STREAMING) || + (strcmp(method, "/grpc.testing.EchoTestService/BidiStream") == 0 && + type == experimental::ServerRpcInfo::Type::BIDI_STREAMING) || + strcmp(method, "/grpc.testing.EchoTestService/Unimplemented") == 0 || + (strcmp(method, "") == 0 && + type == experimental::ServerRpcInfo::Type::BIDI_STREAMING)); + } virtual void Intercept(experimental::InterceptorBatchMethods* methods) { if (methods->QueryInterceptionHookPoint(