[Python O11Y] Reapply registered method change (#35850)

This reverts commit a18279db2e.

<!--

If you know who should review your pull request, please assign it to that
person, otherwise the pull request would get assigned randomly.

If your pull request is for a specific language, please add the appropriate
lang label.

-->

Closes #35850

PiperOrigin-RevId: 607476066
pull/35921/head^2
Xuan Wang 9 months ago committed by Copybara-Service
parent cf79445171
commit 2c9b599e5e
  1. 27
      examples/python/helloworld/helloworld_pb2.py
  2. 14
      examples/python/helloworld/helloworld_pb2.pyi
  3. 104
      examples/python/helloworld/helloworld_pb2_grpc.py
  4. 23
      src/compiler/python_generator.cc
  5. 4
      src/python/.gitignore
  6. 12
      src/python/grpcio/grpc/__init__.py
  7. 75
      src/python/grpcio/grpc/_channel.py
  8. 7
      src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi
  9. 71
      src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
  10. 6
      src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi
  11. 36
      src/python/grpcio/grpc/_interceptor.py
  12. 81
      src/python/grpcio/grpc/_simple_stubs.py
  13. 12
      src/python/grpcio/grpc/aio/_base_channel.py
  14. 21
      src/python/grpcio/grpc/aio/_channel.py
  15. 27
      src/python/grpcio_testing/grpc_testing/_channel/_channel.py
  16. 17
      src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py
  17. 5
      src/python/grpcio_tests/tests/csds/csds_test.py
  18. 11
      src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py
  19. 9
      src/python/grpcio_tests/tests/qps/benchmark_client.py
  20. 25
      src/python/grpcio_tests/tests/status/_grpc_status_test.py
  21. 20
      src/python/grpcio_tests/tests/unit/_abort_test.py
  22. 20
      src/python/grpcio_tests/tests/unit/_auth_context_test.py
  23. 35
      src/python/grpcio_tests/tests/unit/_channel_close_test.py
  24. 20
      src/python/grpcio_tests/tests/unit/_compression_test.py
  25. 10
      src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py
  26. 5
      src/python/grpcio_tests/tests/unit/_dns_resolver_test.py
  27. 24
      src/python/grpcio_tests/tests/unit/_empty_message_test.py
  28. 5
      src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py
  29. 20
      src/python/grpcio_tests/tests/unit/_exit_scenarios.py
  30. 8
      src/python/grpcio_tests/tests/unit/_interceptor_test.py
  31. 9
      src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py
  32. 22
      src/python/grpcio_tests/tests/unit/_invocation_defects_test.py
  33. 14
      src/python/grpcio_tests/tests/unit/_local_credentials_test.py
  34. 80
      src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
  35. 39
      src/python/grpcio_tests/tests/unit/_metadata_flags_test.py
  36. 16
      src/python/grpcio_tests/tests/unit/_metadata_test.py
  37. 5
      src/python/grpcio_tests/tests/unit/_reconnect_test.py
  38. 20
      src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py
  39. 18
      src/python/grpcio_tests/tests/unit/_rpc_test_helpers.py
  40. 5
      src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py
  41. 5
      src/python/grpcio_tests/tests/unit/_session_cache_test.py
  42. 24
      src/python/grpcio_tests/tests/unit/_signal_client.py
  43. 14
      src/python/grpcio_tests/tests/unit/_xds_credentials_test.py
  44. 2
      src/python/grpcio_tests/tests_gevent/unit/close_channel_test.py
  45. 5
      src/python/grpcio_tests/tests_py3_only/unit/_leak_test.py
  46. 46
      src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py

@ -1,11 +1,12 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: helloworld.proto
# Protobuf Python Version: 4.25.0
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@ -13,18 +14,18 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10helloworld.proto\x12\nhelloworld\"\x1c\n\x0cHelloRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"\x1d\n\nHelloReply\x12\x0f\n\x07message\x18\x01 \x01(\t2I\n\x07Greeter\x12>\n\x08SayHello\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00\x42\x36\n\x1bio.grpc.examples.helloworldB\x0fHelloWorldProtoP\x01\xa2\x02\x03HLWb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10helloworld.proto\x12\nhelloworld\"\x1c\n\x0cHelloRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"\x1d\n\nHelloReply\x12\x0f\n\x07message\x18\x01 \x01(\t2\xe4\x01\n\x07Greeter\x12>\n\x08SayHello\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00\x12K\n\x13SayHelloStreamReply\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00\x30\x01\x12L\n\x12SayHelloBidiStream\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00(\x01\x30\x01\x42\x36\n\x1bio.grpc.examples.helloworldB\x0fHelloWorldProtoP\x01\xa2\x02\x03HLWb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'helloworld_pb2', globals())
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'helloworld_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\033io.grpc.examples.helloworldB\017HelloWorldProtoP\001\242\002\003HLW'
_HELLOREQUEST._serialized_start=32
_HELLOREQUEST._serialized_end=60
_HELLOREPLY._serialized_start=62
_HELLOREPLY._serialized_end=91
_GREETER._serialized_start=93
_GREETER._serialized_end=166
_globals['DESCRIPTOR']._options = None
_globals['DESCRIPTOR']._serialized_options = b'\n\033io.grpc.examples.helloworldB\017HelloWorldProtoP\001\242\002\003HLW'
_globals['_HELLOREQUEST']._serialized_start=32
_globals['_HELLOREQUEST']._serialized_end=60
_globals['_HELLOREPLY']._serialized_start=62
_globals['_HELLOREPLY']._serialized_end=91
_globals['_GREETER']._serialized_start=94
_globals['_GREETER']._serialized_end=322
# @@protoc_insertion_point(module_scope)

@ -4,14 +4,14 @@ from typing import ClassVar as _ClassVar, Optional as _Optional
DESCRIPTOR: _descriptor.FileDescriptor
class HelloReply(_message.Message):
__slots__ = ["message"]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
message: str
def __init__(self, message: _Optional[str] = ...) -> None: ...
class HelloRequest(_message.Message):
__slots__ = ["name"]
__slots__ = ("name",)
NAME_FIELD_NUMBER: _ClassVar[int]
name: str
def __init__(self, name: _Optional[str] = ...) -> None: ...
class HelloReply(_message.Message):
__slots__ = ("message",)
MESSAGE_FIELD_NUMBER: _ClassVar[int]
message: str
def __init__(self, message: _Optional[str] = ...) -> None: ...

@ -19,7 +19,17 @@ class GreeterStub(object):
'/helloworld.Greeter/SayHello',
request_serializer=helloworld__pb2.HelloRequest.SerializeToString,
response_deserializer=helloworld__pb2.HelloReply.FromString,
)
_registered_method=True)
self.SayHelloStreamReply = channel.unary_stream(
'/helloworld.Greeter/SayHelloStreamReply',
request_serializer=helloworld__pb2.HelloRequest.SerializeToString,
response_deserializer=helloworld__pb2.HelloReply.FromString,
_registered_method=True)
self.SayHelloBidiStream = channel.stream_stream(
'/helloworld.Greeter/SayHelloBidiStream',
request_serializer=helloworld__pb2.HelloRequest.SerializeToString,
response_deserializer=helloworld__pb2.HelloReply.FromString,
_registered_method=True)
class GreeterServicer(object):
@ -33,6 +43,18 @@ class GreeterServicer(object):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SayHelloStreamReply(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def SayHelloBidiStream(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_GreeterServicer_to_server(servicer, server):
rpc_method_handlers = {
@ -41,6 +63,16 @@ def add_GreeterServicer_to_server(servicer, server):
request_deserializer=helloworld__pb2.HelloRequest.FromString,
response_serializer=helloworld__pb2.HelloReply.SerializeToString,
),
'SayHelloStreamReply': grpc.unary_stream_rpc_method_handler(
servicer.SayHelloStreamReply,
request_deserializer=helloworld__pb2.HelloRequest.FromString,
response_serializer=helloworld__pb2.HelloReply.SerializeToString,
),
'SayHelloBidiStream': grpc.stream_stream_rpc_method_handler(
servicer.SayHelloBidiStream,
request_deserializer=helloworld__pb2.HelloRequest.FromString,
response_serializer=helloworld__pb2.HelloReply.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'helloworld.Greeter', rpc_method_handlers)
@ -63,8 +95,72 @@ class Greeter(object):
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/helloworld.Greeter/SayHello',
return grpc.experimental.unary_unary(
request,
target,
'/helloworld.Greeter/SayHello',
helloworld__pb2.HelloRequest.SerializeToString,
helloworld__pb2.HelloReply.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SayHelloStreamReply(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(
request,
target,
'/helloworld.Greeter/SayHelloStreamReply',
helloworld__pb2.HelloRequest.SerializeToString,
helloworld__pb2.HelloReply.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def SayHelloBidiStream(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(
request_iterator,
target,
'/helloworld.Greeter/SayHelloBidiStream',
helloworld__pb2.HelloRequest.SerializeToString,
helloworld__pb2.HelloReply.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

@ -467,7 +467,7 @@ bool PrivateGenerator::PrintStub(
out->Print(
method_dict,
"response_deserializer=$ResponseModuleAndClass$.FromString,\n");
out->Print(")\n");
out->Print("_registered_method=True)\n");
}
}
}
@ -642,22 +642,27 @@ bool PrivateGenerator::PrintServiceClass(
args_dict["ArityMethodName"] = arity_method_name;
args_dict["PackageQualifiedService"] = package_qualified_service_name;
args_dict["Method"] = method->name();
out->Print(args_dict,
"return "
"grpc.experimental.$ArityMethodName$($RequestParameter$, "
"target, '/$PackageQualifiedService$/$Method$',\n");
out->Print(args_dict, "return grpc.experimental.$ArityMethodName$(\n");
{
IndentScope continuation_indent(out);
StringMap serializer_dict;
out->Print(args_dict, "$RequestParameter$,\n");
out->Print("target,\n");
out->Print(args_dict, "'/$PackageQualifiedService$/$Method$',\n");
serializer_dict["RequestModuleAndClass"] = request_module_and_class;
serializer_dict["ResponseModuleAndClass"] = response_module_and_class;
out->Print(serializer_dict,
"$RequestModuleAndClass$.SerializeToString,\n");
out->Print(serializer_dict, "$ResponseModuleAndClass$.FromString,\n");
out->Print("options, channel_credentials,\n");
out->Print(
"insecure, call_credentials, compression, wait_for_ready, "
"timeout, metadata)\n");
out->Print("options,\n");
out->Print("channel_credentials,\n");
out->Print("insecure,\n");
out->Print("call_credentials,\n");
out->Print("compression,\n");
out->Print("wait_for_ready,\n");
out->Print("timeout,\n");
out->Print("metadata,\n");
out->Print("_registered_method=True)\n");
}
}
}

@ -1,4 +1,6 @@
gens/
build/
grpc_root/
third_party/
*_pb2.py
*_pb2.pyi
*_pb2_grpc.py

@ -1004,6 +1004,7 @@ class Channel(abc.ABC):
method,
request_serializer=None,
response_deserializer=None,
_registered_method=False,
):
"""Creates a UnaryUnaryMultiCallable for a unary-unary method.
@ -1014,6 +1015,8 @@ class Channel(abc.ABC):
response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None
is passed.
_registered_method: Implementation Private. A bool representing whether the method
is registered.
Returns:
A UnaryUnaryMultiCallable value for the named unary-unary method.
@ -1026,6 +1029,7 @@ class Channel(abc.ABC):
method,
request_serializer=None,
response_deserializer=None,
_registered_method=False,
):
"""Creates a UnaryStreamMultiCallable for a unary-stream method.
@ -1036,6 +1040,8 @@ class Channel(abc.ABC):
response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None is
passed.
_registered_method: Implementation Private. A bool representing whether the method
is registered.
Returns:
A UnaryStreamMultiCallable value for the name unary-stream method.
@ -1048,6 +1054,7 @@ class Channel(abc.ABC):
method,
request_serializer=None,
response_deserializer=None,
_registered_method=False,
):
"""Creates a StreamUnaryMultiCallable for a stream-unary method.
@ -1058,6 +1065,8 @@ class Channel(abc.ABC):
response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None is
passed.
_registered_method: Implementation Private. A bool representing whether the method
is registered.
Returns:
A StreamUnaryMultiCallable value for the named stream-unary method.
@ -1070,6 +1079,7 @@ class Channel(abc.ABC):
method,
request_serializer=None,
response_deserializer=None,
_registered_method=False,
):
"""Creates a StreamStreamMultiCallable for a stream-stream method.
@ -1080,6 +1090,8 @@ class Channel(abc.ABC):
response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None
is passed.
_registered_method: Implementation Private. A bool representing whether the method
is registered.
Returns:
A StreamStreamMultiCallable value for the named stream-stream method.

@ -24,6 +24,7 @@ import types
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
@ -1054,6 +1055,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
_request_serializer: Optional[SerializingFunction]
_response_deserializer: Optional[DeserializingFunction]
_context: Any
_registered_call_handle: Optional[int]
__slots__ = [
"_channel",
@ -1074,6 +1076,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
target: bytes,
request_serializer: Optional[SerializingFunction],
response_deserializer: Optional[DeserializingFunction],
_registered_call_handle: Optional[int],
):
self._channel = channel
self._managed_call = managed_call
@ -1082,6 +1085,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context()
self._registered_call_handle = _registered_call_handle
def _prepare(
self,
@ -1153,6 +1157,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
),
),
self._context,
self._registered_call_handle,
)
event = call.next_event()
_handle_event(event, state, self._response_deserializer)
@ -1221,6 +1226,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
(operations,),
event_handler,
self._context,
self._registered_call_handle,
)
return _MultiThreadedRendezvous(
state, call, self._response_deserializer, deadline
@ -1234,6 +1240,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
_request_serializer: Optional[SerializingFunction]
_response_deserializer: Optional[DeserializingFunction]
_context: Any
_registered_call_handle: Optional[int]
__slots__ = [
"_channel",
@ -1252,6 +1259,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
target: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
_registered_call_handle: Optional[int],
):
self._channel = channel
self._method = method
@ -1259,6 +1267,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context()
self._registered_call_handle = _registered_call_handle
def __call__( # pylint: disable=too-many-locals
self,
@ -1317,6 +1326,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
call_credentials,
operations_and_tags,
self._context,
self._registered_call_handle,
)
return _SingleThreadedRendezvous(
state, call, self._response_deserializer, deadline
@ -1331,6 +1341,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
_request_serializer: Optional[SerializingFunction]
_response_deserializer: Optional[DeserializingFunction]
_context: Any
_registered_call_handle: Optional[int]
__slots__ = [
"_channel",
@ -1351,6 +1362,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
target: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
_registered_call_handle: Optional[int],
):
self._channel = channel
self._managed_call = managed_call
@ -1359,6 +1371,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context()
self._registered_call_handle = _registered_call_handle
def __call__( # pylint: disable=too-many-locals
self,
@ -1408,6 +1421,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
operations,
_event_handler(state, self._response_deserializer),
self._context,
self._registered_call_handle,
)
return _MultiThreadedRendezvous(
state, call, self._response_deserializer, deadline
@ -1422,6 +1436,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
_request_serializer: Optional[SerializingFunction]
_response_deserializer: Optional[DeserializingFunction]
_context: Any
_registered_call_handle: Optional[int]
__slots__ = [
"_channel",
@ -1442,6 +1457,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
target: bytes,
request_serializer: Optional[SerializingFunction],
response_deserializer: Optional[DeserializingFunction],
_registered_call_handle: Optional[int],
):
self._channel = channel
self._managed_call = managed_call
@ -1450,6 +1466,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context()
self._registered_call_handle = _registered_call_handle
def _blocking(
self,
@ -1482,6 +1499,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
augmented_metadata, initial_metadata_flags
),
self._context,
self._registered_call_handle,
)
_consume_request_iterator(
request_iterator, state, call, self._request_serializer, None
@ -1572,6 +1590,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
),
event_handler,
self._context,
self._registered_call_handle,
)
_consume_request_iterator(
request_iterator,
@ -1593,6 +1612,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
_request_serializer: Optional[SerializingFunction]
_response_deserializer: Optional[DeserializingFunction]
_context: Any
_registered_call_handle: Optional[int]
__slots__ = [
"_channel",
@ -1611,8 +1631,9 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
managed_call: IntegratedCallFactory,
method: bytes,
target: bytes,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
request_serializer: Optional[SerializingFunction],
response_deserializer: Optional[DeserializingFunction],
_registered_call_handle: Optional[int],
):
self._channel = channel
self._managed_call = managed_call
@ -1621,6 +1642,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._context = cygrpc.build_census_context()
self._registered_call_handle = _registered_call_handle
def __call__(
self,
@ -1662,6 +1684,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
operations,
event_handler,
self._context,
self._registered_call_handle,
)
_consume_request_iterator(
request_iterator,
@ -1751,7 +1774,8 @@ def _channel_managed_call_management(state: _ChannelCallState):
credentials: Optional[cygrpc.CallCredentials],
operations: Sequence[Sequence[cygrpc.Operation]],
event_handler: UserTag,
context,
context: Any,
_registered_call_handle: Optional[int],
) -> cygrpc.IntegratedCall:
"""Creates a cygrpc.IntegratedCall.
@ -1768,6 +1792,8 @@ def _channel_managed_call_management(state: _ChannelCallState):
event_handler: A behavior to call to handle the events resultant from
the operations on the call.
context: Context object for distributed tracing.
_registered_call_handle: An int representing the call handle of the
method, or None if the method is not registered.
Returns:
A cygrpc.IntegratedCall with which to conduct an RPC.
"""
@ -1788,6 +1814,7 @@ def _channel_managed_call_management(state: _ChannelCallState):
credentials,
operations_and_tags,
context,
_registered_call_handle,
)
if state.managed_calls == 0:
state.managed_calls = 1
@ -2021,6 +2048,7 @@ class Channel(grpc.Channel):
_call_state: _ChannelCallState
_connectivity_state: _ChannelConnectivityState
_target: str
_registered_call_handles: Dict[str, int]
def __init__(
self,
@ -2055,6 +2083,22 @@ class Channel(grpc.Channel):
if cygrpc.g_gevent_activated:
cygrpc.gevent_increment_channel_count()
def _get_registered_call_handle(self, method: str) -> int:
"""
Get the registered call handle for a method.
This is a semi-private method. It is intended for use only by gRPC generated code.
This method is not thread-safe.
Args:
method: Required, the method name for the RPC.
Returns:
The registered call handle pointer in the form of a Python Long.
"""
return self._channel.get_registered_call_handle(_common.encode(method))
def _process_python_options(
self, python_options: Sequence[ChannelArgumentType]
) -> None:
@ -2078,12 +2122,17 @@ class Channel(grpc.Channel):
) -> None:
_unsubscribe(self._connectivity_state, callback)
# pylint: disable=arguments-differ
def unary_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.UnaryUnaryMultiCallable:
_registered_call_handle = None
if _registered_method:
_registered_call_handle = self._get_registered_call_handle(method)
return _UnaryUnaryMultiCallable(
self._channel,
_channel_managed_call_management(self._call_state),
@ -2091,14 +2140,20 @@ class Channel(grpc.Channel):
_common.encode(self._target),
request_serializer,
response_deserializer,
_registered_call_handle,
)
# pylint: disable=arguments-differ
def unary_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.UnaryStreamMultiCallable:
_registered_call_handle = None
if _registered_method:
_registered_call_handle = self._get_registered_call_handle(method)
# NOTE(rbellevi): Benchmarks have shown that running a unary-stream RPC
# on a single Python thread results in an appreciable speed-up. However,
# due to slight differences in capability, the multi-threaded variant
@ -2110,6 +2165,7 @@ class Channel(grpc.Channel):
_common.encode(self._target),
request_serializer,
response_deserializer,
_registered_call_handle,
)
else:
return _UnaryStreamMultiCallable(
@ -2119,14 +2175,20 @@ class Channel(grpc.Channel):
_common.encode(self._target),
request_serializer,
response_deserializer,
_registered_call_handle,
)
# pylint: disable=arguments-differ
def stream_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.StreamUnaryMultiCallable:
_registered_call_handle = None
if _registered_method:
_registered_call_handle = self._get_registered_call_handle(method)
return _StreamUnaryMultiCallable(
self._channel,
_channel_managed_call_management(self._call_state),
@ -2134,14 +2196,20 @@ class Channel(grpc.Channel):
_common.encode(self._target),
request_serializer,
response_deserializer,
_registered_call_handle,
)
# pylint: disable=arguments-differ
def stream_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.StreamStreamMultiCallable:
_registered_call_handle = None
if _registered_method:
_registered_call_handle = self._get_registered_call_handle(method)
return _StreamStreamMultiCallable(
self._channel,
_channel_managed_call_management(self._call_state),
@ -2149,6 +2217,7 @@ class Channel(grpc.Channel):
_common.encode(self._target),
request_serializer,
response_deserializer,
_registered_call_handle,
)
def _unsubscribe_all(self) -> None:

