diff --git a/src/compiler/python_generator.cc b/src/compiler/python_generator.cc index 9985132e372..40a77543203 100644 --- a/src/compiler/python_generator.cc +++ b/src/compiler/python_generator.cc @@ -575,10 +575,9 @@ bool PrivateGenerator::PrintAddServicerToServer( * file, with no suffixes. Since this class merely acts as a namespace, it * should never be instantiated. */ -bool PrivateGenerator::PrintServiceClass(const grpc::string& package_qualified_service_name, - const grpc_generator::Service* service, - grpc_generator::Printer* out) -{ +bool PrivateGenerator::PrintServiceClass( + const grpc::string& package_qualified_service_name, + const grpc_generator::Service* service, grpc_generator::Printer* out) { StringMap dict; dict["Service"] = service->name(); out->Print("\n\n"); @@ -609,16 +608,15 @@ bool PrivateGenerator::PrintServiceClass(const grpc::string& package_qualified_s method_dict["Method"] = method->name(); out->Print("@staticmethod\n"); out->Print(method_dict, "def $Method$("); + grpc::string request_parameter( + method->ClientStreaming() ? "request_iterator" : "request"); + StringMap args_dict; + args_dict["RequestParameter"] = request_parameter; { IndentScope args_indent(out); IndentScope args_double_indent(out); - grpc::string request_parameter(method->ClientStreaming() ? "request_iterator" : "request"); - StringMap args_dict; - args_dict["RequestParameter"] = request_parameter; out->Print(args_dict, "$RequestParameter$,\n"); out->Print("target,\n"); - out->Print("request_serializer=None,\n"); - out->Print("request_deserializer=None,\n"); out->Print("options=(),\n"); out->Print("channel_credentials=None,\n"); out->Print("call_credentials=None,\n"); @@ -632,20 +630,25 @@ bool PrivateGenerator::PrintServiceClass(const grpc::string& package_qualified_s grpc::string arity_method_name = grpc::string(method->ClientStreaming() ? "stream" : "unary") + "_" + grpc::string(method->ServerStreaming() ? "stream" : "unary"); - StringMap invocation_dict; - invocation_dict["ArityMethodName"] = arity_method_name; - invocation_dict["PackageQualifiedService"] = package_qualified_service_name; - invocation_dict["Method"] = method->name(); - out->Print(invocation_dict, "return grpc.experimental.$ArityMethodName$(request, target, '/$PackageQualifiedService$/$Method$',\n"); + args_dict["ArityMethodName"] = arity_method_name; + args_dict["PackageQualifiedService"] = package_qualified_service_name; + args_dict["Method"] = method->name(); + out->Print(args_dict, + "return " + "grpc.experimental.$ArityMethodName$($RequestParameter$, " + "target, '/$PackageQualifiedService$/$Method$',\n"); { IndentScope continuation_indent(out); StringMap serializer_dict; serializer_dict["RequestModuleAndClass"] = request_module_and_class; serializer_dict["ResponseModuleAndClass"] = response_module_and_class; - out->Print(serializer_dict, "$RequestModuleAndClass$.SerializeToString,\n"); + out->Print(serializer_dict, + "$RequestModuleAndClass$.SerializeToString,\n"); out->Print(serializer_dict, "$ResponseModuleAndClass$.FromString,\n"); out->Print("options, channel_credentials,\n"); - out->Print("call_credentials, compression, wait_for_ready, timeout, metadata)\n"); + out->Print( + "call_credentials, compression, wait_for_ready, timeout, " + "metadata)\n"); } } } @@ -730,7 +733,8 @@ bool PrivateGenerator::PrintGAServices(grpc_generator::Printer* out) { PrintServicer(service.get(), out) && PrintAddServicerToServer(package_qualified_service_name, service.get(), out) && - PrintServiceClass(package_qualified_service_name, service.get(), out))) { + PrintServiceClass(package_qualified_service_name, service.get(), + out))) { return false; } } diff --git a/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py b/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py index 20218fa798e..cddbbb24db9 100644 --- a/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py +++ b/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py @@ -510,12 +510,30 @@ class SimpleStubsPluginTest(unittest.TestCase): class Servicer(service_pb2_grpc.TestServiceServicer): def UnaryCall(self, request, context): - return SimpleStubsPluginTest.servicer_methods.UnaryCall(request, context) + return SimpleStubsPluginTest.servicer_methods.UnaryCall( + request, context) + + def StreamingOutputCall(self, request, context): + return SimpleStubsPluginTest.servicer_methods.StreamingOutputCall( + request, context) + + def StreamingInputCall(self, request_iterator, context): + return SimpleStubsPluginTest.servicer_methods.StreamingInputCall( + request_iterator, context) + + def FullDuplexCall(self, request_iterator, context): + return SimpleStubsPluginTest.servicer_methods.FullDuplexCall( + request_iterator, context) + + def HalfDuplexCall(self, request_iterator, context): + return SimpleStubsPluginTest.servicer_methods.HalfDuplexCall( + request_iterator, context) def setUp(self): super(SimpleStubsPluginTest, self).setUp() self._server = test_common.test_server() - service_pb2_grpc.add_TestServiceServicer_to_server(self.Servicer(), self._server) + service_pb2_grpc.add_TestServiceServicer_to_server( + self.Servicer(), self._server) self._port = self._server.add_insecure_port('[::]:0') self._server.start() self._target = 'localhost:{}'.format(self._port) @@ -524,13 +542,58 @@ class SimpleStubsPluginTest(unittest.TestCase): self._server.stop(None) super(SimpleStubsPluginTest, self).tearDown() - def testUnaryCallSimple(self): + def testUnaryCall(self): request = request_pb2.SimpleRequest(response_size=13) response = service_pb2_grpc.TestService.UnaryCall(request, self._target) expected_response = self.servicer_methods.UnaryCall( request, 'not a real context!') self.assertEqual(expected_response, response) + def testStreamingOutputCall(self): + request = _streaming_output_request() + expected_responses = self.servicer_methods.StreamingOutputCall( + request, 'not a real RpcContext!') + responses = service_pb2_grpc.TestService.StreamingOutputCall( + request, self._target) + for expected_response, response in moves.zip_longest( + expected_responses, responses): + self.assertEqual(expected_response, response) + + def testStreamingInputCall(self): + response = service_pb2_grpc.TestService.StreamingInputCall( + _streaming_input_request_iterator(), self._target) + expected_response = self.servicer_methods.StreamingInputCall( + _streaming_input_request_iterator(), 'not a real RpcContext!') + self.assertEqual(expected_response, response) + + def testFullDuplexCall(self): + responses = service_pb2_grpc.TestService.FullDuplexCall( + _full_duplex_request_iterator(), self._target) + expected_responses = self.servicer_methods.FullDuplexCall( + _full_duplex_request_iterator(), 'not a real RpcContext!') + for expected_response, response in moves.zip_longest( + expected_responses, responses): + self.assertEqual(expected_response, response) + + def testHalfDuplexCall(self): + + def half_duplex_request_iterator(): + request = request_pb2.StreamingOutputCallRequest() + request.response_parameters.add(size=1, interval_us=0) + yield request + request = request_pb2.StreamingOutputCallRequest() + request.response_parameters.add(size=2, interval_us=0) + request.response_parameters.add(size=3, interval_us=0) + yield request + + responses = service_pb2_grpc.TestService.HalfDuplexCall( + half_duplex_request_iterator(), self._target) + expected_responses = self.servicer_methods.HalfDuplexCall( + half_duplex_request_iterator(), 'not a real RpcContext!') + for expected_response, response in moves.zip_longest( + expected_responses, responses): + self.assertEqual(expected_response, response) + if __name__ == '__main__': unittest.main(verbosity=2)