[Python O11Y] Reapply registered method change (#35850)

This reverts commit a18279db2e.

<!--

If you know who should review your pull request, please assign it to that
person, otherwise the pull request would get assigned randomly.

If your pull request is for a specific language, please add the appropriate
lang label.

-->

Closes #35850

PiperOrigin-RevId: 607476066
pull/35921/head^2
Xuan Wang 10 months ago committed by Copybara-Service
parent cf79445171
commit 2c9b599e5e
  1. 27
      examples/python/helloworld/helloworld_pb2.py
  2. 14
      examples/python/helloworld/helloworld_pb2.pyi
  3. 104
      examples/python/helloworld/helloworld_pb2_grpc.py
  4. 23
      src/compiler/python_generator.cc
  5. 4
      src/python/.gitignore
  6. 12
      src/python/grpcio/grpc/__init__.py
  7. 75
      src/python/grpcio/grpc/_channel.py
  8. 7
      src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi
  9. 71
      src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
  10. 6
      src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi
  11. 36
      src/python/grpcio/grpc/_interceptor.py
  12. 81
      src/python/grpcio/grpc/_simple_stubs.py
  13. 12
      src/python/grpcio/grpc/aio/_base_channel.py
  14. 21
      src/python/grpcio/grpc/aio/_channel.py
  15. 27
      src/python/grpcio_testing/grpc_testing/_channel/_channel.py
  16. 17
      src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py
  17. 5
      src/python/grpcio_tests/tests/csds/csds_test.py
  18. 11
      src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py
  19. 9
      src/python/grpcio_tests/tests/qps/benchmark_client.py
  20. 25
      src/python/grpcio_tests/tests/status/_grpc_status_test.py
  21. 20
      src/python/grpcio_tests/tests/unit/_abort_test.py
  22. 20
      src/python/grpcio_tests/tests/unit/_auth_context_test.py
  23. 35
      src/python/grpcio_tests/tests/unit/_channel_close_test.py
  24. 20
      src/python/grpcio_tests/tests/unit/_compression_test.py
  25. 10
      src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py
  26. 5
      src/python/grpcio_tests/tests/unit/_dns_resolver_test.py
  27. 24
      src/python/grpcio_tests/tests/unit/_empty_message_test.py
  28. 5
      src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py
  29. 20
      src/python/grpcio_tests/tests/unit/_exit_scenarios.py
  30. 8
      src/python/grpcio_tests/tests/unit/_interceptor_test.py
  31. 9
      src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py
  32. 22
      src/python/grpcio_tests/tests/unit/_invocation_defects_test.py
  33. 14
      src/python/grpcio_tests/tests/unit/_local_credentials_test.py
  34. 80
      src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
  35. 39
      src/python/grpcio_tests/tests/unit/_metadata_flags_test.py
  36. 16
      src/python/grpcio_tests/tests/unit/_metadata_test.py
  37. 5
      src/python/grpcio_tests/tests/unit/_reconnect_test.py
  38. 20
      src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py
  39. 18
      src/python/grpcio_tests/tests/unit/_rpc_test_helpers.py
  40. 5
      src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py
  41. 5
      src/python/grpcio_tests/tests/unit/_session_cache_test.py
  42. 24
      src/python/grpcio_tests/tests/unit/_signal_client.py
  43. 14
      src/python/grpcio_tests/tests/unit/_xds_credentials_test.py
  44. 2
      src/python/grpcio_tests/tests_gevent/unit/close_channel_test.py
  45. 5
      src/python/grpcio_tests/tests_py3_only/unit/_leak_test.py
  46. 46
      src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py

@ -1,11 +1,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# source: helloworld.proto # source: helloworld.proto
# Protobuf Python Version: 4.25.0
"""Generated protocol buffer code.""" """Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
@ -13,18 +14,18 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10helloworld.proto\x12\nhelloworld\"\x1c\n\x0cHelloRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"\x1d\n\nHelloReply\x12\x0f\n\x07message\x18\x01 \x01(\t2I\n\x07Greeter\x12>\n\x08SayHello\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00\x42\x36\n\x1bio.grpc.examples.helloworldB\x0fHelloWorldProtoP\x01\xa2\x02\x03HLWb\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10helloworld.proto\x12\nhelloworld\"\x1c\n\x0cHelloRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"\x1d\n\nHelloReply\x12\x0f\n\x07message\x18\x01 \x01(\t2\xe4\x01\n\x07Greeter\x12>\n\x08SayHello\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00\x12K\n\x13SayHelloStreamReply\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00\x30\x01\x12L\n\x12SayHelloBidiStream\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00(\x01\x30\x01\x42\x36\n\x1bio.grpc.examples.helloworldB\x0fHelloWorldProtoP\x01\xa2\x02\x03HLWb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _globals = globals()
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'helloworld_pb2', globals()) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'helloworld_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False: if _descriptor._USE_C_DESCRIPTORS == False:
_globals['DESCRIPTOR']._options = None
DESCRIPTOR._options = None _globals['DESCRIPTOR']._serialized_options = b'\n\033io.grpc.examples.helloworldB\017HelloWorldProtoP\001\242\002\003HLW'
DESCRIPTOR._serialized_options = b'\n\033io.grpc.examples.helloworldB\017HelloWorldProtoP\001\242\002\003HLW' _globals['_HELLOREQUEST']._serialized_start=32
_HELLOREQUEST._serialized_start=32 _globals['_HELLOREQUEST']._serialized_end=60
_HELLOREQUEST._serialized_end=60 _globals['_HELLOREPLY']._serialized_start=62
_HELLOREPLY._serialized_start=62 _globals['_HELLOREPLY']._serialized_end=91
_HELLOREPLY._serialized_end=91 _globals['_GREETER']._serialized_start=94
_GREETER._serialized_start=93 _globals['_GREETER']._serialized_end=322
_GREETER._serialized_end=166
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)

@ -4,14 +4,14 @@ from typing import ClassVar as _ClassVar, Optional as _Optional
DESCRIPTOR: _descriptor.FileDescriptor DESCRIPTOR: _descriptor.FileDescriptor
class HelloReply(_message.Message):
__slots__ = ["message"]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
message: str
def __init__(self, message: _Optional[str] = ...) -> None: ...
class HelloRequest(_message.Message): class HelloRequest(_message.Message):
__slots__ = ["name"] __slots__ = ("name",)
NAME_FIELD_NUMBER: _ClassVar[int] NAME_FIELD_NUMBER: _ClassVar[int]
name: str name: str
def __init__(self, name: _Optional[str] = ...) -> None: ... def __init__(self, name: _Optional[str] = ...) -> None: ...
class HelloReply(_message.Message):
__slots__ = ("message",)
MESSAGE_FIELD_NUMBER: _ClassVar[int]
message: str
def __init__(self, message: _Optional[str] = ...) -> None: ...

@ -19,7 +19,17 @@ class GreeterStub(object):
'/helloworld.Greeter/SayHello', '/helloworld.Greeter/SayHello',
request_serializer=helloworld__pb2.HelloRequest.SerializeToString, request_serializer=helloworld__pb2.HelloRequest.SerializeToString,
response_deserializer=helloworld__pb2.HelloReply.FromString, response_deserializer=helloworld__pb2.HelloReply.FromString,
) _registered_method=True)
self.SayHelloStreamReply = channel.unary_stream(
'/helloworld.Greeter/SayHelloStreamReply',
request_serializer=helloworld__pb2.HelloRequest.SerializeToString,
response_deserializer=helloworld__pb2.HelloReply.FromString,
_registered_method=True)
self.SayHelloBidiStream = channel.stream_stream(
'/helloworld.Greeter/SayHelloBidiStream',
request_serializer=helloworld__pb2.HelloRequest.SerializeToString,
response_deserializer=helloworld__pb2.HelloReply.FromString,
_registered_method=True)
class GreeterServicer(object): class GreeterServicer(object):
@ -33,6 +43,18 @@ class GreeterServicer(object):
context.set_details('Method not implemented!') context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!') raise NotImplementedError('Method not implemented!')
def SayHelloStreamReply(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SayHelloBidiStream(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_GreeterServicer_to_server(servicer, server): def add_GreeterServicer_to_server(servicer, server):
rpc_method_handlers = { rpc_method_handlers = {
@ -41,6 +63,16 @@ def add_GreeterServicer_to_server(servicer, server):
request_deserializer=helloworld__pb2.HelloRequest.FromString, request_deserializer=helloworld__pb2.HelloRequest.FromString,
response_serializer=helloworld__pb2.HelloReply.SerializeToString, response_serializer=helloworld__pb2.HelloReply.SerializeToString,
), ),
'SayHelloStreamReply': grpc.unary_stream_rpc_method_handler(
servicer.SayHelloStreamReply,
request_deserializer=helloworld__pb2.HelloRequest.FromString,
response_serializer=helloworld__pb2.HelloReply.SerializeToString,
),
'SayHelloBidiStream': grpc.stream_stream_rpc_method_handler(
servicer.SayHelloBidiStream,
request_deserializer=helloworld__pb2.HelloRequest.FromString,
response_serializer=helloworld__pb2.HelloReply.SerializeToString,
),
} }
generic_handler = grpc.method_handlers_generic_handler( generic_handler = grpc.method_handlers_generic_handler(
'helloworld.Greeter', rpc_method_handlers) 'helloworld.Greeter', rpc_method_handlers)
@ -63,8 +95,72 @@ class Greeter(object):
wait_for_ready=None, wait_for_ready=None,
timeout=None, timeout=None,
metadata=None): metadata=None):
return grpc.experimental.unary_unary(request, target, '/helloworld.Greeter/SayHello', return grpc.experimental.unary_unary(
request,
target,
'/helloworld.Greeter/SayHello',
helloworld__pb2.HelloRequest.SerializeToString,
helloworld__pb2.HelloReply.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SayHelloStreamReply(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(
request,
target,
'/helloworld.Greeter/SayHelloStreamReply',
helloworld__pb2.HelloRequest.SerializeToString,
helloworld__pb2.HelloReply.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SayHelloBidiStream(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(
request_iterator,
target,
'/helloworld.Greeter/SayHelloBidiStream',
helloworld__pb2.HelloRequest.SerializeToString, helloworld__pb2.HelloRequest.SerializeToString,
helloworld__pb2.HelloReply.FromString, helloworld__pb2.HelloReply.FromString,
options, channel_credentials, options,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

@ -467,7 +467,7 @@ bool PrivateGenerator::PrintStub(
out->Print( out->Print(
method_dict, method_dict,
"response_deserializer=$ResponseModuleAndClass$.FromString,\n"); "response_deserializer=$ResponseModuleAndClass$.FromString,\n");
out->Print(")\n"); out->Print("_registered_method=True)\n");
} }
} }
} }
@ -642,22 +642,27 @@ bool PrivateGenerator::PrintServiceClass(
args_dict["ArityMethodName"] = arity_method_name; args_dict["ArityMethodName"] = arity_method_name;
args_dict["PackageQualifiedService"] = package_qualified_service_name; args_dict["PackageQualifiedService"] = package_qualified_service_name;
args_dict["Method"] = method->name(); args_dict["Method"] = method->name();
out->Print(args_dict, out->Print(args_dict, "return grpc.experimental.$ArityMethodName$(\n");
"return "
"grpc.experimental.$ArityMethodName$($RequestParameter$, "
"target, '/$PackageQualifiedService$/$Method$',\n");
{ {
IndentScope continuation_indent(out); IndentScope continuation_indent(out);
StringMap serializer_dict; StringMap serializer_dict;
out->Print(args_dict, "$RequestParameter$,\n");
out->Print("target,\n");
out->Print(args_dict, "'/$PackageQualifiedService$/$Method$',\n");
serializer_dict["RequestModuleAndClass"] = request_module_and_class; serializer_dict["RequestModuleAndClass"] = request_module_and_class;
serializer_dict["ResponseModuleAndClass"] = response_module_and_class; serializer_dict["ResponseModuleAndClass"] = response_module_and_class;
out->Print(serializer_dict, out->Print(serializer_dict,
"$RequestModuleAndClass$.SerializeToString,\n"); "$RequestModuleAndClass$.SerializeToString,\n");
out->Print(serializer_dict, "$ResponseModuleAndClass$.FromString,\n"); out->Print(serializer_dict, "$ResponseModuleAndClass$.FromString,\n");
out->Print("options, channel_credentials,\n"); out->Print("options,\n");
out->Print( out->Print("channel_credentials,\n");
"insecure, call_credentials, compression, wait_for_ready, " out->Print("insecure,\n");
"timeout, metadata)\n"); out->Print("call_credentials,\n");
out->Print("compression,\n");
out->Print("wait_for_ready,\n");
out->Print("timeout,\n");
out->Print("metadata,\n");
out->Print("_registered_method=True)\n");
} }
} }
} }

@ -1,4 +1,6 @@
gens/ build/
grpc_root/
third_party/
*_pb2.py *_pb2.py
*_pb2.pyi *_pb2.pyi
*_pb2_grpc.py *_pb2_grpc.py

@ -1004,6 +1004,7 @@ class Channel(abc.ABC):
method, method,
request_serializer=None, request_serializer=None,
response_deserializer=None, response_deserializer=None,
_registered_method=False,
): ):
"""Creates a UnaryUnaryMultiCallable for a unary-unary method. """Creates a UnaryUnaryMultiCallable for a unary-unary method.
@ -1014,6 +1015,8 @@ class Channel(abc.ABC):
response_deserializer: Optional :term:`deserializer` for deserializing the response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None response message. Response goes undeserialized in case None
is passed. is passed.
_registered_method: Implementation Private. A bool representing whether the method
is registered.
Returns: Returns:
A UnaryUnaryMultiCallable value for the named unary-unary method. A UnaryUnaryMultiCallable value for the named unary-unary method.
@ -1026,6 +1029,7 @@ class Channel(abc.ABC):
method, method,
request_serializer=None, request_serializer=None,
response_deserializer=None, response_deserializer=None,
_registered_method=False,
): ):
"""Creates a UnaryStreamMultiCallable for a unary-stream method. """Creates a UnaryStreamMultiCallable for a unary-stream method.
@ -1036,6 +1040,8 @@ class Channel(abc.ABC):
response_deserializer: Optional :term:`deserializer` for deserializing the response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None is response message. Response goes undeserialized in case None is
passed. passed.
_registered_method: Implementation Private. A bool representing whether the method
is registered.
Returns: Returns:
A UnaryStreamMultiCallable value for the name unary-stream method. A UnaryStreamMultiCallable value for the name unary-stream method.
@ -1048,6 +1054,7 @@ class Channel(abc.ABC):
method, method,
request_serializer=None, request_serializer=None,
response_deserializer=None, response_deserializer=None,
_registered_method=False,
): ):
"""Creates a StreamUnaryMultiCallable for a stream-unary method. """Creates a StreamUnaryMultiCallable for a stream-unary method.
@ -1058,6 +1065,8 @@ class Channel(abc.ABC):
response_deserializer: Optional :term:`deserializer` for deserializing the response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None is response message. Response goes undeserialized in case None is
passed. passed.
_registered_method: Implementation Private. A bool representing whether the method
is registered.
Returns: Returns:
A StreamUnaryMultiCallable value for the named stream-unary method. A StreamUnaryMultiCallable value for the named stream-unary method.
@ -1070,6 +1079,7 @@ class Channel(abc.ABC):
method, method,
request_serializer=None, request_serializer=None,
response_deserializer=None, response_deserializer=None,
_registered_method=False,
): ):
"""Creates a StreamStreamMultiCallable for a stream-stream method. """Creates a StreamStreamMultiCallable for a stream-stream method.
@ -1080,6 +1090,8 @@ class Channel(abc.ABC):
response_deserializer: Optional :term:`deserializer` for deserializing the response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None response message. Response goes undeserialized in case None
is passed. is passed.
_registered_method: Implementation Private. A bool representing whether the method
is registered.
Returns: Returns:
A StreamStreamMultiCallable value for the named stream-stream method. A StreamStreamMultiCallable value for the named stream-stream method.

@ -24,6 +24,7 @@ import types
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Dict,
Iterator, Iterator,
List, List,
Optional, Optional,
@ -1054,6 +1055,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
_request_serializer: Optional[SerializingFunction] _request_serializer: Optional[SerializingFunction]
_response_deserializer: Optional[DeserializingFunction] _response_deserializer: Optional[DeserializingFunction]
_context: Any _context: Any
_registered_call_handle: Optional[int]
__slots__ = [ __slots__ = [
"_channel", "_channel",
@ -1074,6 +1076,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
target: bytes, target: bytes,
request_serializer: Optional[SerializingFunction], request_serializer: Optional[SerializingFunction],
response_deserializer: Optional[DeserializingFunction], response_deserializer: Optional[DeserializingFunction],
_registered_call_handle: Optional[int],
): ):
self._channel = channel self._channel = channel
self._managed_call = managed_call self._managed_call = managed_call
@ -1082,6 +1085,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
self._request_serializer = request_serializer self._request_serializer = request_serializer
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context() self._context = cygrpc.build_census_context()
self._registered_call_handle = _registered_call_handle
def _prepare( def _prepare(
self, self,
@ -1153,6 +1157,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
), ),
), ),
self._context, self._context,
self._registered_call_handle,
) )
event = call.next_event() event = call.next_event()
_handle_event(event, state, self._response_deserializer) _handle_event(event, state, self._response_deserializer)
@ -1221,6 +1226,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
(operations,), (operations,),
event_handler, event_handler,
self._context, self._context,
self._registered_call_handle,
) )
return _MultiThreadedRendezvous( return _MultiThreadedRendezvous(
state, call, self._response_deserializer, deadline state, call, self._response_deserializer, deadline
@ -1234,6 +1240,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
_request_serializer: Optional[SerializingFunction] _request_serializer: Optional[SerializingFunction]
_response_deserializer: Optional[DeserializingFunction] _response_deserializer: Optional[DeserializingFunction]
_context: Any _context: Any
_registered_call_handle: Optional[int]
__slots__ = [ __slots__ = [
"_channel", "_channel",
@ -1252,6 +1259,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
target: bytes, target: bytes,
request_serializer: SerializingFunction, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction, response_deserializer: DeserializingFunction,
_registered_call_handle: Optional[int],
): ):
self._channel = channel self._channel = channel
self._method = method self._method = method
@ -1259,6 +1267,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
self._request_serializer = request_serializer self._request_serializer = request_serializer
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context() self._context = cygrpc.build_census_context()
self._registered_call_handle = _registered_call_handle
def __call__( # pylint: disable=too-many-locals def __call__( # pylint: disable=too-many-locals
self, self,
@ -1317,6 +1326,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
call_credentials, call_credentials,
operations_and_tags, operations_and_tags,
self._context, self._context,
self._registered_call_handle,
) )
return _SingleThreadedRendezvous( return _SingleThreadedRendezvous(
state, call, self._response_deserializer, deadline state, call, self._response_deserializer, deadline
@ -1331,6 +1341,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
_request_serializer: Optional[SerializingFunction] _request_serializer: Optional[SerializingFunction]
_response_deserializer: Optional[DeserializingFunction] _response_deserializer: Optional[DeserializingFunction]
_context: Any _context: Any
_registered_call_handle: Optional[int]
__slots__ = [ __slots__ = [
"_channel", "_channel",
@ -1351,6 +1362,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
target: bytes, target: bytes,
request_serializer: SerializingFunction, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction, response_deserializer: DeserializingFunction,
_registered_call_handle: Optional[int],
): ):
self._channel = channel self._channel = channel
self._managed_call = managed_call self._managed_call = managed_call
@ -1359,6 +1371,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
self._request_serializer = request_serializer self._request_serializer = request_serializer
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context() self._context = cygrpc.build_census_context()
self._registered_call_handle = _registered_call_handle
def __call__( # pylint: disable=too-many-locals def __call__( # pylint: disable=too-many-locals
self, self,
@ -1408,6 +1421,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
operations, operations,
_event_handler(state, self._response_deserializer), _event_handler(state, self._response_deserializer),
self._context, self._context,
self._registered_call_handle,
) )
return _MultiThreadedRendezvous( return _MultiThreadedRendezvous(
state, call, self._response_deserializer, deadline state, call, self._response_deserializer, deadline
@ -1422,6 +1436,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
_request_serializer: Optional[SerializingFunction] _request_serializer: Optional[SerializingFunction]
_response_deserializer: Optional[DeserializingFunction] _response_deserializer: Optional[DeserializingFunction]
_context: Any _context: Any
_registered_call_handle: Optional[int]
__slots__ = [ __slots__ = [
"_channel", "_channel",
@ -1442,6 +1457,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
target: bytes, target: bytes,
request_serializer: Optional[SerializingFunction], request_serializer: Optional[SerializingFunction],
response_deserializer: Optional[DeserializingFunction], response_deserializer: Optional[DeserializingFunction],
_registered_call_handle: Optional[int],
): ):
self._channel = channel self._channel = channel
self._managed_call = managed_call self._managed_call = managed_call
@ -1450,6 +1466,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
self._request_serializer = request_serializer self._request_serializer = request_serializer
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context() self._context = cygrpc.build_census_context()
self._registered_call_handle = _registered_call_handle
def _blocking( def _blocking(
self, self,
@ -1482,6 +1499,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
augmented_metadata, initial_metadata_flags augmented_metadata, initial_metadata_flags
), ),
self._context, self._context,
self._registered_call_handle,
) )
_consume_request_iterator( _consume_request_iterator(
request_iterator, state, call, self._request_serializer, None request_iterator, state, call, self._request_serializer, None
@ -1572,6 +1590,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
), ),
event_handler, event_handler,
self._context, self._context,
self._registered_call_handle,
) )
_consume_request_iterator( _consume_request_iterator(
request_iterator, request_iterator,
@ -1593,6 +1612,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
_request_serializer: Optional[SerializingFunction] _request_serializer: Optional[SerializingFunction]
_response_deserializer: Optional[DeserializingFunction] _response_deserializer: Optional[DeserializingFunction]
_context: Any _context: Any
_registered_call_handle: Optional[int]
__slots__ = [ __slots__ = [
"_channel", "_channel",
@ -1611,8 +1631,9 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
managed_call: IntegratedCallFactory, managed_call: IntegratedCallFactory,
method: bytes, method: bytes,
target: bytes, target: bytes,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction],
response_deserializer: Optional[DeserializingFunction] = None, response_deserializer: Optional[DeserializingFunction],
_registered_call_handle: Optional[int],
): ):
self._channel = channel self._channel = channel
self._managed_call = managed_call self._managed_call = managed_call
@ -1621,6 +1642,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
self._request_serializer = request_serializer self._request_serializer = request_serializer
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context() self._context = cygrpc.build_census_context()
self._registered_call_handle = _registered_call_handle
def __call__( def __call__(
self, self,
@ -1662,6 +1684,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
operations, operations,
event_handler, event_handler,
self._context, self._context,
self._registered_call_handle,
) )
_consume_request_iterator( _consume_request_iterator(
request_iterator, request_iterator,
@ -1751,7 +1774,8 @@ def _channel_managed_call_management(state: _ChannelCallState):
credentials: Optional[cygrpc.CallCredentials], credentials: Optional[cygrpc.CallCredentials],
operations: Sequence[Sequence[cygrpc.Operation]], operations: Sequence[Sequence[cygrpc.Operation]],
event_handler: UserTag, event_handler: UserTag,
context, context: Any,
_registered_call_handle: Optional[int],
) -> cygrpc.IntegratedCall: ) -> cygrpc.IntegratedCall:
"""Creates a cygrpc.IntegratedCall. """Creates a cygrpc.IntegratedCall.
@ -1768,6 +1792,8 @@ def _channel_managed_call_management(state: _ChannelCallState):
event_handler: A behavior to call to handle the events resultant from event_handler: A behavior to call to handle the events resultant from
the operations on the call. the operations on the call.
context: Context object for distributed tracing. context: Context object for distributed tracing.
_registered_call_handle: An int representing the call handle of the
method, or None if the method is not registered.
Returns: Returns:
A cygrpc.IntegratedCall with which to conduct an RPC. A cygrpc.IntegratedCall with which to conduct an RPC.
""" """
@ -1788,6 +1814,7 @@ def _channel_managed_call_management(state: _ChannelCallState):
credentials, credentials,
operations_and_tags, operations_and_tags,
context, context,
_registered_call_handle,
) )
if state.managed_calls == 0: if state.managed_calls == 0:
state.managed_calls = 1 state.managed_calls = 1
@ -2021,6 +2048,7 @@ class Channel(grpc.Channel):
_call_state: _ChannelCallState _call_state: _ChannelCallState
_connectivity_state: _ChannelConnectivityState _connectivity_state: _ChannelConnectivityState
_target: str _target: str
_registered_call_handles: Dict[str, int]
def __init__( def __init__(
self, self,
@ -2055,6 +2083,22 @@ class Channel(grpc.Channel):
if cygrpc.g_gevent_activated: if cygrpc.g_gevent_activated:
cygrpc.gevent_increment_channel_count() cygrpc.gevent_increment_channel_count()
def _get_registered_call_handle(self, method: str) -> int:
"""
Get the registered call handle for a method.
This is a semi-private method. It is intended for use only by gRPC generated code.
This method is not thread-safe.
Args:
method: Required, the method name for the RPC.
Returns:
The registered call handle pointer in the form of a Python Long.
"""
return self._channel.get_registered_call_handle(_common.encode(method))
def _process_python_options( def _process_python_options(
self, python_options: Sequence[ChannelArgumentType] self, python_options: Sequence[ChannelArgumentType]
) -> None: ) -> None:
@ -2078,12 +2122,17 @@ class Channel(grpc.Channel):
) -> None: ) -> None:
_unsubscribe(self._connectivity_state, callback) _unsubscribe(self._connectivity_state, callback)
# pylint: disable=arguments-differ
def unary_unary( def unary_unary(
self, self,
method: str, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.UnaryUnaryMultiCallable: ) -> grpc.UnaryUnaryMultiCallable:
_registered_call_handle = None
if _registered_method:
_registered_call_handle = self._get_registered_call_handle(method)
return _UnaryUnaryMultiCallable( return _UnaryUnaryMultiCallable(
self._channel, self._channel,
_channel_managed_call_management(self._call_state), _channel_managed_call_management(self._call_state),
@ -2091,14 +2140,20 @@ class Channel(grpc.Channel):
_common.encode(self._target), _common.encode(self._target),
request_serializer, request_serializer,
response_deserializer, response_deserializer,
_registered_call_handle,
) )
# pylint: disable=arguments-differ
def unary_stream( def unary_stream(
self, self,
method: str, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.UnaryStreamMultiCallable: ) -> grpc.UnaryStreamMultiCallable:
_registered_call_handle = None
if _registered_method:
_registered_call_handle = self._get_registered_call_handle(method)
# NOTE(rbellevi): Benchmarks have shown that running a unary-stream RPC # NOTE(rbellevi): Benchmarks have shown that running a unary-stream RPC
# on a single Python thread results in an appreciable speed-up. However, # on a single Python thread results in an appreciable speed-up. However,
# due to slight differences in capability, the multi-threaded variant # due to slight differences in capability, the multi-threaded variant
@ -2110,6 +2165,7 @@ class Channel(grpc.Channel):
_common.encode(self._target), _common.encode(self._target),
request_serializer, request_serializer,
response_deserializer, response_deserializer,
_registered_call_handle,
) )
else: else:
return _UnaryStreamMultiCallable( return _UnaryStreamMultiCallable(
@ -2119,14 +2175,20 @@ class Channel(grpc.Channel):
_common.encode(self._target), _common.encode(self._target),
request_serializer, request_serializer,
response_deserializer, response_deserializer,
_registered_call_handle,
) )
# pylint: disable=arguments-differ
def stream_unary( def stream_unary(
self, self,
method: str, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.StreamUnaryMultiCallable: ) -> grpc.StreamUnaryMultiCallable:
_registered_call_handle = None
if _registered_method:
_registered_call_handle = self._get_registered_call_handle(method)
return _StreamUnaryMultiCallable( return _StreamUnaryMultiCallable(
self._channel, self._channel,
_channel_managed_call_management(self._call_state), _channel_managed_call_management(self._call_state),
@ -2134,14 +2196,20 @@ class Channel(grpc.Channel):
_common.encode(self._target), _common.encode(self._target),
request_serializer, request_serializer,
response_deserializer, response_deserializer,
_registered_call_handle,
) )
# pylint: disable=arguments-differ
def stream_stream( def stream_stream(
self, self,
method: str, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.StreamStreamMultiCallable: ) -> grpc.StreamStreamMultiCallable:
_registered_call_handle = None
if _registered_method:
_registered_call_handle = self._get_registered_call_handle(method)
return _StreamStreamMultiCallable( return _StreamStreamMultiCallable(
self._channel, self._channel,
_channel_managed_call_management(self._call_state), _channel_managed_call_management(self._call_state),
@ -2149,6 +2217,7 @@ class Channel(grpc.Channel):
_common.encode(self._target), _common.encode(self._target),
request_serializer, request_serializer,
response_deserializer, response_deserializer,
_registered_call_handle,
) )
def _unsubscribe_all(self) -> None: def _unsubscribe_all(self) -> None:

@ -74,6 +74,13 @@ cdef class SegregatedCall:
cdef class Channel: cdef class Channel:
cdef _ChannelState _state cdef _ChannelState _state
cdef dict _registered_call_handles
# TODO(https://github.com/grpc/grpc/issues/15662): Eliminate this. # TODO(https://github.com/grpc/grpc/issues/15662): Eliminate this.
cdef tuple _arguments cdef tuple _arguments
cdef class CallHandle:
cdef void *c_call_handle
cdef object method

@ -101,6 +101,25 @@ cdef class _ChannelState:
self.connectivity_due = set() self.connectivity_due = set()
self.closed_reason = None self.closed_reason = None
cdef class CallHandle:
def __cinit__(self, _ChannelState channel_state, object method):
self.method = method
cpython.Py_INCREF(method)
# Note that since we always pass None for host, we set the
# second-to-last parameter of grpc_channel_register_call to a fixed
# NULL value.
self.c_call_handle = grpc_channel_register_call(
channel_state.c_channel, <const char *>method, NULL, NULL)
def __dealloc__(self):
cpython.Py_DECREF(self.method)
@property
def call_handle(self):
return cpython.PyLong_FromVoidPtr(self.c_call_handle)
cdef tuple _operate(grpc_call *c_call, object operations, object user_tag): cdef tuple _operate(grpc_call *c_call, object operations, object user_tag):
cdef grpc_call_error c_call_error cdef grpc_call_error c_call_error
@ -199,7 +218,7 @@ cdef void _call(
grpc_completion_queue *c_completion_queue, on_success, int flags, method, grpc_completion_queue *c_completion_queue, on_success, int flags, method,
host, object deadline, CallCredentials credentials, host, object deadline, CallCredentials credentials,
object operationses_and_user_tags, object metadata, object operationses_and_user_tags, object metadata,
object context) except *: object context, object registered_call_handle) except *:
"""Invokes an RPC. """Invokes an RPC.
Args: Args:
@ -226,6 +245,8 @@ cdef void _call(
must be present in the first element of this value. must be present in the first element of this value.
metadata: The metadata for this call. metadata: The metadata for this call.
context: Context object for distributed tracing. context: Context object for distributed tracing.
registered_call_handle: An int representing the call handle of the method, or
None if the method is not registered.
""" """
cdef grpc_slice method_slice cdef grpc_slice method_slice
cdef grpc_slice host_slice cdef grpc_slice host_slice
@ -242,10 +263,16 @@ cdef void _call(
else: else:
host_slice = _slice_from_bytes(host) host_slice = _slice_from_bytes(host)
host_slice_ptr = &host_slice host_slice_ptr = &host_slice
call_state.c_call = grpc_channel_create_call( if registered_call_handle:
channel_state.c_channel, NULL, flags, call_state.c_call = grpc_channel_create_registered_call(
c_completion_queue, method_slice, host_slice_ptr, channel_state.c_channel, NULL, flags,
_timespec_from_time(deadline), NULL) c_completion_queue, cpython.PyLong_AsVoidPtr(registered_call_handle),
_timespec_from_time(deadline), NULL)
else:
call_state.c_call = grpc_channel_create_call(
channel_state.c_channel, NULL, flags,
c_completion_queue, method_slice, host_slice_ptr,
_timespec_from_time(deadline), NULL)
grpc_slice_unref(method_slice) grpc_slice_unref(method_slice)
if host_slice_ptr: if host_slice_ptr:
grpc_slice_unref(host_slice) grpc_slice_unref(host_slice)
@ -309,7 +336,7 @@ cdef class IntegratedCall:
cdef IntegratedCall _integrated_call( cdef IntegratedCall _integrated_call(
_ChannelState state, int flags, method, host, object deadline, _ChannelState state, int flags, method, host, object deadline,
object metadata, CallCredentials credentials, operationses_and_user_tags, object metadata, CallCredentials credentials, operationses_and_user_tags,
object context): object context, object registered_call_handle):
call_state = _CallState() call_state = _CallState()
def on_success(started_tags): def on_success(started_tags):
@ -318,7 +345,8 @@ cdef IntegratedCall _integrated_call(
_call( _call(
state, call_state, state.c_call_completion_queue, on_success, flags, state, call_state, state.c_call_completion_queue, on_success, flags,
method, host, deadline, credentials, operationses_and_user_tags, metadata, context) method, host, deadline, credentials, operationses_and_user_tags,
metadata, context, registered_call_handle)
return IntegratedCall(state, call_state) return IntegratedCall(state, call_state)
@ -371,7 +399,7 @@ cdef class SegregatedCall:
cdef SegregatedCall _segregated_call( cdef SegregatedCall _segregated_call(
_ChannelState state, int flags, method, host, object deadline, _ChannelState state, int flags, method, host, object deadline,
object metadata, CallCredentials credentials, operationses_and_user_tags, object metadata, CallCredentials credentials, operationses_and_user_tags,
object context): object context, object registered_call_handle):
cdef _CallState call_state = _CallState() cdef _CallState call_state = _CallState()
cdef SegregatedCall segregated_call cdef SegregatedCall segregated_call
cdef grpc_completion_queue *c_completion_queue cdef grpc_completion_queue *c_completion_queue
@ -389,7 +417,7 @@ cdef SegregatedCall _segregated_call(
_call( _call(
state, call_state, c_completion_queue, on_success, flags, method, host, state, call_state, c_completion_queue, on_success, flags, method, host,
deadline, credentials, operationses_and_user_tags, metadata, deadline, credentials, operationses_and_user_tags, metadata,
context) context, registered_call_handle)
except: except:
_destroy_c_completion_queue(c_completion_queue) _destroy_c_completion_queue(c_completion_queue)
raise raise
@ -486,6 +514,7 @@ cdef class Channel:
else grpc_insecure_credentials_create()) else grpc_insecure_credentials_create())
self._state.c_channel = grpc_channel_create( self._state.c_channel = grpc_channel_create(
<char *>target, c_channel_credentials, channel_args.c_args()) <char *>target, c_channel_credentials, channel_args.c_args())
self._registered_call_handles = {}
grpc_channel_credentials_release(c_channel_credentials) grpc_channel_credentials_release(c_channel_credentials)
def target(self): def target(self):
@ -499,10 +528,10 @@ cdef class Channel:
def integrated_call( def integrated_call(
self, int flags, method, host, object deadline, object metadata, self, int flags, method, host, object deadline, object metadata,
CallCredentials credentials, operationses_and_tags, CallCredentials credentials, operationses_and_tags,
object context = None): object context = None, object registered_call_handle = None):
return _integrated_call( return _integrated_call(
self._state, flags, method, host, deadline, metadata, credentials, self._state, flags, method, host, deadline, metadata, credentials,
operationses_and_tags, context) operationses_and_tags, context, registered_call_handle)
def next_call_event(self): def next_call_event(self):
def on_success(tag): def on_success(tag):
@ -521,10 +550,10 @@ cdef class Channel:
def segregated_call( def segregated_call(
self, int flags, method, host, object deadline, object metadata, self, int flags, method, host, object deadline, object metadata,
CallCredentials credentials, operationses_and_tags, CallCredentials credentials, operationses_and_tags,
object context = None): object context = None, object registered_call_handle = None):
return _segregated_call( return _segregated_call(
self._state, flags, method, host, deadline, metadata, credentials, self._state, flags, method, host, deadline, metadata, credentials,
operationses_and_tags, context) operationses_and_tags, context, registered_call_handle)
def check_connectivity_state(self, bint try_to_connect): def check_connectivity_state(self, bint try_to_connect):
with self._state.condition: with self._state.condition:
@ -543,3 +572,19 @@ cdef class Channel:
def close_on_fork(self, code, details): def close_on_fork(self, code, details):
_close(self, code, details, True) _close(self, code, details, True)
def get_registered_call_handle(self, method):
"""
Get or registers a call handler for a method.
This method is not thread-safe.
Args:
method: Required, the method name for the RPC.
Returns:
The registered call handle pointer in the form of a Python Long.
"""
if method not in self._registered_call_handles.keys():
self._registered_call_handles[method] = CallHandle(self._state, method)
return self._registered_call_handles[method].call_handle

@ -433,6 +433,12 @@ cdef extern from "grpc/grpc.h":
grpc_channel *channel, grpc_call *parent_call, uint32_t propagation_mask, grpc_channel *channel, grpc_call *parent_call, uint32_t propagation_mask,
grpc_completion_queue *completion_queue, grpc_slice method, grpc_completion_queue *completion_queue, grpc_slice method,
const grpc_slice *host, gpr_timespec deadline, void *reserved) nogil const grpc_slice *host, gpr_timespec deadline, void *reserved) nogil
void *grpc_channel_register_call(
grpc_channel *channel, const char *method, const char *host, void *reserved) nogil
grpc_call *grpc_channel_create_registered_call(
grpc_channel *channel, grpc_call *parent_call, uint32_t propagation_mask,
grpc_completion_queue *completion_queue, void* registered_call_handle,
gpr_timespec deadline, void *reserved) nogil
grpc_connectivity_state grpc_channel_check_connectivity_state( grpc_connectivity_state grpc_channel_check_connectivity_state(
grpc_channel *channel, int try_to_connect) nogil grpc_channel *channel, int try_to_connect) nogil
void grpc_channel_watch_connectivity_state( void grpc_channel_watch_connectivity_state(

@ -684,57 +684,85 @@ class _Channel(grpc.Channel):
def unsubscribe(self, callback: Callable): def unsubscribe(self, callback: Callable):
self._channel.unsubscribe(callback) self._channel.unsubscribe(callback)
# pylint: disable=arguments-differ
def unary_unary( def unary_unary(
self, self,
method: str, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.UnaryUnaryMultiCallable: ) -> grpc.UnaryUnaryMultiCallable:
# pytype: disable=wrong-arg-count
thunk = lambda m: self._channel.unary_unary( thunk = lambda m: self._channel.unary_unary(
m, request_serializer, response_deserializer m,
request_serializer,
response_deserializer,
_registered_method,
) )
# pytype: enable=wrong-arg-count
if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor): if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor):
return _UnaryUnaryMultiCallable(thunk, method, self._interceptor) return _UnaryUnaryMultiCallable(thunk, method, self._interceptor)
else: else:
return thunk(method) return thunk(method)
# pylint: disable=arguments-differ
def unary_stream( def unary_stream(
self, self,
method: str, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.UnaryStreamMultiCallable: ) -> grpc.UnaryStreamMultiCallable:
# pytype: disable=wrong-arg-count
thunk = lambda m: self._channel.unary_stream( thunk = lambda m: self._channel.unary_stream(
m, request_serializer, response_deserializer m,
request_serializer,
response_deserializer,
_registered_method,
) )
# pytype: enable=wrong-arg-count
if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor): if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor):
return _UnaryStreamMultiCallable(thunk, method, self._interceptor) return _UnaryStreamMultiCallable(thunk, method, self._interceptor)
else: else:
return thunk(method) return thunk(method)
# pylint: disable=arguments-differ
def stream_unary( def stream_unary(
self, self,
method: str, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.StreamUnaryMultiCallable: ) -> grpc.StreamUnaryMultiCallable:
# pytype: disable=wrong-arg-count
thunk = lambda m: self._channel.stream_unary( thunk = lambda m: self._channel.stream_unary(
m, request_serializer, response_deserializer m,
request_serializer,
response_deserializer,
_registered_method,
) )
# pytype: enable=wrong-arg-count
if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor): if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor):
return _StreamUnaryMultiCallable(thunk, method, self._interceptor) return _StreamUnaryMultiCallable(thunk, method, self._interceptor)
else: else:
return thunk(method) return thunk(method)
# pylint: disable=arguments-differ
def stream_stream( def stream_stream(
self, self,
method: str, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.StreamStreamMultiCallable: ) -> grpc.StreamStreamMultiCallable:
# pytype: disable=wrong-arg-count
thunk = lambda m: self._channel.stream_stream( thunk = lambda m: self._channel.stream_stream(
m, request_serializer, response_deserializer m,
request_serializer,
response_deserializer,
_registered_method,
) )
# pytype: enable=wrong-arg-count
if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor): if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor):
return _StreamStreamMultiCallable(thunk, method, self._interceptor) return _StreamStreamMultiCallable(thunk, method, self._interceptor)
else: else:

@ -159,7 +159,19 @@ class ChannelCache:
channel_credentials: Optional[grpc.ChannelCredentials], channel_credentials: Optional[grpc.ChannelCredentials],
insecure: bool, insecure: bool,
compression: Optional[grpc.Compression], compression: Optional[grpc.Compression],
) -> grpc.Channel: method: str,
_registered_method: bool,
) -> Tuple[grpc.Channel, Optional[int]]:
"""Get a channel from cache or creates a new channel.
This method also takes care of register method for channel,
which means we'll register a new call handle if we're calling a
non-registered method for an existing channel.
Returns:
A tuple with two items. The first item is the channel, second item is
the call handle if the method is registered, None if it's not registered.
"""
if insecure and channel_credentials: if insecure and channel_credentials:
raise ValueError( raise ValueError(
"The insecure option is mutually exclusive with " "The insecure option is mutually exclusive with "
@ -176,18 +188,25 @@ class ChannelCache:
key = (target, options, channel_credentials, compression) key = (target, options, channel_credentials, compression)
with self._lock: with self._lock:
channel_data = self._mapping.get(key, None) channel_data = self._mapping.get(key, None)
call_handle = None
if channel_data is not None: if channel_data is not None:
channel = channel_data[0] channel = channel_data[0]
# Register a new call handle if we're calling a registered method for an
# existing channel and this method is not registered.
if _registered_method:
call_handle = channel._get_registered_call_handle(method)
self._mapping.pop(key) self._mapping.pop(key)
self._mapping[key] = ( self._mapping[key] = (
channel, channel,
datetime.datetime.now() + _EVICTION_PERIOD, datetime.datetime.now() + _EVICTION_PERIOD,
) )
return channel return channel, call_handle
else: else:
channel = _create_channel( channel = _create_channel(
target, options, channel_credentials, compression target, options, channel_credentials, compression
) )
if _registered_method:
call_handle = channel._get_registered_call_handle(method)
self._mapping[key] = ( self._mapping[key] = (
channel, channel,
datetime.datetime.now() + _EVICTION_PERIOD, datetime.datetime.now() + _EVICTION_PERIOD,
@ -197,7 +216,7 @@ class ChannelCache:
or len(self._mapping) >= _MAXIMUM_CHANNELS or len(self._mapping) >= _MAXIMUM_CHANNELS
): ):
self._condition.notify() self._condition.notify()
return channel return channel, call_handle
def _test_only_channel_count(self) -> int: def _test_only_channel_count(self) -> int:
with self._lock: with self._lock:
@ -205,6 +224,7 @@ class ChannelCache:
@experimental_api @experimental_api
# pylint: disable=too-many-locals
def unary_unary( def unary_unary(
request: RequestType, request: RequestType,
target: str, target: str,
@ -219,6 +239,7 @@ def unary_unary(
wait_for_ready: Optional[bool] = None, wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = _DEFAULT_TIMEOUT, timeout: Optional[float] = _DEFAULT_TIMEOUT,
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None, metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None,
_registered_method: Optional[bool] = False,
) -> ResponseType: ) -> ResponseType:
"""Invokes a unary-unary RPC without an explicitly specified channel. """Invokes a unary-unary RPC without an explicitly specified channel.
@ -272,11 +293,17 @@ def unary_unary(
Returns: Returns:
The response to the RPC. The response to the RPC.
""" """
channel = ChannelCache.get().get_channel( channel, method_handle = ChannelCache.get().get_channel(
target, options, channel_credentials, insecure, compression target,
options,
channel_credentials,
insecure,
compression,
method,
_registered_method,
) )
multicallable = channel.unary_unary( multicallable = channel.unary_unary(
method, request_serializer, response_deserializer method, request_serializer, response_deserializer, method_handle
) )
wait_for_ready = wait_for_ready if wait_for_ready is not None else True wait_for_ready = wait_for_ready if wait_for_ready is not None else True
return multicallable( return multicallable(
@ -289,6 +316,7 @@ def unary_unary(
@experimental_api @experimental_api
# pylint: disable=too-many-locals
def unary_stream( def unary_stream(
request: RequestType, request: RequestType,
target: str, target: str,
@ -303,6 +331,7 @@ def unary_stream(
wait_for_ready: Optional[bool] = None, wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = _DEFAULT_TIMEOUT, timeout: Optional[float] = _DEFAULT_TIMEOUT,
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None, metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None,
_registered_method: Optional[bool] = False,
) -> Iterator[ResponseType]: ) -> Iterator[ResponseType]:
"""Invokes a unary-stream RPC without an explicitly specified channel. """Invokes a unary-stream RPC without an explicitly specified channel.
@ -355,11 +384,17 @@ def unary_stream(
Returns: Returns:
An iterator of responses. An iterator of responses.
""" """
channel = ChannelCache.get().get_channel( channel, method_handle = ChannelCache.get().get_channel(
target, options, channel_credentials, insecure, compression target,
options,
channel_credentials,
insecure,
compression,
method,
_registered_method,
) )
multicallable = channel.unary_stream( multicallable = channel.unary_stream(
method, request_serializer, response_deserializer method, request_serializer, response_deserializer, method_handle
) )
wait_for_ready = wait_for_ready if wait_for_ready is not None else True wait_for_ready = wait_for_ready if wait_for_ready is not None else True
return multicallable( return multicallable(
@ -372,6 +407,7 @@ def unary_stream(
@experimental_api @experimental_api
# pylint: disable=too-many-locals
def stream_unary( def stream_unary(
request_iterator: Iterator[RequestType], request_iterator: Iterator[RequestType],
target: str, target: str,
@ -386,6 +422,7 @@ def stream_unary(
wait_for_ready: Optional[bool] = None, wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = _DEFAULT_TIMEOUT, timeout: Optional[float] = _DEFAULT_TIMEOUT,
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None, metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None,
_registered_method: Optional[bool] = False,
) -> ResponseType: ) -> ResponseType:
"""Invokes a stream-unary RPC without an explicitly specified channel. """Invokes a stream-unary RPC without an explicitly specified channel.
@ -438,11 +475,17 @@ def stream_unary(
Returns: Returns:
The response to the RPC. The response to the RPC.
""" """
channel = ChannelCache.get().get_channel( channel, method_handle = ChannelCache.get().get_channel(
target, options, channel_credentials, insecure, compression target,
options,
channel_credentials,
insecure,
compression,
method,
_registered_method,
) )
multicallable = channel.stream_unary( multicallable = channel.stream_unary(
method, request_serializer, response_deserializer method, request_serializer, response_deserializer, method_handle
) )
wait_for_ready = wait_for_ready if wait_for_ready is not None else True wait_for_ready = wait_for_ready if wait_for_ready is not None else True
return multicallable( return multicallable(
@ -455,6 +498,7 @@ def stream_unary(
@experimental_api @experimental_api
# pylint: disable=too-many-locals
def stream_stream( def stream_stream(
request_iterator: Iterator[RequestType], request_iterator: Iterator[RequestType],
target: str, target: str,
@ -469,6 +513,7 @@ def stream_stream(
wait_for_ready: Optional[bool] = None, wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = _DEFAULT_TIMEOUT, timeout: Optional[float] = _DEFAULT_TIMEOUT,
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None, metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None,
_registered_method: Optional[bool] = False,
) -> Iterator[ResponseType]: ) -> Iterator[ResponseType]:
"""Invokes a stream-stream RPC without an explicitly specified channel. """Invokes a stream-stream RPC without an explicitly specified channel.
@ -521,11 +566,17 @@ def stream_stream(
Returns: Returns:
An iterator of responses. An iterator of responses.
""" """
channel = ChannelCache.get().get_channel( channel, method_handle = ChannelCache.get().get_channel(
target, options, channel_credentials, insecure, compression target,
options,
channel_credentials,
insecure,
compression,
method,
_registered_method,
) )
multicallable = channel.stream_stream( multicallable = channel.stream_stream(
method, request_serializer, response_deserializer method, request_serializer, response_deserializer, method_handle
) )
wait_for_ready = wait_for_ready if wait_for_ready is not None else True wait_for_ready = wait_for_ready if wait_for_ready is not None else True
return multicallable( return multicallable(

@ -272,6 +272,7 @@ class Channel(abc.ABC):
method: str, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> UnaryUnaryMultiCallable: ) -> UnaryUnaryMultiCallable:
"""Creates a UnaryUnaryMultiCallable for a unary-unary method. """Creates a UnaryUnaryMultiCallable for a unary-unary method.
@ -282,6 +283,8 @@ class Channel(abc.ABC):
response_deserializer: Optional :term:`deserializer` for deserializing the response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None response message. Response goes undeserialized in case None
is passed. is passed.
_registered_method: Implementation Private. Optional: A bool representing
whether the method is registered.
Returns: Returns:
A UnaryUnaryMultiCallable value for the named unary-unary method. A UnaryUnaryMultiCallable value for the named unary-unary method.
@ -293,6 +296,7 @@ class Channel(abc.ABC):
method: str, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> UnaryStreamMultiCallable: ) -> UnaryStreamMultiCallable:
"""Creates a UnaryStreamMultiCallable for a unary-stream method. """Creates a UnaryStreamMultiCallable for a unary-stream method.
@ -303,6 +307,8 @@ class Channel(abc.ABC):
response_deserializer: Optional :term:`deserializer` for deserializing the response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None response message. Response goes undeserialized in case None
is passed. is passed.
_registered_method: Implementation Private. Optional: A bool representing
whether the method is registered.
Returns: Returns:
A UnarySteramMultiCallable value for the named unary-stream method. A UnarySteramMultiCallable value for the named unary-stream method.
@ -314,6 +320,7 @@ class Channel(abc.ABC):
method: str, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> StreamUnaryMultiCallable: ) -> StreamUnaryMultiCallable:
"""Creates a StreamUnaryMultiCallable for a stream-unary method. """Creates a StreamUnaryMultiCallable for a stream-unary method.
@ -324,6 +331,8 @@ class Channel(abc.ABC):
response_deserializer: Optional :term:`deserializer` for deserializing the response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None response message. Response goes undeserialized in case None
is passed. is passed.
_registered_method: Implementation Private. Optional: A bool representing
whether the method is registered.
Returns: Returns:
A StreamUnaryMultiCallable value for the named stream-unary method. A StreamUnaryMultiCallable value for the named stream-unary method.
@ -335,6 +344,7 @@ class Channel(abc.ABC):
method: str, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> StreamStreamMultiCallable: ) -> StreamStreamMultiCallable:
"""Creates a StreamStreamMultiCallable for a stream-stream method. """Creates a StreamStreamMultiCallable for a stream-stream method.
@ -345,6 +355,8 @@ class Channel(abc.ABC):
response_deserializer: Optional :term:`deserializer` for deserializing the response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None response message. Response goes undeserialized in case None
is passed. is passed.
_registered_method: Implementation Private. Optional: A bool representing
whether the method is registered.
Returns: Returns:
A StreamStreamMultiCallable value for the named stream-stream method. A StreamStreamMultiCallable value for the named stream-stream method.

@ -478,11 +478,20 @@ class Channel(_base_channel.Channel):
await self.wait_for_state_change(state) await self.wait_for_state_change(state)
state = self.get_state(try_to_connect=True) state = self.get_state(try_to_connect=True)
# TODO(xuanwn): Implement this method after we have
# observability for Asyncio.
def _get_registered_call_handle(self, method: str) -> int:
pass
# TODO(xuanwn): Implement _registered_method after we have
# observability for Asyncio.
# pylint: disable=arguments-differ,unused-argument
def unary_unary( def unary_unary(
self, self,
method: str, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> UnaryUnaryMultiCallable: ) -> UnaryUnaryMultiCallable:
return UnaryUnaryMultiCallable( return UnaryUnaryMultiCallable(
self._channel, self._channel,
@ -494,11 +503,15 @@ class Channel(_base_channel.Channel):
self._loop, self._loop,
) )
# TODO(xuanwn): Implement _registered_method after we have
# observability for Asyncio.
# pylint: disable=arguments-differ,unused-argument
def unary_stream( def unary_stream(
self, self,
method: str, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> UnaryStreamMultiCallable: ) -> UnaryStreamMultiCallable:
return UnaryStreamMultiCallable( return UnaryStreamMultiCallable(
self._channel, self._channel,
@ -510,11 +523,15 @@ class Channel(_base_channel.Channel):
self._loop, self._loop,
) )
# TODO(xuanwn): Implement _registered_method after we have
# observability for Asyncio.
# pylint: disable=arguments-differ,unused-argument
def stream_unary( def stream_unary(
self, self,
method: str, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> StreamUnaryMultiCallable: ) -> StreamUnaryMultiCallable:
return StreamUnaryMultiCallable( return StreamUnaryMultiCallable(
self._channel, self._channel,
@ -526,11 +543,15 @@ class Channel(_base_channel.Channel):
self._loop, self._loop,
) )
# TODO(xuanwn): Implement _registered_method after we have
# observability for Asyncio.
# pylint: disable=arguments-differ,unused-argument
def stream_stream( def stream_stream(
self, self,
method: str, method: str,
request_serializer: Optional[SerializingFunction] = None, request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> StreamStreamMultiCallable: ) -> StreamStreamMultiCallable:
return StreamStreamMultiCallable( return StreamStreamMultiCallable(
self._channel, self._channel,

@ -31,23 +31,42 @@ class TestingChannel(grpc_testing.Channel):
def unsubscribe(self, callback): def unsubscribe(self, callback):
raise NotImplementedError() raise NotImplementedError()
def _get_registered_call_handle(self, method: str) -> int:
pass
def unary_unary( def unary_unary(
self, method, request_serializer=None, response_deserializer=None self,
method,
request_serializer=None,
response_deserializer=None,
_registered_method=False,
): ):
return _multi_callable.UnaryUnary(method, self._state) return _multi_callable.UnaryUnary(method, self._state)
def unary_stream( def unary_stream(
self, method, request_serializer=None, response_deserializer=None self,
method,
request_serializer=None,
response_deserializer=None,
_registered_method=False,
): ):
return _multi_callable.UnaryStream(method, self._state) return _multi_callable.UnaryStream(method, self._state)
def stream_unary( def stream_unary(
self, method, request_serializer=None, response_deserializer=None self,
method,
request_serializer=None,
response_deserializer=None,
_registered_method=False,
): ):
return _multi_callable.StreamUnary(method, self._state) return _multi_callable.StreamUnary(method, self._state)
def stream_stream( def stream_stream(
self, method, request_serializer=None, response_deserializer=None self,
method,
request_serializer=None,
response_deserializer=None,
_registered_method=False,
): ):
return _multi_callable.StreamStream(method, self._state) return _multi_callable.StreamStream(method, self._state)

@ -99,16 +99,20 @@ class ChannelzServicerTest(unittest.TestCase):
def _send_successful_unary_unary(self, idx): def _send_successful_unary_unary(self, idx):
_, r = ( _, r = (
self._pairs[idx] self._pairs[idx]
.channel.unary_unary(_SUCCESSFUL_UNARY_UNARY) .channel.unary_unary(
_SUCCESSFUL_UNARY_UNARY,
_registered_method=True,
)
.with_call(_REQUEST) .with_call(_REQUEST)
) )
self.assertEqual(r.code(), grpc.StatusCode.OK) self.assertEqual(r.code(), grpc.StatusCode.OK)
def _send_failed_unary_unary(self, idx): def _send_failed_unary_unary(self, idx):
try: try:
self._pairs[idx].channel.unary_unary(_FAILED_UNARY_UNARY).with_call( self._pairs[idx].channel.unary_unary(
_REQUEST _FAILED_UNARY_UNARY,
) _registered_method=True,
).with_call(_REQUEST)
except grpc.RpcError: except grpc.RpcError:
return return
else: else:
@ -117,7 +121,10 @@ class ChannelzServicerTest(unittest.TestCase):
def _send_successful_stream_stream(self, idx): def _send_successful_stream_stream(self, idx):
response_iterator = ( response_iterator = (
self._pairs[idx] self._pairs[idx]
.channel.stream_stream(_SUCCESSFUL_STREAM_STREAM) .channel.stream_stream(
_SUCCESSFUL_STREAM_STREAM,
_registered_method=True,
)
.__call__(iter([_REQUEST] * test_constants.STREAM_LENGTH)) .__call__(iter([_REQUEST] * test_constants.STREAM_LENGTH))
) )
cnt = 0 cnt = 0

@ -86,7 +86,10 @@ class TestCsds(unittest.TestCase):
# Force the XdsClient to initialize and request a resource # Force the XdsClient to initialize and request a resource
with self.assertRaises(grpc.RpcError) as rpc_error: with self.assertRaises(grpc.RpcError) as rpc_error:
dummy_channel.unary_unary("")(b"", wait_for_ready=False, timeout=1) dummy_channel.unary_unary(
"",
_registered_method=True,
)(b"", wait_for_ready=False, timeout=1)
self.assertEqual( self.assertEqual(
grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.exception.code() grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.exception.code()
) )

@ -543,6 +543,17 @@ class PythonPluginTest(unittest.TestCase):
) )
service.server.stop(None) service.server.stop(None)
def testRegisteredMethod(self):
"""Tests that we're setting _registered_call_handle when create call using generated stub."""
service = _CreateService()
self.assertTrue(service.stub.UnaryCall._registered_call_handle)
self.assertTrue(
service.stub.StreamingOutputCall._registered_call_handle
)
self.assertTrue(service.stub.StreamingInputCall._registered_call_handle)
self.assertTrue(service.stub.FullDuplexCall._registered_call_handle)
service.server.stop(None)
@unittest.skipIf( @unittest.skipIf(
sys.version_info[0] < 3 or sys.version_info[1] < 6, sys.version_info[0] < 3 or sys.version_info[1] < 6,

@ -32,13 +32,16 @@ _TIMEOUT = 60 * 60 * 24
class GenericStub(object): class GenericStub(object):
def __init__(self, channel): def __init__(self, channel):
self.UnaryCall = channel.unary_unary( self.UnaryCall = channel.unary_unary(
"/grpc.testing.BenchmarkService/UnaryCall" "/grpc.testing.BenchmarkService/UnaryCall",
_registered_method=True,
) )
self.StreamingFromServer = channel.unary_stream( self.StreamingFromServer = channel.unary_stream(
"/grpc.testing.BenchmarkService/StreamingFromServer" "/grpc.testing.BenchmarkService/StreamingFromServer",
_registered_method=True,
) )
self.StreamingCall = channel.stream_stream( self.StreamingCall = channel.stream_stream(
"/grpc.testing.BenchmarkService/StreamingCall" "/grpc.testing.BenchmarkService/StreamingCall",
_registered_method=True,
) )

@ -138,7 +138,10 @@ class StatusTest(unittest.TestCase):
self._channel.close() self._channel.close()
def test_status_ok(self): def test_status_ok(self):
_, call = self._channel.unary_unary(_STATUS_OK).with_call(_REQUEST) _, call = self._channel.unary_unary(
_STATUS_OK,
_registered_method=True,
).with_call(_REQUEST)
# Succeed RPC doesn't have status # Succeed RPC doesn't have status
status = rpc_status.from_call(call) status = rpc_status.from_call(call)
@ -146,7 +149,10 @@ class StatusTest(unittest.TestCase):
def test_status_not_ok(self): def test_status_not_ok(self):
with self.assertRaises(grpc.RpcError) as exception_context: with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_STATUS_NOT_OK).with_call(_REQUEST) self._channel.unary_unary(
_STATUS_NOT_OK,
_registered_method=True,
).with_call(_REQUEST)
rpc_error = exception_context.exception rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL)
@ -156,7 +162,10 @@ class StatusTest(unittest.TestCase):
def test_error_details(self): def test_error_details(self):
with self.assertRaises(grpc.RpcError) as exception_context: with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_ERROR_DETAILS).with_call(_REQUEST) self._channel.unary_unary(
_ERROR_DETAILS,
_registered_method=True,
).with_call(_REQUEST)
rpc_error = exception_context.exception rpc_error = exception_context.exception
status = rpc_status.from_call(rpc_error) status = rpc_status.from_call(rpc_error)
@ -173,7 +182,10 @@ class StatusTest(unittest.TestCase):
def test_code_message_validation(self): def test_code_message_validation(self):
with self.assertRaises(grpc.RpcError) as exception_context: with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_INCONSISTENT).with_call(_REQUEST) self._channel.unary_unary(
_INCONSISTENT,
_registered_method=True,
).with_call(_REQUEST)
rpc_error = exception_context.exception rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), grpc.StatusCode.NOT_FOUND) self.assertEqual(rpc_error.code(), grpc.StatusCode.NOT_FOUND)
@ -182,7 +194,10 @@ class StatusTest(unittest.TestCase):
def test_invalid_code(self): def test_invalid_code(self):
with self.assertRaises(grpc.RpcError) as exception_context: with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_INVALID_CODE).with_call(_REQUEST) self._channel.unary_unary(
_INVALID_CODE,
_registered_method=True,
).with_call(_REQUEST)
rpc_error = exception_context.exception rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN) self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN)
# Invalid status code exception raised during coversion # Invalid status code exception raised during coversion