@ -74,6 +74,13 @@ cdef class SegregatedCall:
cdef class Channel:
cdef _ChannelState _state
cdef dict _registered_call_handles
# TODO(https://github.com/grpc/grpc/issues/15662): Eliminate this.
cdef tuple _arguments
cdef class CallHandle:
cdef void *c_call_handle
cdef object method

@ -101,6 +101,25 @@ cdef class _ChannelState:
self.connectivity_due = set()
self.closed_reason = None
cdef class CallHandle:
def __cinit__(self, _ChannelState channel_state, object method):
self.method = method
cpython.Py_INCREF(method)
# Note that since we always pass None for host, we set the
# second-to-last parameter of grpc_channel_register_call to a fixed
# NULL value.
self.c_call_handle = grpc_channel_register_call(
channel_state.c_channel, <const char *>method, NULL, NULL)
def __dealloc__(self):
cpython.Py_DECREF(self.method)
@property
def call_handle(self):
return cpython.PyLong_FromVoidPtr(self.c_call_handle)
cdef tuple _operate(grpc_call *c_call, object operations, object user_tag):
cdef grpc_call_error c_call_error
@ -199,7 +218,7 @@ cdef void _call(
grpc_completion_queue *c_completion_queue, on_success, int flags, method,
host, object deadline, CallCredentials credentials,
object operationses_and_user_tags, object metadata,
object context) except *:
object context, object registered_call_handle) except *:
"""Invokes an RPC.
Args:
@ -226,6 +245,8 @@ cdef void _call(
must be present in the first element of this value.
metadata: The metadata for this call.
context: Context object for distributed tracing.
registered_call_handle: An int representing the call handle of the method, or
None if the method is not registered.
"""
cdef grpc_slice method_slice
cdef grpc_slice host_slice
@ -242,10 +263,16 @@ cdef void _call(
else:
host_slice = _slice_from_bytes(host)
host_slice_ptr = &host_slice
call_state.c_call = grpc_channel_create_call(
channel_state.c_channel, NULL, flags,
c_completion_queue, method_slice, host_slice_ptr,
_timespec_from_time(deadline), NULL)
if registered_call_handle:
call_state.c_call = grpc_channel_create_registered_call(
channel_state.c_channel, NULL, flags,
c_completion_queue, cpython.PyLong_AsVoidPtr(registered_call_handle),
_timespec_from_time(deadline), NULL)
else:
call_state.c_call = grpc_channel_create_call(
channel_state.c_channel, NULL, flags,
c_completion_queue, method_slice, host_slice_ptr,
_timespec_from_time(deadline), NULL)
grpc_slice_unref(method_slice)
if host_slice_ptr:
grpc_slice_unref(host_slice)
@ -309,7 +336,7 @@ cdef class IntegratedCall:
cdef IntegratedCall _integrated_call(
_ChannelState state, int flags, method, host, object deadline,
object metadata, CallCredentials credentials, operationses_and_user_tags,
object context):
object context, object registered_call_handle):
call_state = _CallState()
def on_success(started_tags):
@ -318,7 +345,8 @@ cdef IntegratedCall _integrated_call(
_call(
state, call_state, state.c_call_completion_queue, on_success, flags,
method, host, deadline, credentials, operationses_and_user_tags, metadata, context)
method, host, deadline, credentials, operationses_and_user_tags,
metadata, context, registered_call_handle)
return IntegratedCall(state, call_state)
@ -371,7 +399,7 @@ cdef class SegregatedCall:
cdef SegregatedCall _segregated_call(
_ChannelState state, int flags, method, host, object deadline,
object metadata, CallCredentials credentials, operationses_and_user_tags,
object context):
object context, object registered_call_handle):
cdef _CallState call_state = _CallState()
cdef SegregatedCall segregated_call
cdef grpc_completion_queue *c_completion_queue
@ -389,7 +417,7 @@ cdef SegregatedCall _segregated_call(
_call(
state, call_state, c_completion_queue, on_success, flags, method, host,
deadline, credentials, operationses_and_user_tags, metadata,
context)
context, registered_call_handle)
except:
_destroy_c_completion_queue(c_completion_queue)
raise
@ -486,6 +514,7 @@ cdef class Channel:
else grpc_insecure_credentials_create())
self._state.c_channel = grpc_channel_create(
<char *>target, c_channel_credentials, channel_args.c_args())
self._registered_call_handles = {}
grpc_channel_credentials_release(c_channel_credentials)
def target(self):
@ -499,10 +528,10 @@ cdef class Channel:
def integrated_call(
self, int flags, method, host, object deadline, object metadata,
CallCredentials credentials, operationses_and_tags,
object context = None):
object context = None, object registered_call_handle = None):
return _integrated_call(
self._state, flags, method, host, deadline, metadata, credentials,
operationses_and_tags, context)
operationses_and_tags, context, registered_call_handle)
def next_call_event(self):
def on_success(tag):
@ -521,10 +550,10 @@ cdef class Channel:
def segregated_call(
self, int flags, method, host, object deadline, object metadata,
CallCredentials credentials, operationses_and_tags,
object context = None):
object context = None, object registered_call_handle = None):
return _segregated_call(
self._state, flags, method, host, deadline, metadata, credentials,
operationses_and_tags, context)
operationses_and_tags, context, registered_call_handle)
def check_connectivity_state(self, bint try_to_connect):
with self._state.condition:
@ -543,3 +572,19 @@ cdef class Channel:
def close_on_fork(self, code, details):
_close(self, code, details, True)
def get_registered_call_handle(self, method):
"""
Get or registers a call handler for a method.
This method is not thread-safe.
Args:
method: Required, the method name for the RPC.
Returns:
The registered call handle pointer in the form of a Python Long.
"""
if method not in self._registered_call_handles.keys():
self._registered_call_handles[method] = CallHandle(self._state, method)
return self._registered_call_handles[method].call_handle

