pull/22254/head
Richard Belleville 5 years ago
parent b5d69d2cd1
commit f921f01262
  1. 38
      src/compiler/python_generator.cc
  2. 69
      src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py

@ -575,10 +575,9 @@ bool PrivateGenerator::PrintAddServicerToServer(
* file, with no suffixes. Since this class merely acts as a namespace, it
* should never be instantiated.
*/
bool PrivateGenerator::PrintServiceClass(const grpc::string& package_qualified_service_name,
const grpc_generator::Service* service,
grpc_generator::Printer* out)
{
bool PrivateGenerator::PrintServiceClass(
const grpc::string& package_qualified_service_name,
const grpc_generator::Service* service, grpc_generator::Printer* out) {
StringMap dict;
dict["Service"] = service->name();
out->Print("\n\n");
@ -609,16 +608,15 @@ bool PrivateGenerator::PrintServiceClass(const grpc::string& package_qualified_s
method_dict["Method"] = method->name();
out->Print("@staticmethod\n");
out->Print(method_dict, "def $Method$(");
grpc::string request_parameter(
method->ClientStreaming() ? "request_iterator" : "request");
StringMap args_dict;
args_dict["RequestParameter"] = request_parameter;
{
IndentScope args_indent(out);
IndentScope args_double_indent(out);
grpc::string request_parameter(method->ClientStreaming() ? "request_iterator" : "request");
StringMap args_dict;
args_dict["RequestParameter"] = request_parameter;
out->Print(args_dict, "$RequestParameter$,\n");
out->Print("target,\n");
out->Print("request_serializer=None,\n");
out->Print("request_deserializer=None,\n");
out->Print("options=(),\n");
out->Print("channel_credentials=None,\n");
out->Print("call_credentials=None,\n");
@ -632,20 +630,25 @@ bool PrivateGenerator::PrintServiceClass(const grpc::string& package_qualified_s
grpc::string arity_method_name =
grpc::string(method->ClientStreaming() ? "stream" : "unary") + "_" +
grpc::string(method->ServerStreaming() ? "stream" : "unary");
StringMap invocation_dict;
invocation_dict["ArityMethodName"] = arity_method_name;
invocation_dict["PackageQualifiedService"] = package_qualified_service_name;
invocation_dict["Method"] = method->name();
out->Print(invocation_dict, "return grpc.experimental.$ArityMethodName$(request, target, '/$PackageQualifiedService$/$Method$',\n");
args_dict["ArityMethodName"] = arity_method_name;
args_dict["PackageQualifiedService"] = package_qualified_service_name;
args_dict["Method"] = method->name();
out->Print(args_dict,
"return "
"grpc.experimental.$ArityMethodName$($RequestParameter$, "
"target, '/$PackageQualifiedService$/$Method$',\n");
{
IndentScope continuation_indent(out);
StringMap serializer_dict;
serializer_dict["RequestModuleAndClass"] = request_module_and_class;
serializer_dict["ResponseModuleAndClass"] = response_module_and_class;
out->Print(serializer_dict, "$RequestModuleAndClass$.SerializeToString,\n");
out->Print(serializer_dict,
"$RequestModuleAndClass$.SerializeToString,\n");
out->Print(serializer_dict, "$ResponseModuleAndClass$.FromString,\n");
out->Print("options, channel_credentials,\n");
out->Print("call_credentials, compression, wait_for_ready, timeout, metadata)\n");
out->Print(
"call_credentials, compression, wait_for_ready, timeout, "
"metadata)\n");
}
}
}
@ -730,7 +733,8 @@ bool PrivateGenerator::PrintGAServices(grpc_generator::Printer* out) {
PrintServicer(service.get(), out) &&
PrintAddServicerToServer(package_qualified_service_name,
service.get(), out) &&
PrintServiceClass(package_qualified_service_name, service.get(), out))) {
PrintServiceClass(package_qualified_service_name, service.get(),
out))) {
return false;
}
}

@ -510,12 +510,30 @@ class SimpleStubsPluginTest(unittest.TestCase):
class Servicer(service_pb2_grpc.TestServiceServicer):
def UnaryCall(self, request, context):
return SimpleStubsPluginTest.servicer_methods.UnaryCall(request, context)
return SimpleStubsPluginTest.servicer_methods.UnaryCall(
request, context)
def StreamingOutputCall(self, request, context):
return SimpleStubsPluginTest.servicer_methods.StreamingOutputCall(
request, context)
def StreamingInputCall(self, request_iterator, context):
return SimpleStubsPluginTest.servicer_methods.StreamingInputCall(
request_iterator, context)
def FullDuplexCall(self, request_iterator, context):
return SimpleStubsPluginTest.servicer_methods.FullDuplexCall(
request_iterator, context)
def HalfDuplexCall(self, request_iterator, context):
return SimpleStubsPluginTest.servicer_methods.HalfDuplexCall(
request_iterator, context)
def setUp(self):
super(SimpleStubsPluginTest, self).setUp()
self._server = test_common.test_server()
service_pb2_grpc.add_TestServiceServicer_to_server(self.Servicer(), self._server)
service_pb2_grpc.add_TestServiceServicer_to_server(
self.Servicer(), self._server)
self._port = self._server.add_insecure_port('[::]:0')
self._server.start()
self._target = 'localhost:{}'.format(self._port)
@ -524,13 +542,58 @@ class SimpleStubsPluginTest(unittest.TestCase):
self._server.stop(None)
super(SimpleStubsPluginTest, self).tearDown()
def testUnaryCallSimple(self):
def testUnaryCall(self):
request = request_pb2.SimpleRequest(response_size=13)
response = service_pb2_grpc.TestService.UnaryCall(request, self._target)
expected_response = self.servicer_methods.UnaryCall(
request, 'not a real context!')
self.assertEqual(expected_response, response)
def testStreamingOutputCall(self):
request = _streaming_output_request()
expected_responses = self.servicer_methods.StreamingOutputCall(
request, 'not a real RpcContext!')
responses = service_pb2_grpc.TestService.StreamingOutputCall(
request, self._target)
for expected_response, response in moves.zip_longest(
expected_responses, responses):
self.assertEqual(expected_response, response)
def testStreamingInputCall(self):
response = service_pb2_grpc.TestService.StreamingInputCall(
_streaming_input_request_iterator(), self._target)
expected_response = self.servicer_methods.StreamingInputCall(
_streaming_input_request_iterator(), 'not a real RpcContext!')
self.assertEqual(expected_response, response)
def testFullDuplexCall(self):
responses = service_pb2_grpc.TestService.FullDuplexCall(
_full_duplex_request_iterator(), self._target)
expected_responses = self.servicer_methods.FullDuplexCall(
_full_duplex_request_iterator(), 'not a real RpcContext!')
for expected_response, response in moves.zip_longest(
expected_responses, responses):
self.assertEqual(expected_response, response)
def testHalfDuplexCall(self):
def half_duplex_request_iterator():
request = request_pb2.StreamingOutputCallRequest()
request.response_parameters.add(size=1, interval_us=0)
yield request
request = request_pb2.StreamingOutputCallRequest()
request.response_parameters.add(size=2, interval_us=0)
request.response_parameters.add(size=3, interval_us=0)
yield request
responses = service_pb2_grpc.TestService.HalfDuplexCall(
half_duplex_request_iterator(), self._target)
expected_responses = self.servicer_methods.HalfDuplexCall(
half_duplex_request_iterator(), 'not a real RpcContext!')
for expected_response, response in moves.zip_longest(
expected_responses, responses):
self.assertEqual(expected_response, response)
if __name__ == '__main__':
unittest.main(verbosity=2)

Loading…
Cancel
Save