diff --git a/include/grpc++/impl/codegen/method_handler_impl.h b/include/grpc++/impl/codegen/method_handler_impl.h index 52f927631c9..8f417f671aa 100644 --- a/include/grpc++/impl/codegen/method_handler_impl.h +++ b/include/grpc++/impl/codegen/method_handler_impl.h @@ -236,6 +236,20 @@ class StreamedUnaryHandler ServerUnaryStreamer, true>(func) {} }; +template +class SplitServerStreamingHandler + : public TemplatedBidiStreamingHandler< + ServerSplitStreamer, false> { + public: + explicit SplitServerStreamingHandler( + std::function*)> + func) + : TemplatedBidiStreamingHandler< + ServerSplitStreamer, false>(func) {} +}; + // Handle unknown method by returning UNIMPLEMENTED error. class UnknownMethodHandler : public MethodHandler { public: diff --git a/include/grpc++/impl/codegen/service_type.h b/include/grpc++/impl/codegen/service_type.h index 72b22253128..de0862fce2b 100644 --- a/include/grpc++/impl/codegen/service_type.h +++ b/include/grpc++/impl/codegen/service_type.h @@ -147,14 +147,16 @@ class Service { methods_[index].reset(); } - void MarkMethodStreamedUnary(int index, - MethodHandler* streamed_unary_method) { + void MarkMethodStreamed(int index, + MethodHandler* streamed_method) { GPR_CODEGEN_ASSERT(methods_[index] && methods_[index]->handler() && - "Cannot mark an async or generic method Streamed Unary"); - methods_[index]->SetHandler(streamed_unary_method); + "Cannot mark an async or generic method Streamed"); + methods_[index]->SetHandler(streamed_method); // From the server's point of view, streamed unary is a special - // case of BIDI_STREAMING that has 1 read and 1 write, in that order. + // case of BIDI_STREAMING that has 1 read and 1 write, in that order, + // and split server-side streaming is BIDI_STREAMING with 1 read and + // any number of writes, in that order methods_[index]->SetMethodType(::grpc::RpcMethod::BIDI_STREAMING); } diff --git a/include/grpc++/impl/codegen/sync_stream.h b/include/grpc++/impl/codegen/sync_stream.h index e3c5a919b1e..f2b1e9a3adf 100644 --- a/include/grpc++/impl/codegen/sync_stream.h +++ b/include/grpc++/impl/codegen/sync_stream.h @@ -538,7 +538,7 @@ class ServerReaderWriter GRPC_FINAL : public ServerReaderWriterInterface { /// the \a NextMessageSize method to determine an upper-bound on the size of /// the message. /// A key difference relative to streaming: ServerUnaryStreamer -/// must have exactly 1 Read and exactly 1 Write, in that order, to function +/// must have exactly 1 Read and exactly 1 Write, in that order, to function /// correctly. Otherwise, the RPC is in error. template class ServerUnaryStreamer GRPC_FINAL @@ -577,6 +577,43 @@ class ServerUnaryStreamer GRPC_FINAL bool write_done_; }; +/// A class to represent a flow-controlled server-side streaming call. +/// This is something of a hybrid between server-side and bidi streaming. +/// This is invoked through a server-side streaming call on the client side, +/// but the server responds to it as though it were a bidi streaming call that +/// must first have exactly 1 Read and then any number of Writes. +template +class ServerSplitStreamer GRPC_FINAL + : public ServerReaderWriterInterface { + public: + ServerSplitStreamer(Call* call, ServerContext* ctx) + : body_(call, ctx), read_done_(false) {} + + void SendInitialMetadata() GRPC_OVERRIDE { body_.SendInitialMetadata(); } + + bool NextMessageSize(uint32_t* sz) GRPC_OVERRIDE { + return body_.NextMessageSize(sz); + } + + bool Read(RequestType* request) GRPC_OVERRIDE { + if (read_done_) { + return false; + } + read_done_ = true; + return body_.Read(request); + } + + using WriterInterface::Write; + bool Write(const ResponseType& response, + const WriteOptions& options) GRPC_OVERRIDE { + return !read_done_ && body_.Write(response, options); + } + + private: + internal::ServerReaderWriterBody body_; + bool read_done_; +}; + } // namespace grpc #endif // GRPCXX_IMPL_CODEGEN_SYNC_STREAM_H diff --git a/src/compiler/cpp_generator.cc b/src/compiler/cpp_generator.cc index d0c35ea1ab3..d0f164e58df 100644 --- a/src/compiler/cpp_generator.cc +++ b/src/compiler/cpp_generator.cc @@ -605,7 +605,7 @@ void PrintHeaderServerMethodAsync(Printer *printer, const Method *method, printer->Print(*vars, "};\n"); } -void PrintHeaderServerMethodStreamedUnary( + void PrintHeaderServerMethodStreamedUnary( Printer *printer, const Method *method, std::map *vars) { (*vars)["Method"] = method->name(); @@ -624,7 +624,7 @@ void PrintHeaderServerMethodStreamedUnary( printer->Indent(); printer->Print(*vars, "WithStreamedUnaryMethod_$Method$() {\n" - " ::grpc::Service::MarkMethodStreamedUnary($Idx$,\n" + " ::grpc::Service::MarkMethodStreamed($Idx$,\n" " new ::grpc::StreamedUnaryHandler< $Request$, " "$Response$>(std::bind" "(&WithStreamedUnaryMethod_$Method$::" @@ -656,6 +656,57 @@ void PrintHeaderServerMethodStreamedUnary( } } +void PrintHeaderServerMethodSplitStreaming( + Printer *printer, const Method *method, + std::map *vars) { + (*vars)["Method"] = method->name(); + (*vars)["Request"] = method->input_type_name(); + (*vars)["Response"] = method->output_type_name(); + if (method->ServerOnlyStreaming()) { + printer->Print(*vars, "template \n"); + printer->Print(*vars, + "class WithSplitStreamingMethod_$Method$ : " + "public BaseClass {\n"); + printer->Print( + " private:\n" + " void BaseClassMustBeDerivedFromService(const Service *service) " + "{}\n"); + printer->Print(" public:\n"); + printer->Indent(); + printer->Print(*vars, + "WithSplitStreamingMethod_$Method$() {\n" + " ::grpc::Service::MarkMethodStreamed($Idx$,\n" + " new ::grpc::SplitServerStreamingHandler< $Request$, " + "$Response$>(std::bind" + "(&WithSplitStreamingMethod_$Method$::" + "Streamed$Method$, this, std::placeholders::_1, " + "std::placeholders::_2)));\n" + "}\n"); + printer->Print(*vars, + "~WithSplitStreamingMethod_$Method$() GRPC_OVERRIDE {\n" + " BaseClassMustBeDerivedFromService(this);\n" + "}\n"); + printer->Print( + *vars, + "// disable regular version of this method\n" + "::grpc::Status $Method$(" + "::grpc::ServerContext* context, const $Request$* request, " + "$Response$* response) GRPC_FINAL GRPC_OVERRIDE {\n" + " abort();\n" + " return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, \"\");\n" + "}\n"); + printer->Print(*vars, + "// replace default version of method with split streamed\n" + "virtual ::grpc::Status Streamed$Method$(" + "::grpc::ServerContext* context, " + "::grpc::ServerSplitStreamer< " + "$Request$,$Response$>* server_split_streamer)" + " = 0;\n"); + printer->Outdent(); + printer->Print(*vars, "};\n"); + } +} + void PrintHeaderServerMethodGeneric( Printer *printer, const Method *method, std::map *vars) { @@ -844,6 +895,48 @@ void PrintHeaderService(Printer *printer, const Service *service, } printer->Print(" StreamedUnaryService;\n"); + // Server side - controlled server-side streaming + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["Idx"] = as_string(i); + PrintHeaderServerMethodSplitStreaming(printer, service->method(i).get(), + vars); + } + + printer->Print("typedef "); + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["method_name"] = service->method(i).get()->name(); + if (service->method(i)->ServerOnlyStreaming()) { + printer->Print(*vars, "WithSplitStreamingMethod_$method_name$<"); + } + } + printer->Print("Service"); + for (int i = 0; i < service->method_count(); ++i) { + if (service->method(i)->ServerOnlyStreaming()) { + printer->Print(" >"); + } + } + printer->Print(" SplitStreamedService;\n"); + + // Server side - typedef for controlled both unary and server-side streaming + printer->Print("typedef "); + for (int i = 0; i < service->method_count(); ++i) { + (*vars)["method_name"] = service->method(i).get()->name(); + if (service->method(i)->ServerOnlyStreaming()) { + printer->Print(*vars, "WithSplitStreamingMethod_$method_name$<"); + } + if (service->method(i)->NoStreaming()) { + printer->Print(*vars, "WithStreamedUnaryMethod_$method_name$<"); + } + } + printer->Print("Service"); + for (int i = 0; i < service->method_count(); ++i) { + if (service->method(i)->NoStreaming() || + service->method(i)->ServerOnlyStreaming()) { + printer->Print(" >"); + } + } + printer->Print(" StreamedService;\n"); + printer->Outdent(); printer->Print("};\n"); printer->Print(service->GetTrailingComments().c_str());