@ -433,6 +433,12 @@ cdef extern from "grpc/grpc.h":
grpc_channel *channel, grpc_call *parent_call, uint32_t propagation_mask,
grpc_completion_queue *completion_queue, grpc_slice method,
const grpc_slice *host, gpr_timespec deadline, void *reserved) nogil
void *grpc_channel_register_call(
grpc_channel *channel, const char *method, const char *host, void *reserved) nogil
grpc_call *grpc_channel_create_registered_call(
grpc_channel *channel, grpc_call *parent_call, uint32_t propagation_mask,
grpc_completion_queue *completion_queue, void* registered_call_handle,
gpr_timespec deadline, void *reserved) nogil
grpc_connectivity_state grpc_channel_check_connectivity_state(
grpc_channel *channel, int try_to_connect) nogil
void grpc_channel_watch_connectivity_state(

@ -684,57 +684,85 @@ class _Channel(grpc.Channel):
def unsubscribe(self, callback: Callable):
self._channel.unsubscribe(callback)
# pylint: disable=arguments-differ
def unary_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.UnaryUnaryMultiCallable:
# pytype: disable=wrong-arg-count
thunk = lambda m: self._channel.unary_unary(
m, request_serializer, response_deserializer
m,
request_serializer,
response_deserializer,
_registered_method,
)
# pytype: enable=wrong-arg-count
if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor):
return _UnaryUnaryMultiCallable(thunk, method, self._interceptor)
else:
return thunk(method)
# pylint: disable=arguments-differ
def unary_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.UnaryStreamMultiCallable:
# pytype: disable=wrong-arg-count
thunk = lambda m: self._channel.unary_stream(
m, request_serializer, response_deserializer
m,
request_serializer,
response_deserializer,
_registered_method,
)
# pytype: enable=wrong-arg-count
if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor):
return _UnaryStreamMultiCallable(thunk, method, self._interceptor)
else:
return thunk(method)
# pylint: disable=arguments-differ
def stream_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.StreamUnaryMultiCallable:
# pytype: disable=wrong-arg-count
thunk = lambda m: self._channel.stream_unary(
m, request_serializer, response_deserializer
m,
request_serializer,
response_deserializer,
_registered_method,
)
# pytype: enable=wrong-arg-count
if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor):
return _StreamUnaryMultiCallable(thunk, method, self._interceptor)
else:
return thunk(method)
# pylint: disable=arguments-differ
def stream_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> grpc.StreamStreamMultiCallable:
# pytype: disable=wrong-arg-count
thunk = lambda m: self._channel.stream_stream(
m, request_serializer, response_deserializer
m,
request_serializer,
response_deserializer,
_registered_method,
)
# pytype: enable=wrong-arg-count
if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor):
return _StreamStreamMultiCallable(thunk, method, self._interceptor)
else:

@ -159,7 +159,19 @@ class ChannelCache:
channel_credentials: Optional[grpc.ChannelCredentials],
insecure: bool,
compression: Optional[grpc.Compression],
) -> grpc.Channel:
method: str,
_registered_method: bool,
) -> Tuple[grpc.Channel, Optional[int]]:
"""Get a channel from cache or creates a new channel.
This method also takes care of register method for channel,
which means we'll register a new call handle if we're calling a
non-registered method for an existing channel.
Returns:
A tuple with two items. The first item is the channel, second item is
the call handle if the method is registered, None if it's not registered.
"""
if insecure and channel_credentials:
raise ValueError(
"The insecure option is mutually exclusive with "
@ -176,18 +188,25 @@ class ChannelCache:
key = (target, options, channel_credentials, compression)
with self._lock:
channel_data = self._mapping.get(key, None)
call_handle = None
if channel_data is not None:
channel = channel_data[0]
# Register a new call handle if we're calling a registered method for an
# existing channel and this method is not registered.
if _registered_method:
call_handle = channel._get_registered_call_handle(method)
self._mapping.pop(key)
self._mapping[key] = (
channel,
datetime.datetime.now() + _EVICTION_PERIOD,
)
return channel
return channel, call_handle
else:
channel = _create_channel(
target, options, channel_credentials, compression
)
if _registered_method:
call_handle = channel._get_registered_call_handle(method)
self._mapping[key] = (
channel,
datetime.datetime.now() + _EVICTION_PERIOD,
@ -197,7 +216,7 @@ class ChannelCache:
or len(self._mapping) >= _MAXIMUM_CHANNELS
):
self._condition.notify()
return channel
return channel, call_handle
def _test_only_channel_count(self) -> int:
with self._lock:
@ -205,6 +224,7 @@ class ChannelCache:
@experimental_api
# pylint: disable=too-many-locals
def unary_unary(
request: RequestType,
target: str,
@ -219,6 +239,7 @@ def unary_unary(
wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = _DEFAULT_TIMEOUT,
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None,
_registered_method: Optional[bool] = False,
) -> ResponseType:
"""Invokes a unary-unary RPC without an explicitly specified channel.
@ -272,11 +293,17 @@ def unary_unary(
Returns:
The response to the RPC.
"""
channel = ChannelCache.get().get_channel(
target, options, channel_credentials, insecure, compression
channel, method_handle = ChannelCache.get().get_channel(
target,
options,
channel_credentials,
insecure,
compression,
method,
_registered_method,
)
multicallable = channel.unary_unary(
method, request_serializer, response_deserializer
method, request_serializer, response_deserializer, method_handle
)
wait_for_ready = wait_for_ready if wait_for_ready is not None else True
return multicallable(
@ -289,6 +316,7 @@ def unary_unary(
@experimental_api
# pylint: disable=too-many-locals
def unary_stream(
request: RequestType,
target: str,
@ -303,6 +331,7 @@ def unary_stream(
wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = _DEFAULT_TIMEOUT,
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None,
_registered_method: Optional[bool] = False,
) -> Iterator[ResponseType]:
"""Invokes a unary-stream RPC without an explicitly specified channel.
@ -355,11 +384,17 @@ def unary_stream(
Returns:
An iterator of responses.
"""
channel = ChannelCache.get().get_channel(
target, options, channel_credentials, insecure, compression
channel, method_handle = ChannelCache.get().get_channel(
target,
options,
channel_credentials,
insecure,
compression,
method,
_registered_method,
)
multicallable = channel.unary_stream(
method, request_serializer, response_deserializer
method, request_serializer, response_deserializer, method_handle
)
wait_for_ready = wait_for_ready if wait_for_ready is not None else True
return multicallable(
@ -372,6 +407,7 @@ def unary_stream(
@experimental_api
# pylint: disable=too-many-locals
def stream_unary(
request_iterator: Iterator[RequestType],
target: str,
@ -386,6 +422,7 @@ def stream_unary(
wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = _DEFAULT_TIMEOUT,
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None,
_registered_method: Optional[bool] = False,
) -> ResponseType:
"""Invokes a stream-unary RPC without an explicitly specified channel.
@ -438,11 +475,17 @@ def stream_unary(
Returns:
The response to the RPC.
"""
channel = ChannelCache.get().get_channel(
target, options, channel_credentials, insecure, compression
channel, method_handle = ChannelCache.get().get_channel(
target,
options,
channel_credentials,
insecure,
compression,
method,
_registered_method,
)
multicallable = channel.stream_unary(
method, request_serializer, response_deserializer
method, request_serializer, response_deserializer, method_handle
)
wait_for_ready = wait_for_ready if wait_for_ready is not None else True
return multicallable(
@ -455,6 +498,7 @@ def stream_unary(
@experimental_api
# pylint: disable=too-many-locals
def stream_stream(
request_iterator: Iterator[RequestType],
target: str,
@ -469,6 +513,7 @@ def stream_stream(
wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = _DEFAULT_TIMEOUT,
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None,
_registered_method: Optional[bool] = False,
) -> Iterator[ResponseType]:
"""Invokes a stream-stream RPC without an explicitly specified channel.
@ -521,11 +566,17 @@ def stream_stream(
Returns:
An iterator of responses.
"""
channel = ChannelCache.get().get_channel(
target, options, channel_credentials, insecure, compression
channel, method_handle = ChannelCache.get().get_channel(
target,
options,
channel_credentials,
insecure,
compression,
method,
_registered_method,
)
multicallable = channel.stream_stream(
method, request_serializer, response_deserializer
method, request_serializer, response_deserializer, method_handle
)
wait_for_ready = wait_for_ready if wait_for_ready is not None else True
return multicallable(

@ -272,6 +272,7 @@ class Channel(abc.ABC):
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> UnaryUnaryMultiCallable:
"""Creates a UnaryUnaryMultiCallable for a unary-unary method.
@ -282,6 +283,8 @@ class Channel(abc.ABC):
response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None
is passed.
_registered_method: Implementation Private. Optional: A bool representing
whether the method is registered.
Returns:
A UnaryUnaryMultiCallable value for the named unary-unary method.
@ -293,6 +296,7 @@ class Channel(abc.ABC):
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> UnaryStreamMultiCallable:
"""Creates a UnaryStreamMultiCallable for a unary-stream method.
@ -303,6 +307,8 @@ class Channel(abc.ABC):
response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None
is passed.
_registered_method: Implementation Private. Optional: A bool representing
whether the method is registered.
Returns:
A UnarySteramMultiCallable value for the named unary-stream method.
@ -314,6 +320,7 @@ class Channel(abc.ABC):
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> StreamUnaryMultiCallable:
"""Creates a StreamUnaryMultiCallable for a stream-unary method.
@ -324,6 +331,8 @@ class Channel(abc.ABC):
response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None
is passed.
_registered_method: Implementation Private. Optional: A bool representing
whether the method is registered.
Returns:
A StreamUnaryMultiCallable value for the named stream-unary method.
@ -335,6 +344,7 @@ class Channel(abc.ABC):
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> StreamStreamMultiCallable:
"""Creates a StreamStreamMultiCallable for a stream-stream method.
@ -345,6 +355,8 @@ class Channel(abc.ABC):
response_deserializer: Optional :term:`deserializer` for deserializing the
response message. Response goes undeserialized in case None
is passed.
_registered_method: Implementation Private. Optional: A bool representing
whether the method is registered.
Returns:
A StreamStreamMultiCallable value for the named stream-stream method.

@ -478,11 +478,20 @@ class Channel(_base_channel.Channel):
await self.wait_for_state_change(state)
state = self.get_state(try_to_connect=True)
# TODO(xuanwn): Implement this method after we have
# observability for Asyncio.
def _get_registered_call_handle(self, method: str) -> int:
pass
# TODO(xuanwn): Implement _registered_method after we have
# observability for Asyncio.
# pylint: disable=arguments-differ,unused-argument
def unary_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> UnaryUnaryMultiCallable:
return UnaryUnaryMultiCallable(
self._channel,
@ -494,11 +503,15 @@ class Channel(_base_channel.Channel):
self._loop,
)
# TODO(xuanwn): Implement _registered_method after we have
# observability for Asyncio.
# pylint: disable=arguments-differ,unused-argument
def unary_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> UnaryStreamMultiCallable:
return UnaryStreamMultiCallable(
self._channel,
@ -510,11 +523,15 @@ class Channel(_base_channel.Channel):
self._loop,
)
# TODO(xuanwn): Implement _registered_method after we have
# observability for Asyncio.
# pylint: disable=arguments-differ,unused-argument
def stream_unary(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> StreamUnaryMultiCallable:
return StreamUnaryMultiCallable(
self._channel,
@ -526,11 +543,15 @@ class Channel(_base_channel.Channel):
self._loop,
)
# TODO(xuanwn): Implement _registered_method after we have
# observability for Asyncio.
# pylint: disable=arguments-differ,unused-argument
def stream_stream(
self,
method: str,
request_serializer: Optional[SerializingFunction] = None,
response_deserializer: Optional[DeserializingFunction] = None,
_registered_method: Optional[bool] = False,
) -> StreamStreamMultiCallable:
return StreamStreamMultiCallable(
self._channel,

@ -31,23 +31,42 @@ class TestingChannel(grpc_testing.Channel):
def unsubscribe(self, callback):
raise NotImplementedError()
def _get_registered_call_handle(self, method: str) -> int:
pass
def unary_unary(
self, method, request_serializer=None, response_deserializer=None
self,
method,
request_serializer=None,
response_deserializer=None,
_registered_method=False,
):
return _multi_callable.UnaryUnary(method, self._state)
def unary_stream(
self, method, request_serializer=None, response_deserializer=None
self,
method,
request_serializer=None,
response_deserializer=None,
_registered_method=False,
):
return _multi_callable.UnaryStream(method, self._state)
def stream_unary(
self, method, request_serializer=None, response_deserializer=None
self,
method,
request_serializer=None,
response_deserializer=None,
_registered_method=False,
):
return _multi_callable.StreamUnary(method, self._state)
def stream_stream(
self, method, request_serializer=None, response_deserializer=None
self,
method,
request_serializer=None,
response_deserializer=None,
_registered_method=False,
):
return _multi_callable.StreamStream(method, self._state)

@ -99,16 +99,20 @@ class ChannelzServicerTest(unittest.TestCase):
def _send_successful_unary_unary(self, idx):
_, r = (
self._pairs[idx]
.channel.unary_unary(_SUCCESSFUL_UNARY_UNARY)
.channel.unary_unary(
_SUCCESSFUL_UNARY_UNARY,
_registered_method=True,
)
.with_call(_REQUEST)
)
self.assertEqual(r.code(), grpc.StatusCode.OK)
def _send_failed_unary_unary(self, idx):
try:
self._pairs[idx].channel.unary_unary(_FAILED_UNARY_UNARY).with_call(
_REQUEST
)
self._pairs[idx].channel.unary_unary(
_FAILED_UNARY_UNARY,
_registered_method=True,
).with_call(_REQUEST)
except grpc.RpcError:
return
else:
@ -117,7 +121,10 @@ class ChannelzServicerTest(unittest.TestCase):
def _send_successful_stream_stream(self, idx):
response_iterator = (
self._pairs[idx]
.channel.stream_stream(_SUCCESSFUL_STREAM_STREAM)
.channel.stream_stream(
_SUCCESSFUL_STREAM_STREAM,
_registered_method=True,
)
.__call__(iter([_REQUEST] * test_constants.STREAM_LENGTH))
)
cnt = 0

@ -86,7 +86,10 @@ class TestCsds(unittest.TestCase):
# Force the XdsClient to initialize and request a resource
with self.assertRaises(grpc.RpcError) as rpc_error:
dummy_channel.unary_unary("")(b"", wait_for_ready=False, timeout=1)
dummy_channel.unary_unary(
"",
_registered_method=True,
)(b"", wait_for_ready=False, timeout=1)
self.assertEqual(
grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.exception.code()
)

@ -543,6 +543,17 @@ class PythonPluginTest(unittest.TestCase):
)
service.server.stop(None)
def testRegisteredMethod(self):
"""Tests that we're setting _registered_call_handle when create call using generated stub."""
service = _CreateService()
self.assertTrue(service.stub.UnaryCall._registered_call_handle)
self.assertTrue(
service.stub.StreamingOutputCall._registered_call_handle
)
self.assertTrue(service.stub.StreamingInputCall._registered_call_handle)
self.assertTrue(service.stub.FullDuplexCall._registered_call_handle)
service.server.stop(None)
@unittest.skipIf(
sys.version_info[0] < 3 or sys.version_info[1] < 6,

@ -32,13 +32,16 @@ _TIMEOUT = 60 * 60 * 24
class GenericStub(object):
def __init__(self, channel):
self.UnaryCall = channel.unary_unary(
"/grpc.testing.BenchmarkService/UnaryCall"
"/grpc.testing.BenchmarkService/UnaryCall",
_registered_method=True,
)
self.StreamingFromServer = channel.unary_stream(
"/grpc.testing.BenchmarkService/StreamingFromServer"
"/grpc.testing.BenchmarkService/StreamingFromServer",
_registered_method=True,
)
self.StreamingCall = channel.stream_stream(
"/grpc.testing.BenchmarkService/StreamingCall"
"/grpc.testing.BenchmarkService/StreamingCall",
_registered_method=True,
)

@ -138,7 +138,10 @@ class StatusTest(unittest.TestCase):
self._channel.close()
def test_status_ok(self):
_, call = self._channel.unary_unary(_STATUS_OK).with_call(_REQUEST)
_, call = self._channel.unary_unary(
_STATUS_OK,
_registered_method=True,
).with_call(_REQUEST)
# Succeed RPC doesn't have status
status = rpc_status.from_call(call)
@ -146,7 +149,10 @@ class StatusTest(unittest.TestCase):
def test_status_not_ok(self):
with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_STATUS_NOT_OK).with_call(_REQUEST)
self._channel.unary_unary(
_STATUS_NOT_OK,
_registered_method=True,
).with_call(_REQUEST)
rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL)
@ -156,7 +162,10 @@ class StatusTest(unittest.TestCase):
def test_error_details(self):
with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_ERROR_DETAILS).with_call(_REQUEST)
self._channel.unary_unary(
_ERROR_DETAILS,
_registered_method=True,
).with_call(_REQUEST)
rpc_error = exception_context.exception
status = rpc_status.from_call(rpc_error)
@ -173,7 +182,10 @@ class StatusTest(unittest.TestCase):
def test_code_message_validation(self):
with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_INCONSISTENT).with_call(_REQUEST)
self._channel.unary_unary(
_INCONSISTENT,
_registered_method=True,
).with_call(_REQUEST)
rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), grpc.StatusCode.NOT_FOUND)
@ -182,7 +194,10 @@ class StatusTest(unittest.TestCase):
def test_invalid_code(self):
with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_INVALID_CODE).with_call(_REQUEST)
self._channel.unary_unary(
_INVALID_CODE,
_registered_method=True,
).with_call(_REQUEST)
rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN)
# Invalid status code exception raised during coversion

@ -107,7 +107,10 @@ class AbortTest(unittest.TestCase):
def test_abort(self):
with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_ABORT)(_REQUEST)
self._channel.unary_unary(
_ABORT,
_registered_method=True,
)(_REQUEST)
rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL)
@ -124,7 +127,10 @@ class AbortTest(unittest.TestCase):
# Servicer will abort() after creating a local ref to do_not_leak_me.
with self.assertRaises(grpc.RpcError):
self._channel.unary_unary(_ABORT)(_REQUEST)
self._channel.unary_unary(
_ABORT,
_registered_method=True,
)(_REQUEST)
# Server may still have a stack frame reference to the exception even
# after client sees error, so ensure server has shutdown.
@ -134,7 +140,10 @@ class AbortTest(unittest.TestCase):
def test_abort_with_status(self):
with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_ABORT_WITH_STATUS)(_REQUEST)
self._channel.unary_unary(
_ABORT_WITH_STATUS,
_registered_method=True,
)(_REQUEST)
rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL)
@ -143,7 +152,10 @@ class AbortTest(unittest.TestCase):
def test_invalid_code(self):
with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_INVALID_CODE)(_REQUEST)
self._channel.unary_unary(
_INVALID_CODE,
_registered_method=True,
)(_REQUEST)
rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN)

@ -78,7 +78,10 @@ class AuthContextTest(unittest.TestCase):
server.start()
with grpc.insecure_channel("localhost:%d" % port) as channel:
response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
response = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)(_REQUEST)
server.stop(None)
auth_data = pickle.loads(response)
@ -115,7 +118,10 @@ class AuthContextTest(unittest.TestCase):
channel_creds,
options=_PROPERTY_OPTIONS,
)
response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
response = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)(_REQUEST)
channel.close()
server.stop(None)
@ -161,7 +167,10 @@ class AuthContextTest(unittest.TestCase):
options=_PROPERTY_OPTIONS,
)
response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
response = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)(_REQUEST)
channel.close()
server.stop(None)
@ -180,7 +189,10 @@ class AuthContextTest(unittest.TestCase):
channel = grpc.secure_channel(
"localhost:{}".format(port), channel_creds, options=channel_options
)
response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
response = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)(_REQUEST)
auth_data = pickle.loads(response)
self.assertEqual(
expect_ssl_session_reused,

@ -123,7 +123,10 @@ class ChannelCloseTest(unittest.TestCase):
def test_close_immediately_after_call_invocation(self):
channel = grpc.insecure_channel("localhost:{}".format(self._port))
multi_callable = channel.stream_stream(_STREAM_URI)
multi_callable = channel.stream_stream(
_STREAM_URI,
_registered_method=True,
)
request_iterator = _Pipe(())
response_iterator = multi_callable(request_iterator)
channel.close()
@ -133,7 +136,10 @@ class ChannelCloseTest(unittest.TestCase):
def test_close_while_call_active(self):
channel = grpc.insecure_channel("localhost:{}".format(self._port))
multi_callable = channel.stream_stream(_STREAM_URI)
multi_callable = channel.stream_stream(
_STREAM_URI,
_registered_method=True,
)
request_iterator = _Pipe((b"abc",))
response_iterator = multi_callable(request_iterator)
next(response_iterator)
@ -146,7 +152,10 @@ class ChannelCloseTest(unittest.TestCase):
with grpc.insecure_channel(
"localhost:{}".format(self._port)
) as channel: # pylint: disable=bad-continuation
multi_callable = channel.stream_stream(_STREAM_URI)
multi_callable = channel.stream_stream(
_STREAM_URI,
_registered_method=True,
)
request_iterator = _Pipe((b"abc",))
response_iterator = multi_callable(request_iterator)
next(response_iterator)
@ -158,7 +167,10 @@ class ChannelCloseTest(unittest.TestCase):
with grpc.insecure_channel(
"localhost:{}".format(self._port)
) as channel: # pylint: disable=bad-continuation
multi_callable = channel.stream_stream(_STREAM_URI)
multi_callable = channel.stream_stream(
_STREAM_URI,
_registered_method=True,
)
request_iterators = tuple(
_Pipe((b"abc",))
for _ in range(test_constants.THREAD_CONCURRENCY)
@ -176,7 +188,10 @@ class ChannelCloseTest(unittest.TestCase):
def test_many_concurrent_closes(self):
channel = grpc.insecure_channel("localhost:{}".format(self._port))
multi_callable = channel.stream_stream(_STREAM_URI)
multi_callable = channel.stream_stream(
_STREAM_URI,
_registered_method=True,
)
request_iterator = _Pipe((b"abc",))
response_iterator = multi_callable(request_iterator)
next(response_iterator)
@ -203,10 +218,16 @@ class ChannelCloseTest(unittest.TestCase):
with grpc.insecure_channel(
"localhost:{}".format(self._port)
) as channel:
stream_multi_callable = channel.stream_stream(_STREAM_URI)
stream_multi_callable = channel.stream_stream(
_STREAM_URI,
_registered_method=True,
)
endless_iterator = itertools.repeat(b"abc")
stream_response_iterator = stream_multi_callable(endless_iterator)
future = channel.unary_unary(_UNARY_URI).future(b"abc")
future = channel.unary_unary(
_UNARY_URI,
_registered_method=True,
).future(b"abc")
def on_done_callback(future):
raise Exception("This should not cause a deadlock.")

@ -237,7 +237,10 @@ def _get_compression_ratios(
def _unary_unary_client(channel, multicallable_kwargs, message):
multi_callable = channel.unary_unary(_UNARY_UNARY)
multi_callable = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
response = multi_callable(message, **multicallable_kwargs)
if response != message:
raise RuntimeError(
@ -246,7 +249,10 @@ def _unary_unary_client(channel, multicallable_kwargs, message):
def _unary_stream_client(channel, multicallable_kwargs, message):
multi_callable = channel.unary_stream(_UNARY_STREAM)
multi_callable = channel.unary_stream(
_UNARY_STREAM,
_registered_method=True,
)
response_iterator = multi_callable(message, **multicallable_kwargs)
for response in response_iterator:
if response != message:
@ -256,7 +262,10 @@ def _unary_stream_client(channel, multicallable_kwargs, message):
def _stream_unary_client(channel, multicallable_kwargs, message):
multi_callable = channel.stream_unary(_STREAM_UNARY)
multi_callable = channel.stream_unary(
_STREAM_UNARY,
_registered_method=True,
)
requests = (_REQUEST for _ in range(_STREAM_LENGTH))
response = multi_callable(requests, **multicallable_kwargs)
if response != message:
@ -266,7 +275,10 @@ def _stream_unary_client(channel, multicallable_kwargs, message):
def _stream_stream_client(channel, multicallable_kwargs, message):
multi_callable = channel.stream_stream(_STREAM_STREAM)
multi_callable = channel.stream_stream(
_STREAM_STREAM,
_registered_method=True,
)
request_prefix = str(0).encode("ascii") * 100
requests = (
request_prefix + str(i).encode("ascii") for i in range(_STREAM_LENGTH)

@ -116,7 +116,10 @@ class ContextVarsPropagationTest(unittest.TestCase):
local_credentials, call_credentials
)
with grpc.secure_channel(target, composite_credentials) as channel:
stub = channel.unary_unary(_UNARY_UNARY)
stub = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
response = stub(_REQUEST, wait_for_ready=True)
self.assertEqual(_REQUEST, response)
@ -142,7 +145,10 @@ class ContextVarsPropagationTest(unittest.TestCase):
with grpc.secure_channel(
target, composite_credentials
) as channel:
stub = channel.unary_unary(_UNARY_UNARY)
stub = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
wait_group.done()
wait_group.wait()
for i in range(_RPC_COUNT):

@ -55,7 +55,10 @@ class DNSResolverTest(unittest.TestCase):
"loopback46.unittest.grpc.io:%d" % self._port
) as channel:
self.assertEqual(
channel.unary_unary(_METHOD)(
channel.unary_unary(
_METHOD,
_registered_method=True,
)(
_REQUEST,
timeout=10,
),

@ -96,25 +96,33 @@ class EmptyMessageTest(unittest.TestCase):
self._channel.close()
def testUnaryUnary(self):
response = self._channel.unary_unary(_UNARY_UNARY)(_REQUEST)
response = self._channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)(_REQUEST)
self.assertEqual(_RESPONSE, response)
def testUnaryStream(self):
response_iterator = self._channel.unary_stream(_UNARY_STREAM)(_REQUEST)
response_iterator = self._channel.unary_stream(
_UNARY_STREAM,
_registered_method=True,
)(_REQUEST)
self.assertSequenceEqual(
[_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator)
)
def testStreamUnary(self):
response = self._channel.stream_unary(_STREAM_UNARY)(
iter([_REQUEST] * test_constants.STREAM_LENGTH)
)
response = self._channel.stream_unary(
_STREAM_UNARY,
_registered_method=True,
)(iter([_REQUEST] * test_constants.STREAM_LENGTH))
self.assertEqual(_RESPONSE, response)
def testStreamStream(self):
response_iterator = self._channel.stream_stream(_STREAM_STREAM)(
iter([_REQUEST] * test_constants.STREAM_LENGTH)
)
response_iterator = self._channel.stream_stream(
_STREAM_STREAM,
_registered_method=True,
)(iter([_REQUEST] * test_constants.STREAM_LENGTH))
self.assertSequenceEqual(
[_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator)
)

@ -73,7 +73,10 @@ class ErrorMessageEncodingTest(unittest.TestCase):
def testMessageEncoding(self):
for message in _UNICODE_ERROR_MESSAGES:
multi_callable = self._channel.unary_unary(_UNARY_UNARY)
multi_callable = self._channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
with self.assertRaises(grpc.RpcError) as cm:
multi_callable(message.encode("utf-8"))

@ -210,14 +210,20 @@ if __name__ == "__main__":
method = TEST_TO_METHOD[args.scenario]
if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL:
multi_callable = channel.unary_unary(method)
multi_callable = channel.unary_unary(
method,
_registered_method=True,
)
future = multi_callable.future(REQUEST)
result, call = multi_callable.with_call(REQUEST)
elif (
args.scenario == IN_FLIGHT_UNARY_STREAM_CALL
or args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL
):
multi_callable = channel.unary_stream(method)
multi_callable = channel.unary_stream(
method,
_registered_method=True,
)
response_iterator = multi_callable(REQUEST)
for response in response_iterator:
pass
@ -225,7 +231,10 @@ if __name__ == "__main__":
args.scenario == IN_FLIGHT_STREAM_UNARY_CALL
or args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL
):
multi_callable = channel.stream_unary(method)
multi_callable = channel.stream_unary(
method,
_registered_method=True,
)
future = multi_callable.future(infinite_request_iterator())
result, call = multi_callable.with_call(
iter([REQUEST] * test_constants.STREAM_LENGTH)
@ -234,7 +243,10 @@ if __name__ == "__main__":
args.scenario == IN_FLIGHT_STREAM_STREAM_CALL
or args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL
):
multi_callable = channel.stream_stream(method)
multi_callable = channel.stream_stream(
method,
_registered_method=True,
)
response_iterator = multi_callable(infinite_request_iterator())
for response in response_iterator:
pass

@ -231,7 +231,7 @@ class _GenericHandler(grpc.GenericRpcHandler):
def _unary_unary_multi_callable(channel):
return channel.unary_unary(_UNARY_UNARY)
return channel.unary_unary(_UNARY_UNARY, _registered_method=True)
def _unary_stream_multi_callable(channel):
@ -239,6 +239,7 @@ def _unary_stream_multi_callable(channel):
_UNARY_STREAM,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
)
@ -247,11 +248,12 @@ def _stream_unary_multi_callable(channel):
_STREAM_UNARY,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
)
def _stream_stream_multi_callable(channel):
return channel.stream_stream(_STREAM_STREAM)
return channel.stream_stream(_STREAM_STREAM, _registered_method=True)
class _ClientCallDetails(
@ -562,7 +564,7 @@ class InterceptorTest(unittest.TestCase):
self._record[:] = []
multi_callable = _unary_unary_multi_callable(channel)
multi_callable.with_call(
response, call = multi_callable.with_call(
request,
metadata=(
(

@ -32,7 +32,7 @@ _STREAM_STREAM = "/test/StreamStream"
def _unary_unary_multi_callable(channel):
return channel.unary_unary(_UNARY_UNARY)
return channel.unary_unary(_UNARY_UNARY, _registered_method=True)
def _unary_stream_multi_callable(channel):
@ -40,6 +40,7 @@ def _unary_stream_multi_callable(channel):
_UNARY_STREAM,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
)
@ -48,11 +49,15 @@ def _stream_unary_multi_callable(channel):
_STREAM_UNARY,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
)
def _stream_stream_multi_callable(channel):
return channel.stream_stream(_STREAM_STREAM)
return channel.stream_stream(
_STREAM_STREAM,
_registered_method=True,
)
class InvalidMetadataTest(unittest.TestCase):

@ -219,7 +219,10 @@ class FailAfterFewIterationsCounter(object):
def _unary_unary_multi_callable(channel):
return channel.unary_unary(_UNARY_UNARY)
return channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
def _unary_stream_multi_callable(channel):
@ -227,6 +230,7 @@ def _unary_stream_multi_callable(channel):
_UNARY_STREAM,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
)
@ -235,19 +239,29 @@ def _stream_unary_multi_callable(channel):
_STREAM_UNARY,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
)
def _stream_stream_multi_callable(channel):
return channel.stream_stream(_STREAM_STREAM)
return channel.stream_stream(
_STREAM_STREAM,
_registered_method=True,
)
def _defective_handler_multi_callable(channel):
return channel.unary_unary(_DEFECTIVE_GENERIC_RPC_HANDLER)
return channel.unary_unary(
_DEFECTIVE_GENERIC_RPC_HANDLER,
_registered_method=True,
)
def _defective_nested_exception_handler_multi_callable(channel):
return channel.unary_unary(_UNARY_UNARY_NESTED_EXCEPTION)
return channel.unary_unary(
_UNARY_UNARY_NESTED_EXCEPTION,
_registered_method=True,
)
class InvocationDefectsTest(unittest.TestCase):

@ -53,9 +53,10 @@ class LocalCredentialsTest(unittest.TestCase):
) as channel:
self.assertEqual(
b"abc",
channel.unary_unary("/test/method")(
b"abc", wait_for_ready=True
),
channel.unary_unary(
"/test/method",
_registered_method=True,
)(b"abc", wait_for_ready=True),
)
server.stop(None)
@ -77,9 +78,10 @@ class LocalCredentialsTest(unittest.TestCase):
with grpc.secure_channel(server_addr, channel_creds) as channel:
self.assertEqual(
b"abc",
channel.unary_unary("/test/method")(
b"abc", wait_for_ready=True
),
channel.unary_unary(
"/test/method",
_registered_method=True,
)(b"abc", wait_for_ready=True),
)
server.stop(None)

@ -207,45 +207,53 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self._server.start()
self._channel = grpc.insecure_channel("localhost:{}".format(port))
unary_unary_method_name = "/".join(
(
"",
_SERVICE,
_UNARY_UNARY,
)
)
self._unary_unary = self._channel.unary_unary(
"/".join(
(
"",
_SERVICE,
_UNARY_UNARY,
)
),
unary_unary_method_name,
request_serializer=_REQUEST_SERIALIZER,
response_deserializer=_RESPONSE_DESERIALIZER,
_registered_method=True,
)
unary_stream_method_name = "/".join(
(
"",
_SERVICE,
_UNARY_STREAM,
)
)
self._unary_stream = self._channel.unary_stream(
"/".join(
(
"",
_SERVICE,
_UNARY_STREAM,
)
),
unary_stream_method_name,
_registered_method=True,
)
stream_unary_method_name = "/".join(
(
"",
_SERVICE,
_STREAM_UNARY,
)
)
self._stream_unary = self._channel.stream_unary(
"/".join(
(
"",
_SERVICE,
_STREAM_UNARY,
)
),
stream_unary_method_name,
_registered_method=True,
)
stream_stream_method_name = "/".join(
(
"",
_SERVICE,
_STREAM_STREAM,
)
)
self._stream_stream = self._channel.stream_stream(
"/".join(
(
"",
_SERVICE,
_STREAM_STREAM,
)
),
stream_stream_method_name,
request_serializer=_REQUEST_SERIALIZER,
response_deserializer=_RESPONSE_DESERIALIZER,
_registered_method=True,
)
def tearDown(self):
@ -828,16 +836,18 @@ class InspectContextTest(unittest.TestCase):
self._server.start()
self._channel = grpc.insecure_channel("localhost:{}".format(port))
unary_unary_method_name = "/".join(
(
"",
_SERVICE,
_UNARY_UNARY,
)
)
self._unary_unary = self._channel.unary_unary(
"/".join(
(
"",
_SERVICE,
_UNARY_UNARY,
)
),
unary_unary_method_name,
request_serializer=_REQUEST_SERIALIZER,
response_deserializer=_RESPONSE_DESERIALIZER,
_registered_method=True,
)
def tearDown(self):

@ -110,7 +110,10 @@ def create_phony_channel():
def perform_unary_unary_call(channel, wait_for_ready=None):
channel.unary_unary(_UNARY_UNARY).__call__(
channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
).__call__(
_REQUEST,
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready,
@ -118,7 +121,10 @@ def perform_unary_unary_call(channel, wait_for_ready=None):
def perform_unary_unary_with_call(channel, wait_for_ready=None):
channel.unary_unary(_UNARY_UNARY).with_call(
channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
).with_call(
_REQUEST,
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready,
@ -126,7 +132,10 @@ def perform_unary_unary_with_call(channel, wait_for_ready=None):
def perform_unary_unary_future(channel, wait_for_ready=None):
channel.unary_unary(_UNARY_UNARY).future(
channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
).future(
_REQUEST,
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready,
@ -134,7 +143,10 @@ def perform_unary_unary_future(channel, wait_for_ready=None):
def perform_unary_stream_call(channel, wait_for_ready=None):
response_iterator = channel.unary_stream(_UNARY_STREAM).__call__(
response_iterator = channel.unary_stream(
_UNARY_STREAM,
_registered_method=True,
).__call__(
_REQUEST,
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready,
@ -144,7 +156,10 @@ def perform_unary_stream_call(channel, wait_for_ready=None):
def perform_stream_unary_call(channel, wait_for_ready=None):
channel.stream_unary(_STREAM_UNARY).__call__(
channel.stream_unary(
_STREAM_UNARY,
_registered_method=True,
).__call__(
iter([_REQUEST] * test_constants.STREAM_LENGTH),
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready,
@ -152,7 +167,10 @@ def perform_stream_unary_call(channel, wait_for_ready=None):
def perform_stream_unary_with_call(channel, wait_for_ready=None):
channel.stream_unary(_STREAM_UNARY).with_call(
channel.stream_unary(
_STREAM_UNARY,
_registered_method=True,
).with_call(
iter([_REQUEST] * test_constants.STREAM_LENGTH),
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready,
@ -160,7 +178,10 @@ def perform_stream_unary_with_call(channel, wait_for_ready=None):
def perform_stream_unary_future(channel, wait_for_ready=None):
channel.stream_unary(_STREAM_UNARY).future(
channel.stream_unary(
_STREAM_UNARY,
_registered_method=True,
).future(
iter([_REQUEST] * test_constants.STREAM_LENGTH),
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready,
@ -168,7 +189,9 @@ def perform_stream_unary_future(channel, wait_for_ready=None):
def perform_stream_stream_call(channel, wait_for_ready=None):
response_iterator = channel.stream_stream(_STREAM_STREAM).__call__(
response_iterator = channel.stream_stream(
_STREAM_STREAM, _registered_method=True
).__call__(
iter([_REQUEST] * test_constants.STREAM_LENGTH),
timeout=test_constants.LONG_TIMEOUT,
wait_for_ready=wait_for_ready,

@ -195,7 +195,9 @@ class MetadataTest(unittest.TestCase):
self._channel.close()
def testUnaryUnary(self):
multi_callable = self._channel.unary_unary(_UNARY_UNARY)
multi_callable = self._channel.unary_unary(
_UNARY_UNARY, _registered_method=True
)
unused_response, call = multi_callable.with_call(
_REQUEST, metadata=_INVOCATION_METADATA
)
@ -211,7 +213,9 @@ class MetadataTest(unittest.TestCase):
)
def testUnaryStream(self):
multi_callable = self._channel.unary_stream(_UNARY_STREAM)
multi_callable = self._channel.unary_stream(
_UNARY_STREAM, _registered_method=True
)
call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA)
self.assertTrue(
test_common.metadata_transmitted(
@ -227,7 +231,9 @@ class MetadataTest(unittest.TestCase):
)
def testStreamUnary(self):
multi_callable = self._channel.stream_unary(_STREAM_UNARY)
multi_callable = self._channel.stream_unary(
_STREAM_UNARY, _registered_method=True
)
unused_response, call = multi_callable.with_call(
iter([_REQUEST] * test_constants.STREAM_LENGTH),
metadata=_INVOCATION_METADATA,
@ -244,7 +250,9 @@ class MetadataTest(unittest.TestCase):
)
def testStreamStream(self):
multi_callable = self._channel.stream_stream(_STREAM_STREAM)
multi_callable = self._channel.stream_stream(
_STREAM_STREAM, _registered_method=True
)
call = multi_callable(
iter([_REQUEST] * test_constants.STREAM_LENGTH),
metadata=_INVOCATION_METADATA,

@ -52,7 +52,10 @@ class ReconnectTest(unittest.TestCase):
server.add_insecure_port(addr)
server.start()
channel = grpc.insecure_channel(addr)
multi_callable = channel.unary_unary(_UNARY_UNARY)
multi_callable = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
self.assertEqual(_RESPONSE, multi_callable(_REQUEST))
server.stop(None)
# By default, the channel connectivity is checked every 5s

@ -149,7 +149,10 @@ class ResourceExhaustedTest(unittest.TestCase):
self._channel.close()
def testUnaryUnary(self):
multi_callable = self._channel.unary_unary(_UNARY_UNARY)
multi_callable = self._channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
futures = []
for _ in range(test_constants.THREAD_CONCURRENCY):
futures.append(multi_callable.future(_REQUEST))
@ -178,7 +181,10 @@ class ResourceExhaustedTest(unittest.TestCase):
self.assertEqual(_RESPONSE, multi_callable(_REQUEST))
def testUnaryStream(self):
multi_callable = self._channel.unary_stream(_UNARY_STREAM)
multi_callable = self._channel.unary_stream(
_UNARY_STREAM,
_registered_method=True,
)
calls = []
for _ in range(test_constants.THREAD_CONCURRENCY):
calls.append(multi_callable(_REQUEST))
@ -205,7 +211,10 @@ class ResourceExhaustedTest(unittest.TestCase):
self.assertEqual(_RESPONSE, response)
def testStreamUnary(self):
multi_callable = self._channel.stream_unary(_STREAM_UNARY)
multi_callable = self._channel.stream_unary(
_STREAM_UNARY,
_registered_method=True,
)
futures = []
request = iter([_REQUEST] * test_constants.STREAM_LENGTH)
for _ in range(test_constants.THREAD_CONCURRENCY):
@ -236,7 +245,10 @@ class ResourceExhaustedTest(unittest.TestCase):
self.assertEqual(_RESPONSE, multi_callable(request))
def testStreamStream(self):
multi_callable = self._channel.stream_stream(_STREAM_STREAM)
multi_callable = self._channel.stream_stream(
_STREAM_STREAM,
_registered_method=True,
)
calls = []
request = iter([_REQUEST] * test_constants.STREAM_LENGTH)
for _ in range(test_constants.THREAD_CONCURRENCY):

@ -277,7 +277,10 @@ class _GenericHandler(grpc.GenericRpcHandler):
def unary_unary_multi_callable(channel):
return channel.unary_unary(_UNARY_UNARY)
return channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)
def unary_stream_multi_callable(channel):
@ -285,6 +288,7 @@ def unary_stream_multi_callable(channel):
_UNARY_STREAM,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
)
@ -293,6 +297,7 @@ def unary_stream_non_blocking_multi_callable(channel):
_UNARY_STREAM_NON_BLOCKING,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
)
@ -301,15 +306,22 @@ def stream_unary_multi_callable(channel):
_STREAM_UNARY,
request_serializer=_SERIALIZE_REQUEST,
response_deserializer=_DESERIALIZE_RESPONSE,
_registered_method=True,
)
def stream_stream_multi_callable(channel):
return channel.stream_stream(_STREAM_STREAM)
return channel.stream_stream(
_STREAM_STREAM,
_registered_method=True,
)
def stream_stream_non_blocking_multi_callable(channel):
return channel.stream_stream(_STREAM_STREAM_NON_BLOCKING)
return channel.stream_stream(
_STREAM_STREAM_NON_BLOCKING,
_registered_method=True,
)
class BaseRPCTest(object):

@ -81,7 +81,10 @@ def run_test(args):
thread.start()
port = port_queue.get()
channel = grpc.insecure_channel("localhost:%d" % port)
multi_callable = channel.unary_unary(FORK_EXIT)
multi_callable = channel.unary_unary(
FORK_EXIT,
_registered_method=True,
)
result, call = multi_callable.with_call(REQUEST, wait_for_ready=True)
os.wait()
else:

@ -77,7 +77,10 @@ class SSLSessionCacheTest(unittest.TestCase):
channel = grpc.secure_channel(
"localhost:{}".format(port), channel_creds, options=channel_options
)
response = channel.unary_unary(_UNARY_UNARY)(_REQUEST)
response = channel.unary_unary(
_UNARY_UNARY,
_registered_method=True,
)(_REQUEST)
auth_data = pickle.loads(response)
self.assertEqual(
expect_ssl_session_reused,

@ -53,7 +53,10 @@ def main_unary(server_target):
"""Initiate a unary RPC to be interrupted by a SIGINT."""
global per_process_rpc_future # pylint: disable=global-statement
with grpc.insecure_channel(server_target) as channel:
multicallable = channel.unary_unary(UNARY_UNARY)
multicallable = channel.unary_unary(
UNARY_UNARY,
_registered_method=True,
)
signal.signal(signal.SIGINT, handle_sigint)
per_process_rpc_future = multicallable.future(
_MESSAGE, wait_for_ready=True
@ -67,9 +70,10 @@ def main_streaming(server_target):
global per_process_rpc_future # pylint: disable=global-statement
with grpc.insecure_channel(server_target) as channel:
signal.signal(signal.SIGINT, handle_sigint)
per_process_rpc_future = channel.unary_stream(UNARY_STREAM)(
_MESSAGE, wait_for_ready=True
)
per_process_rpc_future = channel.unary_stream(
UNARY_STREAM,
_registered_method=True,
)(_MESSAGE, wait_for_ready=True)
for result in per_process_rpc_future:
pass
assert False, _ASSERTION_MESSAGE
@ -79,7 +83,10 @@ def main_unary_with_exception(server_target):
"""Initiate a unary RPC with a signal handler that will raise."""
channel = grpc.insecure_channel(server_target)
try:
channel.unary_unary(UNARY_UNARY)(_MESSAGE, wait_for_ready=True)
channel.unary_unary(
UNARY_UNARY,
_registered_method=True,
)(_MESSAGE, wait_for_ready=True)
except KeyboardInterrupt:
sys.stderr.write("Running signal handler.\n")
sys.stderr.flush()
@ -92,9 +99,10 @@ def main_streaming_with_exception(server_target):
"""Initiate a streaming RPC with a signal handler that will raise."""
channel = grpc.insecure_channel(server_target)
try:
for _ in channel.unary_stream(UNARY_STREAM)(
_MESSAGE, wait_for_ready=True
):
for _ in channel.unary_stream(
UNARY_STREAM,
_registered_method=True,
)(_MESSAGE, wait_for_ready=True):
pass
except KeyboardInterrupt:
sys.stderr.write("Running signal handler.\n")

@ -71,9 +71,10 @@ class XdsCredentialsTest(unittest.TestCase):
server_address, channel_creds, options=override_options
) as channel:
request = b"abc"
response = channel.unary_unary("/test/method")(
request, wait_for_ready=True
)
response = channel.unary_unary(
"/test/method",
_registered_method=True,
)(request, wait_for_ready=True)
self.assertEqual(response, request)
def test_xds_creds_fallback_insecure(self):
@ -89,9 +90,10 @@ class XdsCredentialsTest(unittest.TestCase):
channel_creds = grpc.xds_channel_credentials(channel_fallback_creds)
with grpc.secure_channel(server_address, channel_creds) as channel:
request = b"abc"
response = channel.unary_unary("/test/method")(
request, wait_for_ready=True
)
response = channel.unary_unary(
"/test/method",
_registered_method=True,
)(request, wait_for_ready=True)
self.assertEqual(response, request)
def test_start_xds_server(self):

@ -65,6 +65,7 @@ class CloseChannelTest(unittest.TestCase):
_UNARY_CALL_METHOD_WITH_SLEEP,
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString,
_registered_method=True,
)
greenlet = group.spawn(self._run_client, UnaryCallWithSleep)
# release loop so that greenlet can take control
@ -78,6 +79,7 @@ class CloseChannelTest(unittest.TestCase):
_UNARY_CALL_METHOD_WITH_SLEEP,
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString,
_registered_method=True,
)
greenlet = group.spawn(self._run_client, UnaryCallWithSleep)
# release loop so that greenlet can take control

@ -68,7 +68,10 @@ def _start_a_test_server():
def _perform_an_rpc(address):
channel = grpc.insecure_channel(address)
multicallable = channel.unary_unary(_TEST_METHOD)
multicallable = channel.unary_unary(
_TEST_METHOD,
_registered_method=True,
)
response = multicallable(_REQUEST)
assert _REQUEST == response

@ -193,6 +193,7 @@ class SimpleStubsTest(unittest.TestCase):
_UNARY_UNARY,
channel_credentials=grpc.experimental.insecure_channel_credentials(),
timeout=None,
_registered_method=0,
)
self.assertEqual(_REQUEST, response)
@ -205,6 +206,7 @@ class SimpleStubsTest(unittest.TestCase):
_UNARY_UNARY,
channel_credentials=grpc.local_channel_credentials(),
timeout=None,
_registered_method=0,
)
self.assertEqual(_REQUEST, response)
@ -213,7 +215,10 @@ class SimpleStubsTest(unittest.TestCase):
target = f"localhost:{port}"
test_name = inspect.stack()[0][3]
args = (_REQUEST, target, _UNARY_UNARY)
kwargs = {"channel_credentials": grpc.local_channel_credentials()}
kwargs = {
"channel_credentials": grpc.local_channel_credentials(),
"_registered_method": True,
}
def _invoke(seed: str):
run_kwargs = dict(kwargs)
@ -230,6 +235,7 @@ class SimpleStubsTest(unittest.TestCase):
target,
_UNARY_UNARY,
channel_credentials=grpc.local_channel_credentials(),
_registered_method=0,
)
self.assert_eventually(
lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()
@ -250,6 +256,7 @@ class SimpleStubsTest(unittest.TestCase):
_UNARY_UNARY,
options=options,
channel_credentials=grpc.local_channel_credentials(),
_registered_method=0,
)
self.assert_eventually(
lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()
@ -265,6 +272,7 @@ class SimpleStubsTest(unittest.TestCase):
target,
_UNARY_STREAM,
channel_credentials=grpc.local_channel_credentials(),
_registered_method=0,
):
self.assertEqual(_REQUEST, response)
@ -280,6 +288,7 @@ class SimpleStubsTest(unittest.TestCase):
target,
_STREAM_UNARY,
channel_credentials=grpc.local_channel_credentials(),
_registered_method=0,
)
self.assertEqual(_REQUEST, response)
@ -295,6 +304,7 @@ class SimpleStubsTest(unittest.TestCase):
target,
_STREAM_STREAM,
channel_credentials=grpc.local_channel_credentials(),
_registered_method=0,
):
self.assertEqual(_REQUEST, response)
@ -319,14 +329,22 @@ class SimpleStubsTest(unittest.TestCase):
with _server(server_creds) as port:
target = f"localhost:{port}"
response = grpc.experimental.unary_unary(
_REQUEST, target, _UNARY_UNARY, options=_property_options
_REQUEST,
target,
_UNARY_UNARY,
options=_property_options,
_registered_method=0,
)
def test_insecure_sugar(self):
with _server(None) as port:
target = f"localhost:{port}"
response = grpc.experimental.unary_unary(
_REQUEST, target, _UNARY_UNARY, insecure=True
_REQUEST,
target,
_UNARY_UNARY,
insecure=True,
_registered_method=0,
)
self.assertEqual(_REQUEST, response)
@ -340,14 +358,24 @@ class SimpleStubsTest(unittest.TestCase):
_UNARY_UNARY,
insecure=True,
channel_credentials=grpc.local_channel_credentials(),
_registered_method=0,
)
def test_default_wait_for_ready(self):
addr, port, sock = get_socket()
sock.close()
target = f"{addr}:{port}"
channel = grpc._simple_stubs.ChannelCache.get().get_channel(
target, (), None, True, None
(
channel,
unused_method_handle,
) = grpc._simple_stubs.ChannelCache.get().get_channel(
target=target,
options=(),
channel_credentials=None,
insecure=True,
compression=None,
method=_UNARY_UNARY,
_registered_method=True,
)
rpc_finished_event = threading.Event()
rpc_failed_event = threading.Event()
@ -376,7 +404,12 @@ class SimpleStubsTest(unittest.TestCase):
def _send_rpc():
try:
response = grpc.experimental.unary_unary(
_REQUEST, target, _UNARY_UNARY, timeout=None, insecure=True
_REQUEST,
target,
_UNARY_UNARY,
timeout=None,
insecure=True,
_registered_method=0,
)
rpc_finished_event.set()
except Exception as e:
@ -399,6 +432,7 @@ class SimpleStubsTest(unittest.TestCase):
target,
_BLACK_HOLE,
insecure=True,
_registered_method=0,
**invocation_args,
)
self.assertEqual(

Loading…
Cancel
Save