@ -107,7 +107,10 @@ class AbortTest(unittest.TestCase):
def test_abort(self): def test_abort(self):
with self.assertRaises(grpc.RpcError) as exception_context: with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_ABORT)(_REQUEST) self._channel.unary_unary(
_ABORT,
_registered_method=True,
)(_REQUEST)
rpc_error = exception_context.exception rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL)
@ -124,7 +127,10 @@ class AbortTest(unittest.TestCase):
# Servicer will abort() after creating a local ref to do_not_leak_me. # Servicer will abort() after creating a local ref to do_not_leak_me.
with self.assertRaises(grpc.RpcError): with self.assertRaises(grpc.RpcError):
self._channel.unary_unary(_ABORT)(_REQUEST) self._channel.unary_unary(
_ABORT,
_registered_method=True,
)(_REQUEST)
# Server may still have a stack frame reference to the exception even # Server may still have a stack frame reference to the exception even
# after client sees error, so ensure server has shutdown. # after client sees error, so ensure server has shutdown.
@ -134,7 +140,10 @@ class AbortTest(unittest.TestCase):
def test_abort_with_status(self): def test_abort_with_status(self):
with self.assertRaises(grpc.RpcError) as exception_context: with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_ABORT_WITH_STATUS)(_REQUEST) self._channel.unary_unary(
_ABORT_WITH_STATUS,
_registered_method=True,
)(_REQUEST)
rpc_error = exception_context.exception rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL)
@ -143,7 +152,10 @@ class AbortTest(unittest.TestCase):
def test_invalid_code(self): def test_invalid_code(self):
with self.assertRaises(grpc.RpcError) as exception_context: with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_INVALID_CODE)(_REQUEST) self._channel.unary_unary(
_INVALID_CODE,
_registered_method=True,
)(_REQUEST)
rpc_error = exception_context.exception rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN) self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN)

