diff --git a/include/grpc++/impl/service_type.h b/include/grpc++/impl/service_type.h index 655aa91cdc1..acc30cc7558 100644 --- a/include/grpc++/impl/service_type.h +++ b/include/grpc++/impl/service_type.h @@ -66,7 +66,7 @@ class Service { bool has_async_methods() const { for (auto it = methods_.begin(); it != methods_.end(); ++it) { - if ((*it)->handler() == nullptr) { + if (*it && (*it)->handler() == nullptr) { return true; } } @@ -75,7 +75,7 @@ class Service { bool has_synchronous_methods() const { for (auto it = methods_.begin(); it != methods_.end(); ++it) { - if ((*it)->handler() != nullptr) { + if (*it && (*it)->handler() != nullptr) { return true; } } @@ -120,14 +120,20 @@ class Service { void AddMethod(RpcServiceMethod* method) { methods_.emplace_back(method); } - void MarkMethodAsync(const grpc::string& method_name) { - for (auto it = methods_.begin(); it != methods_.end(); ++it) { - if ((*it)->name() == method_name) { - (*it)->ResetHandler(); - return; - } + void MarkMethodAsync(int index) { + if (methods_[index].get() == nullptr) { + gpr_log(GPR_ERROR, "A method cannot be marked async and generic."); + abort(); + } + methods_[index]->ResetHandler(); + } + + void MarkMethodGeneric(int index) { + if (methods_[index]->handler() == nullptr) { + gpr_log(GPR_ERROR, "A method cannot be marked async and generic."); + abort(); } - abort(); + methods_[index].reset(); } private: diff --git a/include/grpc++/server.h b/include/grpc++/server.h index 92d7a4b3cc5..31d0a4b24b1 100644 --- a/include/grpc++/server.h +++ b/include/grpc++/server.h @@ -264,6 +264,7 @@ class Server GRPC_FINAL : public GrpcLibrary, private CallHook { CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag, Message* message) { + GPR_ASSERT(method); new PayloadAsyncRequest(method->server_tag(), this, context, stream, call_cq, notification_cq, tag, message); @@ -273,6 +274,7 @@ class Server GRPC_FINAL : public GrpcLibrary, private CallHook { ServerAsyncStreamingInterface* stream, CompletionQueue* call_cq, ServerCompletionQueue* notification_cq, void* tag) { + GPR_ASSERT(method); new NoPayloadAsyncRequest(method->server_tag(), this, context, stream, call_cq, notification_cq, tag); } diff --git a/src/compiler/cpp_generator.cc b/src/compiler/cpp_generator.cc index 9d0d7eb4699..7a298574e57 100644 --- a/src/compiler/cpp_generator.cc +++ b/src/compiler/cpp_generator.cc @@ -501,8 +501,7 @@ void PrintHeaderServerMethodAsync( printer->Indent(); printer->Print(*vars, "WithAsyncMethod_$Method$() {\n" - " ::grpc::Service::MarkMethodAsync(" - "\"/$Package$$Service$/$Method$\");\n" + " ::grpc::Service::MarkMethodAsync($Idx$);\n" "}\n"); printer->Print(*vars, "~WithAsyncMethod_$Method$() GRPC_OVERRIDE {\n" @@ -601,6 +600,79 @@ void PrintHeaderServerMethodAsync( printer->Print(*vars, "};\n"); } +void PrintHeaderServerMethodGeneric( + grpc::protobuf::io::Printer *printer, + const grpc::protobuf::MethodDescriptor *method, + std::map *vars) { + (*vars)["Method"] = method->name(); + (*vars)["Request"] = + grpc_cpp_generator::ClassName(method->input_type(), true); + (*vars)["Response"] = + grpc_cpp_generator::ClassName(method->output_type(), true); + printer->Print(*vars, "template \n"); + printer->Print(*vars, + "class WithGenericMethod_$Method$ : public BaseClass {\n"); + printer->Print( + " private:\n" + " void BaseClassMustBeDerivedFromService(Service *service) {}\n"); + printer->Print(" public:\n"); + printer->Indent(); + printer->Print(*vars, + "WithGenericMethod_$Method$() {\n" + " ::grpc::Service::MarkMethodGeneric($Idx$);\n" + "}\n"); + printer->Print(*vars, + "~WithGenericMethod_$Method$() GRPC_OVERRIDE {\n" + " BaseClassMustBeDerivedFromService(this);\n" + "}\n"); + if (NoStreaming(method)) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* context, const $Request$* request, " + "$Response$* response) GRPC_FINAL GRPC_OVERRIDE {\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + } else if (ClientOnlyStreaming(method)) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* context, " + "::grpc::ServerReader< $Request$>* reader, " + "$Response$* response) GRPC_FINAL GRPC_OVERRIDE {\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + } else if (ServerOnlyStreaming(method)) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* context, const $Request$* request, " + "::grpc::ServerWriter< $Response$>* writer) GRPC_FINAL GRPC_OVERRIDE " + "{\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + } else if (BidiStreaming(method)) { + printer->Print( + *vars, + "// disable synchronous version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* context, " + "::grpc::ServerReaderWriter< $Response$, $Request$>* stream) " + "GRPC_FINAL GRPC_OVERRIDE {\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + } + printer->Outdent(); + printer->Print(*vars, "};\n"); +} + void PrintHeaderService(grpc::protobuf::io::Printer *printer, const grpc::protobuf::ServiceDescriptor *service, std::map *vars) { @@ -686,6 +758,12 @@ void PrintHeaderService(grpc::protobuf::io::Printer *printer, } printer->Print(" AsyncService;\n"); + // Server side - Generic + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["Idx"] = as_string(i); + PrintHeaderServerMethodGeneric(printer, service->method(i), vars); + } + printer->Outdent(); printer->Print("};\n"); }