@ -78,7 +78,10 @@ class AuthContextTest(unittest.TestCase):
server.start() server.start()
with grpc.insecure_channel("localhost:%d" % port) as channel: with grpc.insecure_channel("localhost:%d" % port) as channel:
response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) response = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)(_REQUEST)
server.stop(None) server.stop(None)
auth_data = pickle.loads(response) auth_data = pickle.loads(response)
@ -115,7 +118,10 @@ class AuthContextTest(unittest.TestCase):
channel_creds, channel_creds,
options=_PROPERTY_OPTIONS, options=_PROPERTY_OPTIONS,
) )
response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) response = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)(_REQUEST)
channel.close() channel.close()
server.stop(None) server.stop(None)
@ -161,7 +167,10 @@ class AuthContextTest(unittest.TestCase):
options=_PROPERTY_OPTIONS, options=_PROPERTY_OPTIONS,
) )
response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) response = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)(_REQUEST)
channel.close() channel.close()
server.stop(None) server.stop(None)
@ -180,7 +189,10 @@ class AuthContextTest(unittest.TestCase):
channel = grpc.secure_channel( channel = grpc.secure_channel(
"localhost:{}".format(port), channel_creds, options=channel_options "localhost:{}".format(port), channel_creds, options=channel_options
) )
response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) response = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)(_REQUEST)
auth_data = pickle.loads(response) auth_data = pickle.loads(response)
self.assertEqual( self.assertEqual(
expect_ssl_session_reused, expect_ssl_session_reused,

@ -123,7 +123,10 @@ class ChannelCloseTest(unittest.TestCase):
def test_close_immediately_after_call_invocation(self): def test_close_immediately_after_call_invocation(self):
channel = grpc.insecure_channel("localhost:{}".format(self._port)) channel = grpc.insecure_channel("localhost:{}".format(self._port))
multi_callable = channel.stream_stream(_STREAM_URI) multi_callable = channel.stream_stream(
_STREAM_URI,
_registered_method=True,
)
request_iterator = _Pipe(()) request_iterator = _Pipe(())
response_iterator = multi_callable(request_iterator) response_iterator = multi_callable(request_iterator)
channel.close() channel.close()
@ -133,7 +136,10 @@ class ChannelCloseTest(unittest.TestCase):
def test_close_while_call_active(self): def test_close_while_call_active(self):
channel = grpc.insecure_channel("localhost:{}".format(self._port)) channel = grpc.insecure_channel("localhost:{}".format(self._port))
multi_callable = channel.stream_stream(_STREAM_URI) multi_callable = channel.stream_stream(
_STREAM_URI,
_registered_method=True,
)
request_iterator = _Pipe((b"abc",)) request_iterator = _Pipe((b"abc",))
response_iterator = multi_callable(request_iterator) response_iterator = multi_callable(request_iterator)
next(response_iterator) next(response_iterator)
@ -146,7 +152,10 @@ class ChannelCloseTest(unittest.TestCase):
with grpc.insecure_channel( with grpc.insecure_channel(
"localhost:{}".format(self._port) "localhost:{}".format(self._port)
) as channel: # pylint: disable=bad-continuation ) as channel: # pylint: disable=bad-continuation
multi_callable = channel.stream_stream(_STREAM_URI) multi_callable = channel.stream_stream(
_STREAM_URI,
_registered_method=True,
)
request_iterator = _Pipe((b"abc",)) request_iterator = _Pipe((b"abc",))
response_iterator = multi_callable(request_iterator) response_iterator = multi_callable(request_iterator)
next(response_iterator) next(response_iterator)
@ -158,7 +167,10 @@ class ChannelCloseTest(unittest.TestCase):
with grpc.insecure_channel( with grpc.insecure_channel(
"localhost:{}".format(self._port) "localhost:{}".format(self._port)
) as channel: # pylint: disable=bad-continuation ) as channel: # pylint: disable=bad-continuation
multi_callable = channel.stream_stream(_STREAM_URI) multi_callable = channel.stream_stream(
_STREAM_URI,
_registered_method=True,
)
request_iterators = tuple( request_iterators = tuple(
_Pipe((b"abc",)) _Pipe((b"abc",))
for _ in range(test_constants.THREAD_CONCURRENCY) for _ in range(test_constants.THREAD_CONCURRENCY)
@ -176,7 +188,10 @@ class ChannelCloseTest(unittest.TestCase):
def test_many_concurrent_closes(self): def test_many_concurrent_closes(self):
channel = grpc.insecure_channel("localhost:{}".format(self._port)) channel = grpc.insecure_channel("localhost:{}".format(self._port))
multi_callable = channel.stream_stream(_STREAM_URI) multi_callable = channel.stream_stream(
_STREAM_URI,
_registered_method=True,
)
request_iterator = _Pipe((b"abc",)) request_iterator = _Pipe((b"abc",))
response_iterator = multi_callable(request_iterator) response_iterator = multi_callable(request_iterator)
next(response_iterator) next(response_iterator)
@ -203,10 +218,16 @@ class ChannelCloseTest(unittest.TestCase):
with grpc.insecure_channel( with grpc.insecure_channel(
"localhost:{}".format(self._port) "localhost:{}".format(self._port)
) as channel: ) as channel:
stream_multi_callable = channel.stream_stream(_STREAM_URI) stream_multi_callable = channel.stream_stream(
_STREAM_URI,
_registered_method=True,
)
endless_iterator = itertools.repeat(b"abc") endless_iterator = itertools.repeat(b"abc")
stream_response_iterator = stream_multi_callable(endless_iterator) stream_response_iterator = stream_multi_callable(endless_iterator)
future = channel.unary_unary(_UNARY_URI).future(b"abc") future = channel.unary_unary(
_UNARY_URI,
_registered_method=True,
).future(b"abc")
def on_done_callback(future): def on_done_callback(future):
raise Exception("This should not cause a deadlock.") raise Exception("This should not cause a deadlock.")

@ -237,7 +237,10 @@ def _get_compression_ratios(
def _unary_unary_client(channel, multicallable_kwargs, message): def _unary_unary_client(channel, multicallable_kwargs, message):
multi_callable = channel.unary_unary(_UNARY_UNARY) multi_callable = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
response = multi_callable(message, **multicallable_kwargs) response = multi_callable(message, **multicallable_kwargs)
if response != message: if response != message:
raise RuntimeError( raise RuntimeError(
@ -246,7 +249,10 @@ def _unary_unary_client(channel, multicallable_kwargs, message):
def _unary_stream_client(channel, multicallable_kwargs, message): def _unary_stream_client(channel, multicallable_kwargs, message):
multi_callable = channel.unary_stream(_UNARY_STREAM) multi_callable = channel.unary_stream(
_UNARY_STREAM,
_registered_method=True,
)
response_iterator = multi_callable(message, **multicallable_kwargs) response_iterator = multi_callable(message, **multicallable_kwargs)
for response in response_iterator: for response in response_iterator:
if response != message: if response != message:
@ -256,7 +262,10 @@ def _unary_stream_client(channel, multicallable_kwargs, message):
def _stream_unary_client(channel, multicallable_kwargs, message): def _stream_unary_client(channel, multicallable_kwargs, message):
multi_callable = channel.stream_unary(_STREAM_UNARY) multi_callable = channel.stream_unary(
_STREAM_UNARY,
_registered_method=True,
)
requests = (_REQUEST for _ in range(_STREAM_LENGTH)) requests = (_REQUEST for _ in range(_STREAM_LENGTH))
response = multi_callable(requests, **multicallable_kwargs) response = multi_callable(requests, **multicallable_kwargs)
if response != message: if response != message:
@ -266,7 +275,10 @@ def _stream_unary_client(channel, multicallable_kwargs, message):
def _stream_stream_client(channel, multicallable_kwargs, message): def _stream_stream_client(channel, multicallable_kwargs, message):
multi_callable = channel.stream_stream(_STREAM_STREAM) multi_callable = channel.stream_stream(
_STREAM_STREAM,
_registered_method=True,
)
request_prefix = str(0).encode("ascii") * 100 request_prefix = str(0).encode("ascii") * 100
requests = ( requests = (
request_prefix + str(i).encode("ascii") for i in range(_STREAM_LENGTH) request_prefix + str(i).encode("ascii") for i in range(_STREAM_LENGTH)

@ -116,7 +116,10 @@ class ContextVarsPropagationTest(unittest.TestCase):
local_credentials, call_credentials local_credentials, call_credentials
) )
with grpc.secure_channel(target, composite_credentials) as channel: with grpc.secure_channel(target, composite_credentials) as channel:
stub = channel.unary_unary(_UNARY_UNARY) stub = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
response = stub(_REQUEST, wait_for_ready=True) response = stub(_REQUEST, wait_for_ready=True)
self.assertEqual(_REQUEST, response) self.assertEqual(_REQUEST, response)
@ -142,7 +145,10 @@ class ContextVarsPropagationTest(unittest.TestCase):
with grpc.secure_channel( with grpc.secure_channel(
target, composite_credentials target, composite_credentials
) as channel: ) as channel:
stub = channel.unary_unary(_UNARY_UNARY) stub = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
wait_group.done() wait_group.done()
wait_group.wait() wait_group.wait()
for i in range(_RPC_COUNT): for i in range(_RPC_COUNT):

@ -55,7 +55,10 @@ class DNSResolverTest(unittest.TestCase):
"loopback46.unittest.grpc.io:%d" % self._port "loopback46.unittest.grpc.io:%d" % self._port
) as channel: ) as channel:
self.assertEqual( self.assertEqual(
channel.unary_unary(_METHOD)( channel.unary_unary(
_METHOD,
_registered_method=True,
)(
_REQUEST, _REQUEST,
timeout=10, timeout=10,
), ),

@ -96,25 +96,33 @@ class EmptyMessageTest(unittest.TestCase):
self._channel.close() self._channel.close()
def testUnaryUnary(self): def testUnaryUnary(self):
response = self._channel.unary_unary(_UNARY_UNARY)(_REQUEST) response = self._channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)(_REQUEST)
self.assertEqual(_RESPONSE, response) self.assertEqual(_RESPONSE, response)
def testUnaryStream(self): def testUnaryStream(self):
response_iterator = self._channel.unary_stream(_UNARY_STREAM)(_REQUEST) response_iterator = self._channel.unary_stream(
_UNARY_STREAM,
_registered_method=True,
)(_REQUEST)
self.assertSequenceEqual( self.assertSequenceEqual(
[_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator) [_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator)
) )
def testStreamUnary(self): def testStreamUnary(self):
response = self._channel.stream_unary(_STREAM_UNARY)( response = self._channel.stream_unary(
iter([_REQUEST] * test_constants.STREAM_LENGTH) _STREAM_UNARY,
) _registered_method=True,
)(iter([_REQUEST] * test_constants.STREAM_LENGTH))
self.assertEqual(_RESPONSE, response) self.assertEqual(_RESPONSE, response)
def testStreamStream(self): def testStreamStream(self):
response_iterator = self._channel.stream_stream(_STREAM_STREAM)( response_iterator = self._channel.stream_stream(
iter([_REQUEST] * test_constants.STREAM_LENGTH) _STREAM_STREAM,
) _registered_method=True,
)(iter([_REQUEST] * test_constants.STREAM_LENGTH))
self.assertSequenceEqual( self.assertSequenceEqual(
[_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator) [_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator)
) )

@ -73,7 +73,10 @@ class ErrorMessageEncodingTest(unittest.TestCase):
def testMessageEncoding(self): def testMessageEncoding(self):
for message in _UNICODE_ERROR_MESSAGES: for message in _UNICODE_ERROR_MESSAGES:
multi_callable = self._channel.unary_unary(_UNARY_UNARY) multi_callable = self._channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
with self.assertRaises(grpc.RpcError) as cm: with self.assertRaises(grpc.RpcError) as cm:
multi_callable(message.encode("utf-8")) multi_callable(message.encode("utf-8"))

@ -210,14 +210,20 @@ if __name__ == "__main__":
method = TEST_TO_METHOD[args.scenario] method = TEST_TO_METHOD[args.scenario]
if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL: if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL:
multi_callable = channel.unary_unary(method) multi_callable = channel.unary_unary(
method,
_registered_method=True,
)
future = multi_callable.future(REQUEST) future = multi_callable.future(REQUEST)
result, call = multi_callable.with_call(REQUEST) result, call = multi_callable.with_call(REQUEST)
elif ( elif (
args.scenario == IN_FLIGHT_UNARY_STREAM_CALL args.scenario == IN_FLIGHT_UNARY_STREAM_CALL
or args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL or args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL
): ):
multi_callable = channel.unary_stream(method) multi_callable = channel.unary_stream(
method,
_registered_method=True,
)
response_iterator = multi_callable(REQUEST) response_iterator = multi_callable(REQUEST)
for response in response_iterator: for response in response_iterator:
pass pass
@ -225,7 +231,10 @@ if __name__ == "__main__":
args.scenario == IN_FLIGHT_STREAM_UNARY_CALL args.scenario == IN_FLIGHT_STREAM_UNARY_CALL
or args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL or args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL
): ):
multi_callable = channel.stream_unary(method) multi_callable = channel.stream_unary(
method,
_registered_method=True,
)
future = multi_callable.future(infinite_request_iterator()) future = multi_callable.future(infinite_request_iterator())
result, call = multi_callable.with_call( result, call = multi_callable.with_call(
iter([REQUEST] * test_constants.STREAM_LENGTH) iter([REQUEST] * test_constants.STREAM_LENGTH)
@ -234,7 +243,10 @@ if __name__ == "__main__":
args.scenario == IN_FLIGHT_STREAM_STREAM_CALL args.scenario == IN_FLIGHT_STREAM_STREAM_CALL
or args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL or args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL
): ):
multi_callable = channel.stream_stream(method) multi_callable = channel.stream_stream(
method,
_registered_method=True,
)
response_iterator = multi_callable(infinite_request_iterator()) response_iterator = multi_callable(infinite_request_iterator())
for response in response_iterator: for response in response_iterator:
pass pass

@ -231,7 +231,7 @@ class _GenericHandler(grpc.GenericRpcHandler):
def _unary_unary_multi_callable(channel): def _unary_unary_multi_callable(channel):
return channel.unary_unary(_UNARY_UNARY) return channel.unary_unary(_UNARY_UNARY, _registered_method=True)
def _unary_stream_multi_callable(channel): def _unary_stream_multi_callable(channel):
@ -239,6 +239,7 @@ def _unary_stream_multi_callable(channel):
_UNARY_STREAM, _UNARY_STREAM,
request_serializer=_SERIALIZE_REQUEST, request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE, response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
) )
@ -247,11 +248,12 @@ def _stream_unary_multi_callable(channel):
_STREAM_UNARY, _STREAM_UNARY,
request_serializer=_SERIALIZE_REQUEST, request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE, response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
) )
def _stream_stream_multi_callable(channel): def _stream_stream_multi_callable(channel):
return channel.stream_stream(_STREAM_STREAM) return channel.stream_stream(_STREAM_STREAM, _registered_method=True)
class _ClientCallDetails( class _ClientCallDetails(
@ -562,7 +564,7 @@ class InterceptorTest(unittest.TestCase):
self._record[:] = [] self._record[:] = []
multi_callable = _unary_unary_multi_callable(channel) multi_callable = _unary_unary_multi_callable(channel)
multi_callable.with_call( response, call = multi_callable.with_call(
request, request,
metadata=( metadata=(
( (

@ -32,7 +32,7 @@ _STREAM_STREAM = "/test/StreamStream"
def _unary_unary_multi_callable(channel): def _unary_unary_multi_callable(channel):
return channel.unary_unary(_UNARY_UNARY) return channel.unary_unary(_UNARY_UNARY, _registered_method=True)
def _unary_stream_multi_callable(channel): def _unary_stream_multi_callable(channel):
@ -40,6 +40,7 @@ def _unary_stream_multi_callable(channel):
_UNARY_STREAM, _UNARY_STREAM,
request_serializer=_SERIALIZE_REQUEST, request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE, response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
) )
@ -48,11 +49,15 @@ def _stream_unary_multi_callable(channel):
_STREAM_UNARY, _STREAM_UNARY,
request_serializer=_SERIALIZE_REQUEST, request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE, response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
) )
def _stream_stream_multi_callable(channel): def _stream_stream_multi_callable(channel):
return channel.stream_stream(_STREAM_STREAM) return channel.stream_stream(
_STREAM_STREAM,
_registered_method=True,
)
class InvalidMetadataTest(unittest.TestCase): class InvalidMetadataTest(unittest.TestCase):

@ -219,7 +219,10 @@ class FailAfterFewIterationsCounter(object):
def _unary_unary_multi_callable(channel): def _unary_unary_multi_callable(channel):
return channel.unary_unary(_UNARY_UNARY) return channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
def _unary_stream_multi_callable(channel): def _unary_stream_multi_callable(channel):
@ -227,6 +230,7 @@ def _unary_stream_multi_callable(channel):
_UNARY_STREAM, _UNARY_STREAM,
request_serializer=_SERIALIZE_REQUEST, request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE, response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
) )
@ -235,19 +239,29 @@ def _stream_unary_multi_callable(channel):
_STREAM_UNARY, _STREAM_UNARY,
request_serializer=_SERIALIZE_REQUEST, request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE, response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
) )
def _stream_stream_multi_callable(channel): def _stream_stream_multi_callable(channel):
return channel.stream_stream(_STREAM_STREAM) return channel.stream_stream(
_STREAM_STREAM,
_registered_method=True,
)
def _defective_handler_multi_callable(channel): def _defective_handler_multi_callable(channel):
return channel.unary_unary(_DEFECTIVE_GENERIC_RPC_HANDLER) return channel.unary_unary(
_DEFECTIVE_GENERIC_RPC_HANDLER,
_registered_method=True,
)
def _defective_nested_exception_handler_multi_callable(channel): def _defective_nested_exception_handler_multi_callable(channel):
return channel.unary_unary(_UNARY_UNARY_NESTED_EXCEPTION) return channel.unary_unary(
_UNARY_UNARY_NESTED_EXCEPTION,
_registered_method=True,
)
class InvocationDefectsTest(unittest.TestCase): class InvocationDefectsTest(unittest.TestCase):

@ -53,9 +53,10 @@ class LocalCredentialsTest(unittest.TestCase):
) as channel: ) as channel:
self.assertEqual( self.assertEqual(
b"abc", b"abc",
channel.unary_unary("/test/method")( channel.unary_unary(
b"abc", wait_for_ready=True "/test/method",
), _registered_method=True,
)(b"abc", wait_for_ready=True),
) )
server.stop(None) server.stop(None)
@ -77,9 +78,10 @@ class LocalCredentialsTest(unittest.TestCase):
with grpc.secure_channel(server_addr, channel_creds) as channel: with grpc.secure_channel(server_addr, channel_creds) as channel:
self.assertEqual( self.assertEqual(
b"abc", b"abc",
channel.unary_unary("/test/method")( channel.unary_unary(
b"abc", wait_for_ready=True "/test/method",
), _registered_method=True,
)(b"abc", wait_for_ready=True),
) )
server.stop(None) server.stop(None)

@ -207,45 +207,53 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self._server.start() self._server.start()
self._channel = grpc.insecure_channel("localhost:{}".format(port)) self._channel = grpc.insecure_channel("localhost:{}".format(port))
unary_unary_method_name = "/".join(
(
"",
_SERVICE,
_UNARY_UNARY,
)
)
self._unary_unary = self._channel.unary_unary( self._unary_unary = self._channel.unary_unary(
"/".join( unary_unary_method_name,
(
"",
_SERVICE,
_UNARY_UNARY,
)
),
request_serializer=_REQUEST_SERIALIZER, request_serializer=_REQUEST_SERIALIZER,
response_deserializer=_RESPONSE_DESERIALIZER, response_deserializer=_RESPONSE_DESERIALIZER,
_registered_method=True,
)
unary_stream_method_name = "/".join(
(
"",
_SERVICE,
_UNARY_STREAM,
)
) )
self._unary_stream = self._channel.unary_stream( self._unary_stream = self._channel.unary_stream(
"/".join( unary_stream_method_name,
( _registered_method=True,
"", )
_SERVICE, stream_unary_method_name = "/".join(
_UNARY_STREAM, (
) "",
), _SERVICE,
_STREAM_UNARY,
)
) )
self._stream_unary = self._channel.stream_unary( self._stream_unary = self._channel.stream_unary(
"/".join( stream_unary_method_name,
( _registered_method=True,
"", )
_SERVICE, stream_stream_method_name = "/".join(
_STREAM_UNARY, (
) "",
), _SERVICE,
_STREAM_STREAM,
)
) )
self._stream_stream = self._channel.stream_stream( self._stream_stream = self._channel.stream_stream(
"/".join( stream_stream_method_name,
(
"",
_SERVICE,
_STREAM_STREAM,
)
),
request_serializer=_REQUEST_SERIALIZER, request_serializer=_REQUEST_SERIALIZER,
response_deserializer=_RESPONSE_DESERIALIZER, response_deserializer=_RESPONSE_DESERIALIZER,
_registered_method=True,
) )
def tearDown(self): def tearDown(self):
@ -828,16 +836,18 @@ class InspectContextTest(unittest.TestCase):
self._server.start() self._server.start()
self._channel = grpc.insecure_channel("localhost:{}".format(port)) self._channel = grpc.insecure_channel("localhost:{}".format(port))
unary_unary_method_name = "/".join(
(
"",
_SERVICE,
_UNARY_UNARY,
)
)
self._unary_unary = self._channel.unary_unary( self._unary_unary = self._channel.unary_unary(
"/".join( unary_unary_method_name,
(
"",
_SERVICE,
_UNARY_UNARY,
)
),
request_serializer=_REQUEST_SERIALIZER, request_serializer=_REQUEST_SERIALIZER,
response_deserializer=_RESPONSE_DESERIALIZER, response_deserializer=_RESPONSE_DESERIALIZER,
_registered_method=True,
) )
def tearDown(self): def tearDown(self):

@ -110,7 +110,10 @@ def create_phony_channel():
def perform_unary_unary_call(channel, wait_for_ready=None): def perform_unary_unary_call(channel, wait_for_ready=None):
channel.unary_unary(_UNARY_UNARY).__call__( channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
).__call__(
_REQUEST, _REQUEST,
timeout=test_constants.LONG_TIMEOUT, timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready, wait_for_ready=wait_for_ready,
@ -118,7 +121,10 @@ def perform_unary_unary_call(channel, wait_for_ready=None):
def perform_unary_unary_with_call(channel, wait_for_ready=None): def perform_unary_unary_with_call(channel, wait_for_ready=None):
channel.unary_unary(_UNARY_UNARY).with_call( channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
).with_call(
_REQUEST, _REQUEST,
timeout=test_constants.LONG_TIMEOUT, timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready, wait_for_ready=wait_for_ready,
@ -126,7 +132,10 @@ def perform_unary_unary_with_call(channel, wait_for_ready=None):
def perform_unary_unary_future(channel, wait_for_ready=None): def perform_unary_unary_future(channel, wait_for_ready=None):
channel.unary_unary(_UNARY_UNARY).future( channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
).future(
_REQUEST, _REQUEST,
timeout=test_constants.LONG_TIMEOUT, timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready, wait_for_ready=wait_for_ready,
@ -134,7 +143,10 @@ def perform_unary_unary_future(channel, wait_for_ready=None):
def perform_unary_stream_call(channel, wait_for_ready=None): def perform_unary_stream_call(channel, wait_for_ready=None):
response_iterator = channel.unary_stream(_UNARY_STREAM).__call__( response_iterator = channel.unary_stream(
_UNARY_STREAM,
_registered_method=True,
).__call__(
_REQUEST, _REQUEST,
timeout=test_constants.LONG_TIMEOUT, timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready, wait_for_ready=wait_for_ready,
@ -144,7 +156,10 @@ def perform_unary_stream_call(channel, wait_for_ready=None):
def perform_stream_unary_call(channel, wait_for_ready=None): def perform_stream_unary_call(channel, wait_for_ready=None):
channel.stream_unary(_STREAM_UNARY).__call__( channel.stream_unary(
_STREAM_UNARY,
_registered_method=True,
).__call__(
iter([_REQUEST] * test_constants.STREAM_LENGTH), iter([_REQUEST] * test_constants.STREAM_LENGTH),
timeout=test_constants.LONG_TIMEOUT, timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready, wait_for_ready=wait_for_ready,
@ -152,7 +167,10 @@ def perform_stream_unary_call(channel, wait_for_ready=None):
def perform_stream_unary_with_call(channel, wait_for_ready=None): def perform_stream_unary_with_call(channel, wait_for_ready=None):
channel.stream_unary(_STREAM_UNARY).with_call( channel.stream_unary(
_STREAM_UNARY,
_registered_method=True,
).with_call(
iter([_REQUEST] * test_constants.STREAM_LENGTH), iter([_REQUEST] * test_constants.STREAM_LENGTH),
timeout=test_constants.LONG_TIMEOUT, timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready, wait_for_ready=wait_for_ready,
@ -160,7 +178,10 @@ def perform_stream_unary_with_call(channel, wait_for_ready=None):
def perform_stream_unary_future(channel, wait_for_ready=None): def perform_stream_unary_future(channel, wait_for_ready=None):
channel.stream_unary(_STREAM_UNARY).future( channel.stream_unary(
_STREAM_UNARY,
_registered_method=True,
).future(
iter([_REQUEST] * test_constants.STREAM_LENGTH), iter([_REQUEST] * test_constants.STREAM_LENGTH),
timeout=test_constants.LONG_TIMEOUT, timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready, wait_for_ready=wait_for_ready,
@ -168,7 +189,9 @@ def perform_stream_unary_future(channel, wait_for_ready=None):
def perform_stream_stream_call(channel, wait_for_ready=None): def perform_stream_stream_call(channel, wait_for_ready=None):
response_iterator = channel.stream_stream(_STREAM_STREAM).__call__( response_iterator = channel.stream_stream(
_STREAM_STREAM, _registered_method=True
).__call__(
iter([_REQUEST] * test_constants.STREAM_LENGTH), iter([_REQUEST] * test_constants.STREAM_LENGTH),
timeout=test_constants.LONG_TIMEOUT, timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready, wait_for_ready=wait_for_ready,

@ -195,7 +195,9 @@ class MetadataTest(unittest.TestCase):
self._channel.close() self._channel.close()
def testUnaryUnary(self): def testUnaryUnary(self):
multi_callable = self._channel.unary_unary(_UNARY_UNARY) multi_callable = self._channel.unary_unary(
_UNARY_UNARY, _registered_method=True
)
unused_response, call = multi_callable.with_call( unused_response, call = multi_callable.with_call(
_REQUEST, metadata=_INVOCATION_METADATA _REQUEST, metadata=_INVOCATION_METADATA
) )
@ -211,7 +213,9 @@ class MetadataTest(unittest.TestCase):
) )
def testUnaryStream(self): def testUnaryStream(self):
multi_callable = self._channel.unary_stream(_UNARY_STREAM) multi_callable = self._channel.unary_stream(
_UNARY_STREAM, _registered_method=True
)
call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA) call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA)
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted( test_common.metadata_transmitted(
@ -227,7 +231,9 @@ class MetadataTest(unittest.TestCase):
) )
def testStreamUnary(self): def testStreamUnary(self):
multi_callable = self._channel.stream_unary(_STREAM_UNARY) multi_callable = self._channel.stream_unary(
_STREAM_UNARY, _registered_method=True
)
unused_response, call = multi_callable.with_call( unused_response, call = multi_callable.with_call(
iter([_REQUEST] * test_constants.STREAM_LENGTH), iter([_REQUEST] * test_constants.STREAM_LENGTH),
metadata=_INVOCATION_METADATA, metadata=_INVOCATION_METADATA,
@ -244,7 +250,9 @@ class MetadataTest(unittest.TestCase):
) )
def testStreamStream(self): def testStreamStream(self):
multi_callable = self._channel.stream_stream(_STREAM_STREAM) multi_callable = self._channel.stream_stream(
_STREAM_STREAM, _registered_method=True
)
call = multi_callable( call = multi_callable(
iter([_REQUEST] * test_constants.STREAM_LENGTH), iter([_REQUEST] * test_constants.STREAM_LENGTH),
metadata=_INVOCATION_METADATA, metadata=_INVOCATION_METADATA,

@ -52,7 +52,10 @@ class ReconnectTest(unittest.TestCase):
server.add_insecure_port(addr) server.add_insecure_port(addr)
server.start() server.start()
channel = grpc.insecure_channel(addr) channel = grpc.insecure_channel(addr)
multi_callable = channel.unary_unary(_UNARY_UNARY) multi_callable = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
self.assertEqual(_RESPONSE, multi_callable(_REQUEST)) self.assertEqual(_RESPONSE, multi_callable(_REQUEST))
server.stop(None) server.stop(None)
# By default, the channel connectivity is checked every 5s # By default, the channel connectivity is checked every 5s

@ -149,7 +149,10 @@ class ResourceExhaustedTest(unittest.TestCase):
self._channel.close() self._channel.close()
def testUnaryUnary(self): def testUnaryUnary(self):
multi_callable = self._channel.unary_unary(_UNARY_UNARY) multi_callable = self._channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
futures = [] futures = []
for _ in range(test_constants.THREAD_CONCURRENCY): for _ in range(test_constants.THREAD_CONCURRENCY):
futures.append(multi_callable.future(_REQUEST)) futures.append(multi_callable.future(_REQUEST))
@ -178,7 +181,10 @@ class ResourceExhaustedTest(unittest.TestCase):
self.assertEqual(_RESPONSE, multi_callable(_REQUEST)) self.assertEqual(_RESPONSE, multi_callable(_REQUEST))
def testUnaryStream(self): def testUnaryStream(self):
multi_callable = self._channel.unary_stream(_UNARY_STREAM) multi_callable = self._channel.unary_stream(
_UNARY_STREAM,
_registered_method=True,
)
calls = [] calls = []
for _ in range(test_constants.THREAD_CONCURRENCY): for _ in range(test_constants.THREAD_CONCURRENCY):
calls.append(multi_callable(_REQUEST)) calls.append(multi_callable(_REQUEST))
@ -205,7 +211,10 @@ class ResourceExhaustedTest(unittest.TestCase):
self.assertEqual(_RESPONSE, response) self.assertEqual(_RESPONSE, response)
def testStreamUnary(self): def testStreamUnary(self):
multi_callable = self._channel.stream_unary(_STREAM_UNARY) multi_callable = self._channel.stream_unary(
_STREAM_UNARY,
_registered_method=True,
)
futures = [] futures = []
request = iter([_REQUEST] * test_constants.STREAM_LENGTH) request = iter([_REQUEST] * test_constants.STREAM_LENGTH)
for _ in range(test_constants.THREAD_CONCURRENCY): for _ in range(test_constants.THREAD_CONCURRENCY):
@ -236,7 +245,10 @@ class ResourceExhaustedTest(unittest.TestCase):
self.assertEqual(_RESPONSE, multi_callable(request)) self.assertEqual(_RESPONSE, multi_callable(request))
def testStreamStream(self): def testStreamStream(self):
multi_callable = self._channel.stream_stream(_STREAM_STREAM) multi_callable = self._channel.stream_stream(
_STREAM_STREAM,
_registered_method=True,
)
calls = [] calls = []
request = iter([_REQUEST] * test_constants.STREAM_LENGTH) request = iter([_REQUEST] * test_constants.STREAM_LENGTH)
for _ in range(test_constants.THREAD_CONCURRENCY): for _ in range(test_constants.THREAD_CONCURRENCY):

@ -277,7 +277,10 @@ class _GenericHandler(grpc.GenericRpcHandler):
def unary_unary_multi_callable(channel): def unary_unary_multi_callable(channel):
return channel.unary_unary(_UNARY_UNARY) return channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
def unary_stream_multi_callable(channel): def unary_stream_multi_callable(channel):
@ -285,6 +288,7 @@ def unary_stream_multi_callable(channel):
_UNARY_STREAM, _UNARY_STREAM,
request_serializer=_SERIALIZE_REQUEST, request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE, response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
) )
@ -293,6 +297,7 @@ def unary_stream_non_blocking_multi_callable(channel):
_UNARY_STREAM_NON_BLOCKING, _UNARY_STREAM_NON_BLOCKING,
request_serializer=_SERIALIZE_REQUEST, request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE, response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
) )
@ -301,15 +306,22 @@ def stream_unary_multi_callable(channel):
_STREAM_UNARY, _STREAM_UNARY,
request_serializer=_SERIALIZE_REQUEST, request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE, response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
) )
def stream_stream_multi_callable(channel): def stream_stream_multi_callable(channel):
return channel.stream_stream(_STREAM_STREAM) return channel.stream_stream(
_STREAM_STREAM,
_registered_method=True,
)
def stream_stream_non_blocking_multi_callable(channel): def stream_stream_non_blocking_multi_callable(channel):
return channel.stream_stream(_STREAM_STREAM_NON_BLOCKING) return channel.stream_stream(
_STREAM_STREAM_NON_BLOCKING,
_registered_method=True,
)
class BaseRPCTest(object): class BaseRPCTest(object):

@ -81,7 +81,10 @@ def run_test(args):
thread.start() thread.start()
port = port_queue.get() port = port_queue.get()
channel = grpc.insecure_channel("localhost:%d" % port) channel = grpc.insecure_channel("localhost:%d" % port)
multi_callable = channel.unary_unary(FORK_EXIT) multi_callable = channel.unary_unary(
FORK_EXIT,
_registered_method=True,
)
result, call = multi_callable.with_call(REQUEST, wait_for_ready=True) result, call = multi_callable.with_call(REQUEST, wait_for_ready=True)
os.wait() os.wait()
else: else:

@ -77,7 +77,10 @@ class SSLSessionCacheTest(unittest.TestCase):
channel = grpc.secure_channel( channel = grpc.secure_channel(
"localhost:{}".format(port), channel_creds, options=channel_options "localhost:{}".format(port), channel_creds, options=channel_options
) )
response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) response = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)(_REQUEST)
auth_data = pickle.loads(response) auth_data = pickle.loads(response)
self.assertEqual( self.assertEqual(
expect_ssl_session_reused, expect_ssl_session_reused,

@ -53,7 +53,10 @@ def main_unary(server_target):
"""Initiate a unary RPC to be interrupted by a SIGINT.""" """Initiate a unary RPC to be interrupted by a SIGINT."""
global per_process_rpc_future # pylint: disable=global-statement global per_process_rpc_future # pylint: disable=global-statement
with grpc.insecure_channel(server_target) as channel: with grpc.insecure_channel(server_target) as channel:
multicallable = channel.unary_unary(UNARY_UNARY) multicallable = channel.unary_unary(
UNARY_UNARY,
_registered_method=True,
)
signal.signal(signal.SIGINT, handle_sigint) signal.signal(signal.SIGINT, handle_sigint)
per_process_rpc_future = multicallable.future( per_process_rpc_future = multicallable.future(
_MESSAGE, wait_for_ready=True _MESSAGE, wait_for_ready=True
@ -67,9 +70,10 @@ def main_streaming(server_target):
global per_process_rpc_future # pylint: disable=global-statement global per_process_rpc_future # pylint: disable=global-statement
with grpc.insecure_channel(server_target) as channel: with grpc.insecure_channel(server_target) as channel:
signal.signal(signal.SIGINT, handle_sigint) signal.signal(signal.SIGINT, handle_sigint)
per_process_rpc_future = channel.unary_stream(UNARY_STREAM)( per_process_rpc_future = channel.unary_stream(
_MESSAGE, wait_for_ready=True UNARY_STREAM,
) _registered_method=True,
)(_MESSAGE, wait_for_ready=True)
for result in per_process_rpc_future: for result in per_process_rpc_future:
pass pass
assert False, _ASSERTION_MESSAGE assert False, _ASSERTION_MESSAGE
@ -79,7 +83,10 @@ def main_unary_with_exception(server_target):
"""Initiate a unary RPC with a signal handler that will raise.""" """Initiate a unary RPC with a signal handler that will raise."""
channel = grpc.insecure_channel(server_target) channel = grpc.insecure_channel(server_target)
try: try:
channel.unary_unary(UNARY_UNARY)(_MESSAGE, wait_for_ready=True) channel.unary_unary(
UNARY_UNARY,
_registered_method=True,
)(_MESSAGE, wait_for_ready=True)
except KeyboardInterrupt: except KeyboardInterrupt:
sys.stderr.write("Running signal handler.\n") sys.stderr.write("Running signal handler.\n")
sys.stderr.flush() sys.stderr.flush()
@ -92,9 +99,10 @@ def main_streaming_with_exception(server_target):
"""Initiate a streaming RPC with a signal handler that will raise.""" """Initiate a streaming RPC with a signal handler that will raise."""
channel = grpc.insecure_channel(server_target) channel = grpc.insecure_channel(server_target)
try: try:
for _ in channel.unary_stream(UNARY_STREAM)( for _ in channel.unary_stream(
_MESSAGE, wait_for_ready=True UNARY_STREAM,
): _registered_method=True,
)(_MESSAGE, wait_for_ready=True):
pass pass
except KeyboardInterrupt: except KeyboardInterrupt:
sys.stderr.write("Running signal handler.\n") sys.stderr.write("Running signal handler.\n")

@ -71,9 +71,10 @@ class XdsCredentialsTest(unittest.TestCase):
server_address, channel_creds, options=override_options server_address, channel_creds, options=override_options
) as channel: ) as channel:
request = b"abc" request = b"abc"
response = channel.unary_unary("/test/method")( response = channel.unary_unary(
request, wait_for_ready=True "/test/method",
) _registered_method=True,
)(request, wait_for_ready=True)
self.assertEqual(response, request) self.assertEqual(response, request)
def test_xds_creds_fallback_insecure(self): def test_xds_creds_fallback_insecure(self):
@ -89,9 +90,10 @@ class XdsCredentialsTest(unittest.TestCase):
channel_creds = grpc.xds_channel_credentials(channel_fallback_creds) channel_creds = grpc.xds_channel_credentials(channel_fallback_creds)
with grpc.secure_channel(server_address, channel_creds) as channel: with grpc.secure_channel(server_address, channel_creds) as channel:
request = b"abc" request = b"abc"
response = channel.unary_unary("/test/method")( response = channel.unary_unary(
request, wait_for_ready=True "/test/method",
) _registered_method=True,
)(request, wait_for_ready=True)
self.assertEqual(response, request) self.assertEqual(response, request)
def test_start_xds_server(self): def test_start_xds_server(self):

@ -65,6 +65,7 @@ class CloseChannelTest(unittest.TestCase):
_UNARY_CALL_METHOD_WITH_SLEEP, _UNARY_CALL_METHOD_WITH_SLEEP,
request_serializer=messages_pb2.SimpleRequest.SerializeToString, request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString, response_deserializer=messages_pb2.SimpleResponse.FromString,
_registered_method=True,
) )
greenlet = group.spawn(self._run_client, UnaryCallWithSleep) greenlet = group.spawn(self._run_client, UnaryCallWithSleep)
# release loop so that greenlet can take control # release loop so that greenlet can take control
@ -78,6 +79,7 @@ class CloseChannelTest(unittest.TestCase):
_UNARY_CALL_METHOD_WITH_SLEEP, _UNARY_CALL_METHOD_WITH_SLEEP,
request_serializer=messages_pb2.SimpleRequest.SerializeToString, request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString, response_deserializer=messages_pb2.SimpleResponse.FromString,
_registered_method=True,
) )
greenlet = group.spawn(self._run_client, UnaryCallWithSleep) greenlet = group.spawn(self._run_client, UnaryCallWithSleep)
# release loop so that greenlet can take control # release loop so that greenlet can take control

@ -68,7 +68,10 @@ def _start_a_test_server():
def _perform_an_rpc(address): def _perform_an_rpc(address):
channel = grpc.insecure_channel(address) channel = grpc.insecure_channel(address)
multicallable = channel.unary_unary(_TEST_METHOD) multicallable = channel.unary_unary(
_TEST_METHOD,
_registered_method=True,
)
response = multicallable(_REQUEST) response = multicallable(_REQUEST)
assert _REQUEST == response assert _REQUEST == response

@ -193,6 +193,7 @@ class SimpleStubsTest(unittest.TestCase):
_UNARY_UNARY, _UNARY_UNARY,
channel_credentials=grpc.experimental.insecure_channel_credentials(), channel_credentials=grpc.experimental.insecure_channel_credentials(),
timeout=None, timeout=None,
_registered_method=0,
) )
self.assertEqual(_REQUEST, response) self.assertEqual(_REQUEST, response)
@ -205,6 +206,7 @@ class SimpleStubsTest(unittest.TestCase):
_UNARY_UNARY, _UNARY_UNARY,
channel_credentials=grpc.local_channel_credentials(), channel_credentials=grpc.local_channel_credentials(),
timeout=None, timeout=None,
_registered_method=0,
) )
self.assertEqual(_REQUEST, response) self.assertEqual(_REQUEST, response)
@ -213,7 +215,10 @@ class SimpleStubsTest(unittest.TestCase):
target = f"localhost:{port}" target = f"localhost:{port}"
test_name = inspect.stack()[0][3] test_name = inspect.stack()[0][3]
args = (_REQUEST, target, _UNARY_UNARY) args = (_REQUEST, target, _UNARY_UNARY)
kwargs = {"channel_credentials": grpc.local_channel_credentials()} kwargs = {
"channel_credentials": grpc.local_channel_credentials(),
"_registered_method": True,
}
def _invoke(seed: str): def _invoke(seed: str):
run_kwargs = dict(kwargs) run_kwargs = dict(kwargs)
@ -230,6 +235,7 @@ class SimpleStubsTest(unittest.TestCase):
target, target,
_UNARY_UNARY, _UNARY_UNARY,
channel_credentials=grpc.local_channel_credentials(), channel_credentials=grpc.local_channel_credentials(),
_registered_method=0,
) )
self.assert_eventually( self.assert_eventually(
lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count() lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()
@ -250,6 +256,7 @@ class SimpleStubsTest(unittest.TestCase):
_UNARY_UNARY, _UNARY_UNARY,
options=options, options=options,
channel_credentials=grpc.local_channel_credentials(), channel_credentials=grpc.local_channel_credentials(),
_registered_method=0,
) )
self.assert_eventually( self.assert_eventually(
lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count() lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()
@ -265,6 +272,7 @@ class SimpleStubsTest(unittest.TestCase):
target, target,
_UNARY_STREAM, _UNARY_STREAM,
channel_credentials=grpc.local_channel_credentials(), channel_credentials=grpc.local_channel_credentials(),
_registered_method=0,
): ):
self.assertEqual(_REQUEST, response) self.assertEqual(_REQUEST, response)
@ -280,6 +288,7 @@ class SimpleStubsTest(unittest.TestCase):
target, target,
_STREAM_UNARY, _STREAM_UNARY,
channel_credentials=grpc.local_channel_credentials(), channel_credentials=grpc.local_channel_credentials(),
_registered_method=0,
) )
self.assertEqual(_REQUEST, response) self.assertEqual(_REQUEST, response)
@ -295,6 +304,7 @@ class SimpleStubsTest(unittest.TestCase):
target, target,
_STREAM_STREAM, _STREAM_STREAM,
channel_credentials=grpc.local_channel_credentials(), channel_credentials=grpc.local_channel_credentials(),
_registered_method=0,
): ):
self.assertEqual(_REQUEST, response) self.assertEqual(_REQUEST, response)
@ -319,14 +329,22 @@ class SimpleStubsTest(unittest.TestCase):
with _server(server_creds) as port: with _server(server_creds) as port:
target = f"localhost:{port}" target = f"localhost:{port}"
response = grpc.experimental.unary_unary( response = grpc.experimental.unary_unary(
_REQUEST, target, _UNARY_UNARY, options=_property_options _REQUEST,
target,
_UNARY_UNARY,
options=_property_options,
_registered_method=0,
) )
def test_insecure_sugar(self): def test_insecure_sugar(self):
with _server(None) as port: with _server(None) as port:
target = f"localhost:{port}" target = f"localhost:{port}"
response = grpc.experimental.unary_unary( response = grpc.experimental.unary_unary(
_REQUEST, target, _UNARY_UNARY, insecure=True _REQUEST,
target,
_UNARY_UNARY,
insecure=True,
_registered_method=0,
) )
self.assertEqual(_REQUEST, response) self.assertEqual(_REQUEST, response)
@ -340,14 +358,24 @@ class SimpleStubsTest(unittest.TestCase):
_UNARY_UNARY, _UNARY_UNARY,
insecure=True, insecure=True,
channel_credentials=grpc.local_channel_credentials(), channel_credentials=grpc.local_channel_credentials(),
_registered_method=0,
) )
def test_default_wait_for_ready(self): def test_default_wait_for_ready(self):
addr, port, sock = get_socket() addr, port, sock = get_socket()
sock.close() sock.close()
target = f"{addr}:{port}" target = f"{addr}:{port}"
channel = grpc._simple_stubs.ChannelCache.get().get_channel( (
target, (), None, True, None channel,
unused_method_handle,
) = grpc._simple_stubs.ChannelCache.get().get_channel(
target=target,
options=(),
channel_credentials=None,
insecure=True,
compression=None,
method=_UNARY_UNARY,
_registered_method=True,
) )
rpc_finished_event = threading.Event() rpc_finished_event = threading.Event()
rpc_failed_event = threading.Event() rpc_failed_event = threading.Event()
@ -376,7 +404,12 @@ class SimpleStubsTest(unittest.TestCase):
def _send_rpc(): def _send_rpc():
try: try:
response = grpc.experimental.unary_unary( response = grpc.experimental.unary_unary(
_REQUEST, target, _UNARY_UNARY, timeout=None, insecure=True _REQUEST,
target,
_UNARY_UNARY,
timeout=None,
insecure=True,
_registered_method=0,
) )
rpc_finished_event.set() rpc_finished_event.set()
except Exception as e: except Exception as e:
@ -399,6 +432,7 @@ class SimpleStubsTest(unittest.TestCase):
target, target,
_BLACK_HOLE, _BLACK_HOLE,
insecure=True, insecure=True,
_registered_method=0,
**invocation_args, **invocation_args,
) )
self.assertEqual( self.assertEqual(

Loading…
Cancel
Save