diff --git a/examples/python/helloworld/helloworld_pb2.py b/examples/python/helloworld/helloworld_pb2.py index f5b4f2d27dc..6ee31ad94a2 100644 --- a/examples/python/helloworld/helloworld_pb2.py +++ b/examples/python/helloworld/helloworld_pb2.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: helloworld.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" -from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -13,18 +14,18 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10helloworld.proto\x12\nhelloworld\"\x1c\n\x0cHelloRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"\x1d\n\nHelloReply\x12\x0f\n\x07message\x18\x01 \x01(\t2I\n\x07Greeter\x12>\n\x08SayHello\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00\x42\x36\n\x1bio.grpc.examples.helloworldB\x0fHelloWorldProtoP\x01\xa2\x02\x03HLWb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10helloworld.proto\x12\nhelloworld\"\x1c\n\x0cHelloRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"\x1d\n\nHelloReply\x12\x0f\n\x07message\x18\x01 \x01(\t2\xe4\x01\n\x07Greeter\x12>\n\x08SayHello\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00\x12K\n\x13SayHelloStreamReply\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00\x30\x01\x12L\n\x12SayHelloBidiStream\x12\x18.helloworld.HelloRequest\x1a\x16.helloworld.HelloReply\"\x00(\x01\x30\x01\x42\x36\n\x1bio.grpc.examples.helloworldB\x0fHelloWorldProtoP\x01\xa2\x02\x03HLWb\x06proto3') -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'helloworld_pb2', globals()) +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'helloworld_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'\n\033io.grpc.examples.helloworldB\017HelloWorldProtoP\001\242\002\003HLW' - _HELLOREQUEST._serialized_start=32 - _HELLOREQUEST._serialized_end=60 - _HELLOREPLY._serialized_start=62 - _HELLOREPLY._serialized_end=91 - _GREETER._serialized_start=93 - _GREETER._serialized_end=166 + _globals['DESCRIPTOR']._options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\033io.grpc.examples.helloworldB\017HelloWorldProtoP\001\242\002\003HLW' + _globals['_HELLOREQUEST']._serialized_start=32 + _globals['_HELLOREQUEST']._serialized_end=60 + _globals['_HELLOREPLY']._serialized_start=62 + _globals['_HELLOREPLY']._serialized_end=91 + _globals['_GREETER']._serialized_start=94 + _globals['_GREETER']._serialized_end=322 # @@protoc_insertion_point(module_scope) diff --git a/examples/python/helloworld/helloworld_pb2.pyi b/examples/python/helloworld/helloworld_pb2.pyi index 8c4b5b22805..bf0bd395ad5 100644 --- a/examples/python/helloworld/helloworld_pb2.pyi +++ b/examples/python/helloworld/helloworld_pb2.pyi @@ -4,14 +4,14 @@ from typing import ClassVar as _ClassVar, Optional as _Optional DESCRIPTOR: _descriptor.FileDescriptor -class HelloReply(_message.Message): - __slots__ = ["message"] - MESSAGE_FIELD_NUMBER: _ClassVar[int] - message: str - def __init__(self, message: _Optional[str] = ...) -> None: ... - class HelloRequest(_message.Message): - __slots__ = ["name"] + __slots__ = ("name",) NAME_FIELD_NUMBER: _ClassVar[int] name: str def __init__(self, name: _Optional[str] = ...) -> None: ... + +class HelloReply(_message.Message): + __slots__ = ("message",) + MESSAGE_FIELD_NUMBER: _ClassVar[int] + message: str + def __init__(self, message: _Optional[str] = ...) -> None: ... diff --git a/examples/python/helloworld/helloworld_pb2_grpc.py b/examples/python/helloworld/helloworld_pb2_grpc.py index 47c186976e1..68bcfef1756 100644 --- a/examples/python/helloworld/helloworld_pb2_grpc.py +++ b/examples/python/helloworld/helloworld_pb2_grpc.py @@ -19,7 +19,17 @@ class GreeterStub(object): '/helloworld.Greeter/SayHello', request_serializer=helloworld__pb2.HelloRequest.SerializeToString, response_deserializer=helloworld__pb2.HelloReply.FromString, - ) + _registered_method=True) + self.SayHelloStreamReply = channel.unary_stream( + '/helloworld.Greeter/SayHelloStreamReply', + request_serializer=helloworld__pb2.HelloRequest.SerializeToString, + response_deserializer=helloworld__pb2.HelloReply.FromString, + _registered_method=True) + self.SayHelloBidiStream = channel.stream_stream( + '/helloworld.Greeter/SayHelloBidiStream', + request_serializer=helloworld__pb2.HelloRequest.SerializeToString, + response_deserializer=helloworld__pb2.HelloReply.FromString, + _registered_method=True) class GreeterServicer(object): @@ -33,6 +43,18 @@ class GreeterServicer(object): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def SayHelloStreamReply(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SayHelloBidiStream(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_GreeterServicer_to_server(servicer, server): rpc_method_handlers = { @@ -41,6 +63,16 @@ def add_GreeterServicer_to_server(servicer, server): request_deserializer=helloworld__pb2.HelloRequest.FromString, response_serializer=helloworld__pb2.HelloReply.SerializeToString, ), + 'SayHelloStreamReply': grpc.unary_stream_rpc_method_handler( + servicer.SayHelloStreamReply, + request_deserializer=helloworld__pb2.HelloRequest.FromString, + response_serializer=helloworld__pb2.HelloReply.SerializeToString, + ), + 'SayHelloBidiStream': grpc.stream_stream_rpc_method_handler( + servicer.SayHelloBidiStream, + request_deserializer=helloworld__pb2.HelloRequest.FromString, + response_serializer=helloworld__pb2.HelloReply.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'helloworld.Greeter', rpc_method_handlers) @@ -63,8 +95,72 @@ class Greeter(object): wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/helloworld.Greeter/SayHello', + return grpc.experimental.unary_unary( + request, + target, + '/helloworld.Greeter/SayHello', + helloworld__pb2.HelloRequest.SerializeToString, + helloworld__pb2.HelloReply.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SayHelloStreamReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/helloworld.Greeter/SayHelloStreamReply', + helloworld__pb2.HelloRequest.SerializeToString, + helloworld__pb2.HelloReply.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SayHelloBidiStream(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_stream( + request_iterator, + target, + '/helloworld.Greeter/SayHelloBidiStream', helloworld__pb2.HelloRequest.SerializeToString, helloworld__pb2.HelloReply.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/src/compiler/python_generator.cc b/src/compiler/python_generator.cc index 753fe1c8887..b7a3115bce0 100644 --- a/src/compiler/python_generator.cc +++ b/src/compiler/python_generator.cc @@ -467,7 +467,7 @@ bool PrivateGenerator::PrintStub( out->Print( method_dict, "response_deserializer=$ResponseModuleAndClass$.FromString,\n"); - out->Print(")\n"); + out->Print("_registered_method=True)\n"); } } } @@ -642,22 +642,27 @@ bool PrivateGenerator::PrintServiceClass( args_dict["ArityMethodName"] = arity_method_name; args_dict["PackageQualifiedService"] = package_qualified_service_name; args_dict["Method"] = method->name(); - out->Print(args_dict, - "return " - "grpc.experimental.$ArityMethodName$($RequestParameter$, " - "target, '/$PackageQualifiedService$/$Method$',\n"); + out->Print(args_dict, "return grpc.experimental.$ArityMethodName$(\n"); { IndentScope continuation_indent(out); StringMap serializer_dict; + out->Print(args_dict, "$RequestParameter$,\n"); + out->Print("target,\n"); + out->Print(args_dict, "'/$PackageQualifiedService$/$Method$',\n"); serializer_dict["RequestModuleAndClass"] = request_module_and_class; serializer_dict["ResponseModuleAndClass"] = response_module_and_class; out->Print(serializer_dict, "$RequestModuleAndClass$.SerializeToString,\n"); out->Print(serializer_dict, "$ResponseModuleAndClass$.FromString,\n"); - out->Print("options, channel_credentials,\n"); - out->Print( - "insecure, call_credentials, compression, wait_for_ready, " - "timeout, metadata)\n"); + out->Print("options,\n"); + out->Print("channel_credentials,\n"); + out->Print("insecure,\n"); + out->Print("call_credentials,\n"); + out->Print("compression,\n"); + out->Print("wait_for_ready,\n"); + out->Print("timeout,\n"); + out->Print("metadata,\n"); + out->Print("_registered_method=True)\n"); } } } diff --git a/src/python/.gitignore b/src/python/.gitignore index 095ab8bbae1..61363e8cb8d 100644 --- a/src/python/.gitignore +++ b/src/python/.gitignore @@ -1,4 +1,6 @@ -gens/ +build/ +grpc_root/ +third_party/ *_pb2.py *_pb2.pyi *_pb2_grpc.py diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index 83ded1b5df8..e0ec581f9d4 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -1004,6 +1004,7 @@ class Channel(abc.ABC): method, request_serializer=None, response_deserializer=None, + _registered_method=False, ): """Creates a UnaryUnaryMultiCallable for a unary-unary method. @@ -1014,6 +1015,8 @@ class Channel(abc.ABC): response_deserializer: Optional :term:`deserializer` for deserializing the response message. Response goes undeserialized in case None is passed. + _registered_method: Implementation Private. A bool representing whether the method + is registered. Returns: A UnaryUnaryMultiCallable value for the named unary-unary method. @@ -1026,6 +1029,7 @@ class Channel(abc.ABC): method, request_serializer=None, response_deserializer=None, + _registered_method=False, ): """Creates a UnaryStreamMultiCallable for a unary-stream method. @@ -1036,6 +1040,8 @@ class Channel(abc.ABC): response_deserializer: Optional :term:`deserializer` for deserializing the response message. Response goes undeserialized in case None is passed. + _registered_method: Implementation Private. A bool representing whether the method + is registered. Returns: A UnaryStreamMultiCallable value for the name unary-stream method. @@ -1048,6 +1054,7 @@ class Channel(abc.ABC): method, request_serializer=None, response_deserializer=None, + _registered_method=False, ): """Creates a StreamUnaryMultiCallable for a stream-unary method. @@ -1058,6 +1065,8 @@ class Channel(abc.ABC): response_deserializer: Optional :term:`deserializer` for deserializing the response message. Response goes undeserialized in case None is passed. + _registered_method: Implementation Private. A bool representing whether the method + is registered. Returns: A StreamUnaryMultiCallable value for the named stream-unary method. @@ -1070,6 +1079,7 @@ class Channel(abc.ABC): method, request_serializer=None, response_deserializer=None, + _registered_method=False, ): """Creates a StreamStreamMultiCallable for a stream-stream method. @@ -1080,6 +1090,8 @@ class Channel(abc.ABC): response_deserializer: Optional :term:`deserializer` for deserializing the response message. Response goes undeserialized in case None is passed. + _registered_method: Implementation Private. A bool representing whether the method + is registered. Returns: A StreamStreamMultiCallable value for the named stream-stream method. diff --git a/src/python/grpcio/grpc/_channel.py b/src/python/grpcio/grpc/_channel.py index 79c85a1b94c..bf29982ca65 100644 --- a/src/python/grpcio/grpc/_channel.py +++ b/src/python/grpcio/grpc/_channel.py @@ -24,6 +24,7 @@ import types from typing import ( Any, Callable, + Dict, Iterator, List, Optional, @@ -1054,6 +1055,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): _request_serializer: Optional[SerializingFunction] _response_deserializer: Optional[DeserializingFunction] _context: Any + _registered_call_handle: Optional[int] __slots__ = [ "_channel", @@ -1074,6 +1076,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): target: bytes, request_serializer: Optional[SerializingFunction], response_deserializer: Optional[DeserializingFunction], + _registered_call_handle: Optional[int], ): self._channel = channel self._managed_call = managed_call @@ -1082,6 +1085,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): self._request_serializer = request_serializer self._response_deserializer = response_deserializer self._context = cygrpc.build_census_context() + self._registered_call_handle = _registered_call_handle def _prepare( self, @@ -1153,6 +1157,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): ), ), self._context, + self._registered_call_handle, ) event = call.next_event() _handle_event(event, state, self._response_deserializer) @@ -1221,6 +1226,7 @@ class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): (operations,), event_handler, self._context, + self._registered_call_handle, ) return _MultiThreadedRendezvous( state, call, self._response_deserializer, deadline @@ -1234,6 +1240,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): _request_serializer: Optional[SerializingFunction] _response_deserializer: Optional[DeserializingFunction] _context: Any + _registered_call_handle: Optional[int] __slots__ = [ "_channel", @@ -1252,6 +1259,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): target: bytes, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction, + _registered_call_handle: Optional[int], ): self._channel = channel self._method = method @@ -1259,6 +1267,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): self._request_serializer = request_serializer self._response_deserializer = response_deserializer self._context = cygrpc.build_census_context() + self._registered_call_handle = _registered_call_handle def __call__( # pylint: disable=too-many-locals self, @@ -1317,6 +1326,7 @@ class _SingleThreadedUnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): call_credentials, operations_and_tags, self._context, + self._registered_call_handle, ) return _SingleThreadedRendezvous( state, call, self._response_deserializer, deadline @@ -1331,6 +1341,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): _request_serializer: Optional[SerializingFunction] _response_deserializer: Optional[DeserializingFunction] _context: Any + _registered_call_handle: Optional[int] __slots__ = [ "_channel", @@ -1351,6 +1362,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): target: bytes, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction, + _registered_call_handle: Optional[int], ): self._channel = channel self._managed_call = managed_call @@ -1359,6 +1371,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): self._request_serializer = request_serializer self._response_deserializer = response_deserializer self._context = cygrpc.build_census_context() + self._registered_call_handle = _registered_call_handle def __call__( # pylint: disable=too-many-locals self, @@ -1408,6 +1421,7 @@ class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable): operations, _event_handler(state, self._response_deserializer), self._context, + self._registered_call_handle, ) return _MultiThreadedRendezvous( state, call, self._response_deserializer, deadline @@ -1422,6 +1436,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): _request_serializer: Optional[SerializingFunction] _response_deserializer: Optional[DeserializingFunction] _context: Any + _registered_call_handle: Optional[int] __slots__ = [ "_channel", @@ -1442,6 +1457,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): target: bytes, request_serializer: Optional[SerializingFunction], response_deserializer: Optional[DeserializingFunction], + _registered_call_handle: Optional[int], ): self._channel = channel self._managed_call = managed_call @@ -1450,6 +1466,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): self._request_serializer = request_serializer self._response_deserializer = response_deserializer self._context = cygrpc.build_census_context() + self._registered_call_handle = _registered_call_handle def _blocking( self, @@ -1482,6 +1499,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): augmented_metadata, initial_metadata_flags ), self._context, + self._registered_call_handle, ) _consume_request_iterator( request_iterator, state, call, self._request_serializer, None @@ -1572,6 +1590,7 @@ class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable): ), event_handler, self._context, + self._registered_call_handle, ) _consume_request_iterator( request_iterator, @@ -1593,6 +1612,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): _request_serializer: Optional[SerializingFunction] _response_deserializer: Optional[DeserializingFunction] _context: Any + _registered_call_handle: Optional[int] __slots__ = [ "_channel", @@ -1611,8 +1631,9 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): managed_call: IntegratedCallFactory, method: bytes, target: bytes, - request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None, + request_serializer: Optional[SerializingFunction], + response_deserializer: Optional[DeserializingFunction], + _registered_call_handle: Optional[int], ): self._channel = channel self._managed_call = managed_call @@ -1621,6 +1642,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): self._request_serializer = request_serializer self._response_deserializer = response_deserializer self._context = cygrpc.build_census_context() + self._registered_call_handle = _registered_call_handle def __call__( self, @@ -1662,6 +1684,7 @@ class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable): operations, event_handler, self._context, + self._registered_call_handle, ) _consume_request_iterator( request_iterator, @@ -1751,7 +1774,8 @@ def _channel_managed_call_management(state: _ChannelCallState): credentials: Optional[cygrpc.CallCredentials], operations: Sequence[Sequence[cygrpc.Operation]], event_handler: UserTag, - context, + context: Any, + _registered_call_handle: Optional[int], ) -> cygrpc.IntegratedCall: """Creates a cygrpc.IntegratedCall. @@ -1768,6 +1792,8 @@ def _channel_managed_call_management(state: _ChannelCallState): event_handler: A behavior to call to handle the events resultant from the operations on the call. context: Context object for distributed tracing. + _registered_call_handle: An int representing the call handle of the + method, or None if the method is not registered. Returns: A cygrpc.IntegratedCall with which to conduct an RPC. """ @@ -1788,6 +1814,7 @@ def _channel_managed_call_management(state: _ChannelCallState): credentials, operations_and_tags, context, + _registered_call_handle, ) if state.managed_calls == 0: state.managed_calls = 1 @@ -2021,6 +2048,7 @@ class Channel(grpc.Channel): _call_state: _ChannelCallState _connectivity_state: _ChannelConnectivityState _target: str + _registered_call_handles: Dict[str, int] def __init__( self, @@ -2055,6 +2083,22 @@ class Channel(grpc.Channel): if cygrpc.g_gevent_activated: cygrpc.gevent_increment_channel_count() + def _get_registered_call_handle(self, method: str) -> int: + """ + Get the registered call handle for a method. + + This is a semi-private method. It is intended for use only by gRPC generated code. + + This method is not thread-safe. + + Args: + method: Required, the method name for the RPC. + + Returns: + The registered call handle pointer in the form of a Python Long. + """ + return self._channel.get_registered_call_handle(_common.encode(method)) + def _process_python_options( self, python_options: Sequence[ChannelArgumentType] ) -> None: @@ -2078,12 +2122,17 @@ class Channel(grpc.Channel): ) -> None: _unsubscribe(self._connectivity_state, callback) + # pylint: disable=arguments-differ def unary_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> grpc.UnaryUnaryMultiCallable: + _registered_call_handle = None + if _registered_method: + _registered_call_handle = self._get_registered_call_handle(method) return _UnaryUnaryMultiCallable( self._channel, _channel_managed_call_management(self._call_state), @@ -2091,14 +2140,20 @@ class Channel(grpc.Channel): _common.encode(self._target), request_serializer, response_deserializer, + _registered_call_handle, ) + # pylint: disable=arguments-differ def unary_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> grpc.UnaryStreamMultiCallable: + _registered_call_handle = None + if _registered_method: + _registered_call_handle = self._get_registered_call_handle(method) # NOTE(rbellevi): Benchmarks have shown that running a unary-stream RPC # on a single Python thread results in an appreciable speed-up. However, # due to slight differences in capability, the multi-threaded variant @@ -2110,6 +2165,7 @@ class Channel(grpc.Channel): _common.encode(self._target), request_serializer, response_deserializer, + _registered_call_handle, ) else: return _UnaryStreamMultiCallable( @@ -2119,14 +2175,20 @@ class Channel(grpc.Channel): _common.encode(self._target), request_serializer, response_deserializer, + _registered_call_handle, ) + # pylint: disable=arguments-differ def stream_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> grpc.StreamUnaryMultiCallable: + _registered_call_handle = None + if _registered_method: + _registered_call_handle = self._get_registered_call_handle(method) return _StreamUnaryMultiCallable( self._channel, _channel_managed_call_management(self._call_state), @@ -2134,14 +2196,20 @@ class Channel(grpc.Channel): _common.encode(self._target), request_serializer, response_deserializer, + _registered_call_handle, ) + # pylint: disable=arguments-differ def stream_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> grpc.StreamStreamMultiCallable: + _registered_call_handle = None + if _registered_method: + _registered_call_handle = self._get_registered_call_handle(method) return _StreamStreamMultiCallable( self._channel, _channel_managed_call_management(self._call_state), @@ -2149,6 +2217,7 @@ class Channel(grpc.Channel): _common.encode(self._target), request_serializer, response_deserializer, + _registered_call_handle, ) def _unsubscribe_all(self) -> None: diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi index 6e5416a9e31..96d03e181b9 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pxd.pxi @@ -74,6 +74,13 @@ cdef class SegregatedCall: cdef class Channel: cdef _ChannelState _state + cdef dict _registered_call_handles # TODO(https://github.com/grpc/grpc/issues/15662): Eliminate this. cdef tuple _arguments + + +cdef class CallHandle: + + cdef void *c_call_handle + cdef object method diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi index f6db36ebde1..dde3b166789 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi @@ -101,6 +101,25 @@ cdef class _ChannelState: self.connectivity_due = set() self.closed_reason = None +cdef class CallHandle: + + def __cinit__(self, _ChannelState channel_state, object method): + self.method = method + cpython.Py_INCREF(method) + # Note that since we always pass None for host, we set the + # second-to-last parameter of grpc_channel_register_call to a fixed + # NULL value. + self.c_call_handle = grpc_channel_register_call( + channel_state.c_channel, method, NULL, NULL) + + def __dealloc__(self): + cpython.Py_DECREF(self.method) + + @property + def call_handle(self): + return cpython.PyLong_FromVoidPtr(self.c_call_handle) + + cdef tuple _operate(grpc_call *c_call, object operations, object user_tag): cdef grpc_call_error c_call_error @@ -199,7 +218,7 @@ cdef void _call( grpc_completion_queue *c_completion_queue, on_success, int flags, method, host, object deadline, CallCredentials credentials, object operationses_and_user_tags, object metadata, - object context) except *: + object context, object registered_call_handle) except *: """Invokes an RPC. Args: @@ -226,6 +245,8 @@ cdef void _call( must be present in the first element of this value. metadata: The metadata for this call. context: Context object for distributed tracing. + registered_call_handle: An int representing the call handle of the method, or + None if the method is not registered. """ cdef grpc_slice method_slice cdef grpc_slice host_slice @@ -242,10 +263,16 @@ cdef void _call( else: host_slice = _slice_from_bytes(host) host_slice_ptr = &host_slice - call_state.c_call = grpc_channel_create_call( - channel_state.c_channel, NULL, flags, - c_completion_queue, method_slice, host_slice_ptr, - _timespec_from_time(deadline), NULL) + if registered_call_handle: + call_state.c_call = grpc_channel_create_registered_call( + channel_state.c_channel, NULL, flags, + c_completion_queue, cpython.PyLong_AsVoidPtr(registered_call_handle), + _timespec_from_time(deadline), NULL) + else: + call_state.c_call = grpc_channel_create_call( + channel_state.c_channel, NULL, flags, + c_completion_queue, method_slice, host_slice_ptr, + _timespec_from_time(deadline), NULL) grpc_slice_unref(method_slice) if host_slice_ptr: grpc_slice_unref(host_slice) @@ -309,7 +336,7 @@ cdef class IntegratedCall: cdef IntegratedCall _integrated_call( _ChannelState state, int flags, method, host, object deadline, object metadata, CallCredentials credentials, operationses_and_user_tags, - object context): + object context, object registered_call_handle): call_state = _CallState() def on_success(started_tags): @@ -318,7 +345,8 @@ cdef IntegratedCall _integrated_call( _call( state, call_state, state.c_call_completion_queue, on_success, flags, - method, host, deadline, credentials, operationses_and_user_tags, metadata, context) + method, host, deadline, credentials, operationses_and_user_tags, + metadata, context, registered_call_handle) return IntegratedCall(state, call_state) @@ -371,7 +399,7 @@ cdef class SegregatedCall: cdef SegregatedCall _segregated_call( _ChannelState state, int flags, method, host, object deadline, object metadata, CallCredentials credentials, operationses_and_user_tags, - object context): + object context, object registered_call_handle): cdef _CallState call_state = _CallState() cdef SegregatedCall segregated_call cdef grpc_completion_queue *c_completion_queue @@ -389,7 +417,7 @@ cdef SegregatedCall _segregated_call( _call( state, call_state, c_completion_queue, on_success, flags, method, host, deadline, credentials, operationses_and_user_tags, metadata, - context) + context, registered_call_handle) except: _destroy_c_completion_queue(c_completion_queue) raise @@ -486,6 +514,7 @@ cdef class Channel: else grpc_insecure_credentials_create()) self._state.c_channel = grpc_channel_create( target, c_channel_credentials, channel_args.c_args()) + self._registered_call_handles = {} grpc_channel_credentials_release(c_channel_credentials) def target(self): @@ -499,10 +528,10 @@ cdef class Channel: def integrated_call( self, int flags, method, host, object deadline, object metadata, CallCredentials credentials, operationses_and_tags, - object context = None): + object context = None, object registered_call_handle = None): return _integrated_call( self._state, flags, method, host, deadline, metadata, credentials, - operationses_and_tags, context) + operationses_and_tags, context, registered_call_handle) def next_call_event(self): def on_success(tag): @@ -521,10 +550,10 @@ cdef class Channel: def segregated_call( self, int flags, method, host, object deadline, object metadata, CallCredentials credentials, operationses_and_tags, - object context = None): + object context = None, object registered_call_handle = None): return _segregated_call( self._state, flags, method, host, deadline, metadata, credentials, - operationses_and_tags, context) + operationses_and_tags, context, registered_call_handle) def check_connectivity_state(self, bint try_to_connect): with self._state.condition: @@ -543,3 +572,19 @@ cdef class Channel: def close_on_fork(self, code, details): _close(self, code, details, True) + + def get_registered_call_handle(self, method): + """ + Get or registers a call handler for a method. + + This method is not thread-safe. + + Args: + method: Required, the method name for the RPC. + + Returns: + The registered call handle pointer in the form of a Python Long. + """ + if method not in self._registered_call_handles.keys(): + self._registered_call_handles[method] = CallHandle(self._state, method) + return self._registered_call_handles[method].call_handle diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi index e1bc87d4abd..29149e9893a 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxi @@ -433,6 +433,12 @@ cdef extern from "grpc/grpc.h": grpc_channel *channel, grpc_call *parent_call, uint32_t propagation_mask, grpc_completion_queue *completion_queue, grpc_slice method, const grpc_slice *host, gpr_timespec deadline, void *reserved) nogil + void *grpc_channel_register_call( + grpc_channel *channel, const char *method, const char *host, void *reserved) nogil + grpc_call *grpc_channel_create_registered_call( + grpc_channel *channel, grpc_call *parent_call, uint32_t propagation_mask, + grpc_completion_queue *completion_queue, void* registered_call_handle, + gpr_timespec deadline, void *reserved) nogil grpc_connectivity_state grpc_channel_check_connectivity_state( grpc_channel *channel, int try_to_connect) nogil void grpc_channel_watch_connectivity_state( diff --git a/src/python/grpcio/grpc/_interceptor.py b/src/python/grpcio/grpc/_interceptor.py index 36bce4e3ba5..94abafebaa6 100644 --- a/src/python/grpcio/grpc/_interceptor.py +++ b/src/python/grpcio/grpc/_interceptor.py @@ -684,57 +684,85 @@ class _Channel(grpc.Channel): def unsubscribe(self, callback: Callable): self._channel.unsubscribe(callback) + # pylint: disable=arguments-differ def unary_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> grpc.UnaryUnaryMultiCallable: + # pytype: disable=wrong-arg-count thunk = lambda m: self._channel.unary_unary( - m, request_serializer, response_deserializer + m, + request_serializer, + response_deserializer, + _registered_method, ) + # pytype: enable=wrong-arg-count if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor): return _UnaryUnaryMultiCallable(thunk, method, self._interceptor) else: return thunk(method) + # pylint: disable=arguments-differ def unary_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> grpc.UnaryStreamMultiCallable: + # pytype: disable=wrong-arg-count thunk = lambda m: self._channel.unary_stream( - m, request_serializer, response_deserializer + m, + request_serializer, + response_deserializer, + _registered_method, ) + # pytype: enable=wrong-arg-count if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor): return _UnaryStreamMultiCallable(thunk, method, self._interceptor) else: return thunk(method) + # pylint: disable=arguments-differ def stream_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> grpc.StreamUnaryMultiCallable: + # pytype: disable=wrong-arg-count thunk = lambda m: self._channel.stream_unary( - m, request_serializer, response_deserializer + m, + request_serializer, + response_deserializer, + _registered_method, ) + # pytype: enable=wrong-arg-count if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor): return _StreamUnaryMultiCallable(thunk, method, self._interceptor) else: return thunk(method) + # pylint: disable=arguments-differ def stream_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> grpc.StreamStreamMultiCallable: + # pytype: disable=wrong-arg-count thunk = lambda m: self._channel.stream_stream( - m, request_serializer, response_deserializer + m, + request_serializer, + response_deserializer, + _registered_method, ) + # pytype: enable=wrong-arg-count if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor): return _StreamStreamMultiCallable(thunk, method, self._interceptor) else: diff --git a/src/python/grpcio/grpc/_simple_stubs.py b/src/python/grpcio/grpc/_simple_stubs.py index 7772860957b..3e88670aa08 100644 --- a/src/python/grpcio/grpc/_simple_stubs.py +++ b/src/python/grpcio/grpc/_simple_stubs.py @@ -159,7 +159,19 @@ class ChannelCache: channel_credentials: Optional[grpc.ChannelCredentials], insecure: bool, compression: Optional[grpc.Compression], - ) -> grpc.Channel: + method: str, + _registered_method: bool, + ) -> Tuple[grpc.Channel, Optional[int]]: + """Get a channel from cache or creates a new channel. + + This method also takes care of register method for channel, + which means we'll register a new call handle if we're calling a + non-registered method for an existing channel. + + Returns: + A tuple with two items. The first item is the channel, second item is + the call handle if the method is registered, None if it's not registered. + """ if insecure and channel_credentials: raise ValueError( "The insecure option is mutually exclusive with " @@ -176,18 +188,25 @@ class ChannelCache: key = (target, options, channel_credentials, compression) with self._lock: channel_data = self._mapping.get(key, None) + call_handle = None if channel_data is not None: channel = channel_data[0] + # Register a new call handle if we're calling a registered method for an + # existing channel and this method is not registered. + if _registered_method: + call_handle = channel._get_registered_call_handle(method) self._mapping.pop(key) self._mapping[key] = ( channel, datetime.datetime.now() + _EVICTION_PERIOD, ) - return channel + return channel, call_handle else: channel = _create_channel( target, options, channel_credentials, compression ) + if _registered_method: + call_handle = channel._get_registered_call_handle(method) self._mapping[key] = ( channel, datetime.datetime.now() + _EVICTION_PERIOD, @@ -197,7 +216,7 @@ class ChannelCache: or len(self._mapping) >= _MAXIMUM_CHANNELS ): self._condition.notify() - return channel + return channel, call_handle def _test_only_channel_count(self) -> int: with self._lock: @@ -205,6 +224,7 @@ class ChannelCache: @experimental_api +# pylint: disable=too-many-locals def unary_unary( request: RequestType, target: str, @@ -219,6 +239,7 @@ def unary_unary( wait_for_ready: Optional[bool] = None, timeout: Optional[float] = _DEFAULT_TIMEOUT, metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None, + _registered_method: Optional[bool] = False, ) -> ResponseType: """Invokes a unary-unary RPC without an explicitly specified channel. @@ -272,11 +293,17 @@ def unary_unary( Returns: The response to the RPC. """ - channel = ChannelCache.get().get_channel( - target, options, channel_credentials, insecure, compression + channel, method_handle = ChannelCache.get().get_channel( + target, + options, + channel_credentials, + insecure, + compression, + method, + _registered_method, ) multicallable = channel.unary_unary( - method, request_serializer, response_deserializer + method, request_serializer, response_deserializer, method_handle ) wait_for_ready = wait_for_ready if wait_for_ready is not None else True return multicallable( @@ -289,6 +316,7 @@ def unary_unary( @experimental_api +# pylint: disable=too-many-locals def unary_stream( request: RequestType, target: str, @@ -303,6 +331,7 @@ def unary_stream( wait_for_ready: Optional[bool] = None, timeout: Optional[float] = _DEFAULT_TIMEOUT, metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None, + _registered_method: Optional[bool] = False, ) -> Iterator[ResponseType]: """Invokes a unary-stream RPC without an explicitly specified channel. @@ -355,11 +384,17 @@ def unary_stream( Returns: An iterator of responses. """ - channel = ChannelCache.get().get_channel( - target, options, channel_credentials, insecure, compression + channel, method_handle = ChannelCache.get().get_channel( + target, + options, + channel_credentials, + insecure, + compression, + method, + _registered_method, ) multicallable = channel.unary_stream( - method, request_serializer, response_deserializer + method, request_serializer, response_deserializer, method_handle ) wait_for_ready = wait_for_ready if wait_for_ready is not None else True return multicallable( @@ -372,6 +407,7 @@ def unary_stream( @experimental_api +# pylint: disable=too-many-locals def stream_unary( request_iterator: Iterator[RequestType], target: str, @@ -386,6 +422,7 @@ def stream_unary( wait_for_ready: Optional[bool] = None, timeout: Optional[float] = _DEFAULT_TIMEOUT, metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None, + _registered_method: Optional[bool] = False, ) -> ResponseType: """Invokes a stream-unary RPC without an explicitly specified channel. @@ -438,11 +475,17 @@ def stream_unary( Returns: The response to the RPC. """ - channel = ChannelCache.get().get_channel( - target, options, channel_credentials, insecure, compression + channel, method_handle = ChannelCache.get().get_channel( + target, + options, + channel_credentials, + insecure, + compression, + method, + _registered_method, ) multicallable = channel.stream_unary( - method, request_serializer, response_deserializer + method, request_serializer, response_deserializer, method_handle ) wait_for_ready = wait_for_ready if wait_for_ready is not None else True return multicallable( @@ -455,6 +498,7 @@ def stream_unary( @experimental_api +# pylint: disable=too-many-locals def stream_stream( request_iterator: Iterator[RequestType], target: str, @@ -469,6 +513,7 @@ def stream_stream( wait_for_ready: Optional[bool] = None, timeout: Optional[float] = _DEFAULT_TIMEOUT, metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None, + _registered_method: Optional[bool] = False, ) -> Iterator[ResponseType]: """Invokes a stream-stream RPC without an explicitly specified channel. @@ -521,11 +566,17 @@ def stream_stream( Returns: An iterator of responses. """ - channel = ChannelCache.get().get_channel( - target, options, channel_credentials, insecure, compression + channel, method_handle = ChannelCache.get().get_channel( + target, + options, + channel_credentials, + insecure, + compression, + method, + _registered_method, ) multicallable = channel.stream_stream( - method, request_serializer, response_deserializer + method, request_serializer, response_deserializer, method_handle ) wait_for_ready = wait_for_ready if wait_for_ready is not None else True return multicallable( diff --git a/src/python/grpcio/grpc/aio/_base_channel.py b/src/python/grpcio/grpc/aio/_base_channel.py index 2fb8e75d3a9..8bc32b8b641 100644 --- a/src/python/grpcio/grpc/aio/_base_channel.py +++ b/src/python/grpcio/grpc/aio/_base_channel.py @@ -272,6 +272,7 @@ class Channel(abc.ABC): method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> UnaryUnaryMultiCallable: """Creates a UnaryUnaryMultiCallable for a unary-unary method. @@ -282,6 +283,8 @@ class Channel(abc.ABC): response_deserializer: Optional :term:`deserializer` for deserializing the response message. Response goes undeserialized in case None is passed. + _registered_method: Implementation Private. Optional: A bool representing + whether the method is registered. Returns: A UnaryUnaryMultiCallable value for the named unary-unary method. @@ -293,6 +296,7 @@ class Channel(abc.ABC): method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> UnaryStreamMultiCallable: """Creates a UnaryStreamMultiCallable for a unary-stream method. @@ -303,6 +307,8 @@ class Channel(abc.ABC): response_deserializer: Optional :term:`deserializer` for deserializing the response message. Response goes undeserialized in case None is passed. + _registered_method: Implementation Private. Optional: A bool representing + whether the method is registered. Returns: A UnarySteramMultiCallable value for the named unary-stream method. @@ -314,6 +320,7 @@ class Channel(abc.ABC): method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> StreamUnaryMultiCallable: """Creates a StreamUnaryMultiCallable for a stream-unary method. @@ -324,6 +331,8 @@ class Channel(abc.ABC): response_deserializer: Optional :term:`deserializer` for deserializing the response message. Response goes undeserialized in case None is passed. + _registered_method: Implementation Private. Optional: A bool representing + whether the method is registered. Returns: A StreamUnaryMultiCallable value for the named stream-unary method. @@ -335,6 +344,7 @@ class Channel(abc.ABC): method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> StreamStreamMultiCallable: """Creates a StreamStreamMultiCallable for a stream-stream method. @@ -345,6 +355,8 @@ class Channel(abc.ABC): response_deserializer: Optional :term:`deserializer` for deserializing the response message. Response goes undeserialized in case None is passed. + _registered_method: Implementation Private. Optional: A bool representing + whether the method is registered. Returns: A StreamStreamMultiCallable value for the named stream-stream method. diff --git a/src/python/grpcio/grpc/aio/_channel.py b/src/python/grpcio/grpc/aio/_channel.py index bea64c27fa3..ea4de20965a 100644 --- a/src/python/grpcio/grpc/aio/_channel.py +++ b/src/python/grpcio/grpc/aio/_channel.py @@ -478,11 +478,20 @@ class Channel(_base_channel.Channel): await self.wait_for_state_change(state) state = self.get_state(try_to_connect=True) + # TODO(xuanwn): Implement this method after we have + # observability for Asyncio. + def _get_registered_call_handle(self, method: str) -> int: + pass + + # TODO(xuanwn): Implement _registered_method after we have + # observability for Asyncio. + # pylint: disable=arguments-differ,unused-argument def unary_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> UnaryUnaryMultiCallable: return UnaryUnaryMultiCallable( self._channel, @@ -494,11 +503,15 @@ class Channel(_base_channel.Channel): self._loop, ) + # TODO(xuanwn): Implement _registered_method after we have + # observability for Asyncio. + # pylint: disable=arguments-differ,unused-argument def unary_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> UnaryStreamMultiCallable: return UnaryStreamMultiCallable( self._channel, @@ -510,11 +523,15 @@ class Channel(_base_channel.Channel): self._loop, ) + # TODO(xuanwn): Implement _registered_method after we have + # observability for Asyncio. + # pylint: disable=arguments-differ,unused-argument def stream_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> StreamUnaryMultiCallable: return StreamUnaryMultiCallable( self._channel, @@ -526,11 +543,15 @@ class Channel(_base_channel.Channel): self._loop, ) + # TODO(xuanwn): Implement _registered_method after we have + # observability for Asyncio. + # pylint: disable=arguments-differ,unused-argument def stream_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + _registered_method: Optional[bool] = False, ) -> StreamStreamMultiCallable: return StreamStreamMultiCallable( self._channel, diff --git a/src/python/grpcio_testing/grpc_testing/_channel/_channel.py b/src/python/grpcio_testing/grpc_testing/_channel/_channel.py index 170533f63ea..3f12e1f4df8 100644 --- a/src/python/grpcio_testing/grpc_testing/_channel/_channel.py +++ b/src/python/grpcio_testing/grpc_testing/_channel/_channel.py @@ -31,23 +31,42 @@ class TestingChannel(grpc_testing.Channel): def unsubscribe(self, callback): raise NotImplementedError() + def _get_registered_call_handle(self, method: str) -> int: + pass + def unary_unary( - self, method, request_serializer=None, response_deserializer=None + self, + method, + request_serializer=None, + response_deserializer=None, + _registered_method=False, ): return _multi_callable.UnaryUnary(method, self._state) def unary_stream( - self, method, request_serializer=None, response_deserializer=None + self, + method, + request_serializer=None, + response_deserializer=None, + _registered_method=False, ): return _multi_callable.UnaryStream(method, self._state) def stream_unary( - self, method, request_serializer=None, response_deserializer=None + self, + method, + request_serializer=None, + response_deserializer=None, + _registered_method=False, ): return _multi_callable.StreamUnary(method, self._state) def stream_stream( - self, method, request_serializer=None, response_deserializer=None + self, + method, + request_serializer=None, + response_deserializer=None, + _registered_method=False, ): return _multi_callable.StreamStream(method, self._state) diff --git a/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py b/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py index 78333fc62c7..2379ab59a3c 100644 --- a/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py +++ b/src/python/grpcio_tests/tests/channelz/_channelz_servicer_test.py @@ -99,16 +99,20 @@ class ChannelzServicerTest(unittest.TestCase): def _send_successful_unary_unary(self, idx): _, r = ( self._pairs[idx] - .channel.unary_unary(_SUCCESSFUL_UNARY_UNARY) + .channel.unary_unary( + _SUCCESSFUL_UNARY_UNARY, + _registered_method=True, + ) .with_call(_REQUEST) ) self.assertEqual(r.code(), grpc.StatusCode.OK) def _send_failed_unary_unary(self, idx): try: - self._pairs[idx].channel.unary_unary(_FAILED_UNARY_UNARY).with_call( - _REQUEST - ) + self._pairs[idx].channel.unary_unary( + _FAILED_UNARY_UNARY, + _registered_method=True, + ).with_call(_REQUEST) except grpc.RpcError: return else: @@ -117,7 +121,10 @@ class ChannelzServicerTest(unittest.TestCase): def _send_successful_stream_stream(self, idx): response_iterator = ( self._pairs[idx] - .channel.stream_stream(_SUCCESSFUL_STREAM_STREAM) + .channel.stream_stream( + _SUCCESSFUL_STREAM_STREAM, + _registered_method=True, + ) .__call__(iter([_REQUEST] * test_constants.STREAM_LENGTH)) ) cnt = 0 diff --git a/src/python/grpcio_tests/tests/csds/csds_test.py b/src/python/grpcio_tests/tests/csds/csds_test.py index 365115bcb99..5a882b2d679 100644 --- a/src/python/grpcio_tests/tests/csds/csds_test.py +++ b/src/python/grpcio_tests/tests/csds/csds_test.py @@ -86,7 +86,10 @@ class TestCsds(unittest.TestCase): # Force the XdsClient to initialize and request a resource with self.assertRaises(grpc.RpcError) as rpc_error: - dummy_channel.unary_unary("")(b"", wait_for_ready=False, timeout=1) + dummy_channel.unary_unary( + "", + _registered_method=True, + )(b"", wait_for_ready=False, timeout=1) self.assertEqual( grpc.StatusCode.DEADLINE_EXCEEDED, rpc_error.exception.code() ) diff --git a/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py b/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py index 9a6538cca35..c6e41b3ae8c 100644 --- a/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py +++ b/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py @@ -543,6 +543,17 @@ class PythonPluginTest(unittest.TestCase): ) service.server.stop(None) + def testRegisteredMethod(self): + """Tests that we're setting _registered_call_handle when create call using generated stub.""" + service = _CreateService() + self.assertTrue(service.stub.UnaryCall._registered_call_handle) + self.assertTrue( + service.stub.StreamingOutputCall._registered_call_handle + ) + self.assertTrue(service.stub.StreamingInputCall._registered_call_handle) + self.assertTrue(service.stub.FullDuplexCall._registered_call_handle) + service.server.stop(None) + @unittest.skipIf( sys.version_info[0] < 3 or sys.version_info[1] < 6, diff --git a/src/python/grpcio_tests/tests/qps/benchmark_client.py b/src/python/grpcio_tests/tests/qps/benchmark_client.py index e5aafc4142f..9310f8a8c61 100644 --- a/src/python/grpcio_tests/tests/qps/benchmark_client.py +++ b/src/python/grpcio_tests/tests/qps/benchmark_client.py @@ -32,13 +32,16 @@ _TIMEOUT = 60 * 60 * 24 class GenericStub(object): def __init__(self, channel): self.UnaryCall = channel.unary_unary( - "/grpc.testing.BenchmarkService/UnaryCall" + "/grpc.testing.BenchmarkService/UnaryCall", + _registered_method=True, ) self.StreamingFromServer = channel.unary_stream( - "/grpc.testing.BenchmarkService/StreamingFromServer" + "/grpc.testing.BenchmarkService/StreamingFromServer", + _registered_method=True, ) self.StreamingCall = channel.stream_stream( - "/grpc.testing.BenchmarkService/StreamingCall" + "/grpc.testing.BenchmarkService/StreamingCall", + _registered_method=True, ) diff --git a/src/python/grpcio_tests/tests/status/_grpc_status_test.py b/src/python/grpcio_tests/tests/status/_grpc_status_test.py index 2573e961f18..031bdbe4d53 100644 --- a/src/python/grpcio_tests/tests/status/_grpc_status_test.py +++ b/src/python/grpcio_tests/tests/status/_grpc_status_test.py @@ -138,7 +138,10 @@ class StatusTest(unittest.TestCase): self._channel.close() def test_status_ok(self): - _, call = self._channel.unary_unary(_STATUS_OK).with_call(_REQUEST) + _, call = self._channel.unary_unary( + _STATUS_OK, + _registered_method=True, + ).with_call(_REQUEST) # Succeed RPC doesn't have status status = rpc_status.from_call(call) @@ -146,7 +149,10 @@ class StatusTest(unittest.TestCase): def test_status_not_ok(self): with self.assertRaises(grpc.RpcError) as exception_context: - self._channel.unary_unary(_STATUS_NOT_OK).with_call(_REQUEST) + self._channel.unary_unary( + _STATUS_NOT_OK, + _registered_method=True, + ).with_call(_REQUEST) rpc_error = exception_context.exception self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) @@ -156,7 +162,10 @@ class StatusTest(unittest.TestCase): def test_error_details(self): with self.assertRaises(grpc.RpcError) as exception_context: - self._channel.unary_unary(_ERROR_DETAILS).with_call(_REQUEST) + self._channel.unary_unary( + _ERROR_DETAILS, + _registered_method=True, + ).with_call(_REQUEST) rpc_error = exception_context.exception status = rpc_status.from_call(rpc_error) @@ -173,7 +182,10 @@ class StatusTest(unittest.TestCase): def test_code_message_validation(self): with self.assertRaises(grpc.RpcError) as exception_context: - self._channel.unary_unary(_INCONSISTENT).with_call(_REQUEST) + self._channel.unary_unary( + _INCONSISTENT, + _registered_method=True, + ).with_call(_REQUEST) rpc_error = exception_context.exception self.assertEqual(rpc_error.code(), grpc.StatusCode.NOT_FOUND) @@ -182,7 +194,10 @@ class StatusTest(unittest.TestCase): def test_invalid_code(self): with self.assertRaises(grpc.RpcError) as exception_context: - self._channel.unary_unary(_INVALID_CODE).with_call(_REQUEST) + self._channel.unary_unary( + _INVALID_CODE, + _registered_method=True, + ).with_call(_REQUEST) rpc_error = exception_context.exception self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN) # Invalid status code exception raised during coversion diff --git a/src/python/grpcio_tests/tests/unit/_abort_test.py b/src/python/grpcio_tests/tests/unit/_abort_test.py index 731bb741bec..46f48bd1cae 100644 --- a/src/python/grpcio_tests/tests/unit/_abort_test.py +++ b/src/python/grpcio_tests/tests/unit/_abort_test.py @@ -107,7 +107,10 @@ class AbortTest(unittest.TestCase): def test_abort(self): with self.assertRaises(grpc.RpcError) as exception_context: - self._channel.unary_unary(_ABORT)(_REQUEST) + self._channel.unary_unary( + _ABORT, + _registered_method=True, + )(_REQUEST) rpc_error = exception_context.exception self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) @@ -124,7 +127,10 @@ class AbortTest(unittest.TestCase): # Servicer will abort() after creating a local ref to do_not_leak_me. with self.assertRaises(grpc.RpcError): - self._channel.unary_unary(_ABORT)(_REQUEST) + self._channel.unary_unary( + _ABORT, + _registered_method=True, + )(_REQUEST) # Server may still have a stack frame reference to the exception even # after client sees error, so ensure server has shutdown. @@ -134,7 +140,10 @@ class AbortTest(unittest.TestCase): def test_abort_with_status(self): with self.assertRaises(grpc.RpcError) as exception_context: - self._channel.unary_unary(_ABORT_WITH_STATUS)(_REQUEST) + self._channel.unary_unary( + _ABORT_WITH_STATUS, + _registered_method=True, + )(_REQUEST) rpc_error = exception_context.exception self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) @@ -143,7 +152,10 @@ class AbortTest(unittest.TestCase): def test_invalid_code(self): with self.assertRaises(grpc.RpcError) as exception_context: - self._channel.unary_unary(_INVALID_CODE)(_REQUEST) + self._channel.unary_unary( + _INVALID_CODE, + _registered_method=True, + )(_REQUEST) rpc_error = exception_context.exception self.assertEqual(rpc_error.code(), grpc.StatusCode.UNKNOWN) diff --git a/src/python/grpcio_tests/tests/unit/_auth_context_test.py b/src/python/grpcio_tests/tests/unit/_auth_context_test.py index 039c908c3e5..0e5e2017fba 100644 --- a/src/python/grpcio_tests/tests/unit/_auth_context_test.py +++ b/src/python/grpcio_tests/tests/unit/_auth_context_test.py @@ -78,7 +78,10 @@ class AuthContextTest(unittest.TestCase): server.start() with grpc.insecure_channel("localhost:%d" % port) as channel: - response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + response = channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + )(_REQUEST) server.stop(None) auth_data = pickle.loads(response) @@ -115,7 +118,10 @@ class AuthContextTest(unittest.TestCase): channel_creds, options=_PROPERTY_OPTIONS, ) - response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + response = channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + )(_REQUEST) channel.close() server.stop(None) @@ -161,7 +167,10 @@ class AuthContextTest(unittest.TestCase): options=_PROPERTY_OPTIONS, ) - response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + response = channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + )(_REQUEST) channel.close() server.stop(None) @@ -180,7 +189,10 @@ class AuthContextTest(unittest.TestCase): channel = grpc.secure_channel( "localhost:{}".format(port), channel_creds, options=channel_options ) - response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + response = channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + )(_REQUEST) auth_data = pickle.loads(response) self.assertEqual( expect_ssl_session_reused, diff --git a/src/python/grpcio_tests/tests/unit/_channel_close_test.py b/src/python/grpcio_tests/tests/unit/_channel_close_test.py index 4e5f215af89..b2ba4e7c887 100644 --- a/src/python/grpcio_tests/tests/unit/_channel_close_test.py +++ b/src/python/grpcio_tests/tests/unit/_channel_close_test.py @@ -123,7 +123,10 @@ class ChannelCloseTest(unittest.TestCase): def test_close_immediately_after_call_invocation(self): channel = grpc.insecure_channel("localhost:{}".format(self._port)) - multi_callable = channel.stream_stream(_STREAM_URI) + multi_callable = channel.stream_stream( + _STREAM_URI, + _registered_method=True, + ) request_iterator = _Pipe(()) response_iterator = multi_callable(request_iterator) channel.close() @@ -133,7 +136,10 @@ class ChannelCloseTest(unittest.TestCase): def test_close_while_call_active(self): channel = grpc.insecure_channel("localhost:{}".format(self._port)) - multi_callable = channel.stream_stream(_STREAM_URI) + multi_callable = channel.stream_stream( + _STREAM_URI, + _registered_method=True, + ) request_iterator = _Pipe((b"abc",)) response_iterator = multi_callable(request_iterator) next(response_iterator) @@ -146,7 +152,10 @@ class ChannelCloseTest(unittest.TestCase): with grpc.insecure_channel( "localhost:{}".format(self._port) ) as channel: # pylint: disable=bad-continuation - multi_callable = channel.stream_stream(_STREAM_URI) + multi_callable = channel.stream_stream( + _STREAM_URI, + _registered_method=True, + ) request_iterator = _Pipe((b"abc",)) response_iterator = multi_callable(request_iterator) next(response_iterator) @@ -158,7 +167,10 @@ class ChannelCloseTest(unittest.TestCase): with grpc.insecure_channel( "localhost:{}".format(self._port) ) as channel: # pylint: disable=bad-continuation - multi_callable = channel.stream_stream(_STREAM_URI) + multi_callable = channel.stream_stream( + _STREAM_URI, + _registered_method=True, + ) request_iterators = tuple( _Pipe((b"abc",)) for _ in range(test_constants.THREAD_CONCURRENCY) @@ -176,7 +188,10 @@ class ChannelCloseTest(unittest.TestCase): def test_many_concurrent_closes(self): channel = grpc.insecure_channel("localhost:{}".format(self._port)) - multi_callable = channel.stream_stream(_STREAM_URI) + multi_callable = channel.stream_stream( + _STREAM_URI, + _registered_method=True, + ) request_iterator = _Pipe((b"abc",)) response_iterator = multi_callable(request_iterator) next(response_iterator) @@ -203,10 +218,16 @@ class ChannelCloseTest(unittest.TestCase): with grpc.insecure_channel( "localhost:{}".format(self._port) ) as channel: - stream_multi_callable = channel.stream_stream(_STREAM_URI) + stream_multi_callable = channel.stream_stream( + _STREAM_URI, + _registered_method=True, + ) endless_iterator = itertools.repeat(b"abc") stream_response_iterator = stream_multi_callable(endless_iterator) - future = channel.unary_unary(_UNARY_URI).future(b"abc") + future = channel.unary_unary( + _UNARY_URI, + _registered_method=True, + ).future(b"abc") def on_done_callback(future): raise Exception("This should not cause a deadlock.") diff --git a/src/python/grpcio_tests/tests/unit/_compression_test.py b/src/python/grpcio_tests/tests/unit/_compression_test.py index be2a528ea9b..9fdeca030d8 100644 --- a/src/python/grpcio_tests/tests/unit/_compression_test.py +++ b/src/python/grpcio_tests/tests/unit/_compression_test.py @@ -237,7 +237,10 @@ def _get_compression_ratios( def _unary_unary_client(channel, multicallable_kwargs, message): - multi_callable = channel.unary_unary(_UNARY_UNARY) + multi_callable = channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ) response = multi_callable(message, **multicallable_kwargs) if response != message: raise RuntimeError( @@ -246,7 +249,10 @@ def _unary_unary_client(channel, multicallable_kwargs, message): def _unary_stream_client(channel, multicallable_kwargs, message): - multi_callable = channel.unary_stream(_UNARY_STREAM) + multi_callable = channel.unary_stream( + _UNARY_STREAM, + _registered_method=True, + ) response_iterator = multi_callable(message, **multicallable_kwargs) for response in response_iterator: if response != message: @@ -256,7 +262,10 @@ def _unary_stream_client(channel, multicallable_kwargs, message): def _stream_unary_client(channel, multicallable_kwargs, message): - multi_callable = channel.stream_unary(_STREAM_UNARY) + multi_callable = channel.stream_unary( + _STREAM_UNARY, + _registered_method=True, + ) requests = (_REQUEST for _ in range(_STREAM_LENGTH)) response = multi_callable(requests, **multicallable_kwargs) if response != message: @@ -266,7 +275,10 @@ def _stream_unary_client(channel, multicallable_kwargs, message): def _stream_stream_client(channel, multicallable_kwargs, message): - multi_callable = channel.stream_stream(_STREAM_STREAM) + multi_callable = channel.stream_stream( + _STREAM_STREAM, + _registered_method=True, + ) request_prefix = str(0).encode("ascii") * 100 requests = ( request_prefix + str(i).encode("ascii") for i in range(_STREAM_LENGTH) diff --git a/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py b/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py index 6f3b601ceb2..5a23d5dc69e 100644 --- a/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py +++ b/src/python/grpcio_tests/tests/unit/_contextvars_propagation_test.py @@ -116,7 +116,10 @@ class ContextVarsPropagationTest(unittest.TestCase): local_credentials, call_credentials ) with grpc.secure_channel(target, composite_credentials) as channel: - stub = channel.unary_unary(_UNARY_UNARY) + stub = channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ) response = stub(_REQUEST, wait_for_ready=True) self.assertEqual(_REQUEST, response) @@ -142,7 +145,10 @@ class ContextVarsPropagationTest(unittest.TestCase): with grpc.secure_channel( target, composite_credentials ) as channel: - stub = channel.unary_unary(_UNARY_UNARY) + stub = channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ) wait_group.done() wait_group.wait() for i in range(_RPC_COUNT): diff --git a/src/python/grpcio_tests/tests/unit/_dns_resolver_test.py b/src/python/grpcio_tests/tests/unit/_dns_resolver_test.py index 62a95a02135..bcd7e6da849 100644 --- a/src/python/grpcio_tests/tests/unit/_dns_resolver_test.py +++ b/src/python/grpcio_tests/tests/unit/_dns_resolver_test.py @@ -55,7 +55,10 @@ class DNSResolverTest(unittest.TestCase): "loopback46.unittest.grpc.io:%d" % self._port ) as channel: self.assertEqual( - channel.unary_unary(_METHOD)( + channel.unary_unary( + _METHOD, + _registered_method=True, + )( _REQUEST, timeout=10, ), diff --git a/src/python/grpcio_tests/tests/unit/_empty_message_test.py b/src/python/grpcio_tests/tests/unit/_empty_message_test.py index e2dc1594202..a303aa8b3e9 100644 --- a/src/python/grpcio_tests/tests/unit/_empty_message_test.py +++ b/src/python/grpcio_tests/tests/unit/_empty_message_test.py @@ -96,25 +96,33 @@ class EmptyMessageTest(unittest.TestCase): self._channel.close() def testUnaryUnary(self): - response = self._channel.unary_unary(_UNARY_UNARY)(_REQUEST) + response = self._channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + )(_REQUEST) self.assertEqual(_RESPONSE, response) def testUnaryStream(self): - response_iterator = self._channel.unary_stream(_UNARY_STREAM)(_REQUEST) + response_iterator = self._channel.unary_stream( + _UNARY_STREAM, + _registered_method=True, + )(_REQUEST) self.assertSequenceEqual( [_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator) ) def testStreamUnary(self): - response = self._channel.stream_unary(_STREAM_UNARY)( - iter([_REQUEST] * test_constants.STREAM_LENGTH) - ) + response = self._channel.stream_unary( + _STREAM_UNARY, + _registered_method=True, + )(iter([_REQUEST] * test_constants.STREAM_LENGTH)) self.assertEqual(_RESPONSE, response) def testStreamStream(self): - response_iterator = self._channel.stream_stream(_STREAM_STREAM)( - iter([_REQUEST] * test_constants.STREAM_LENGTH) - ) + response_iterator = self._channel.stream_stream( + _STREAM_STREAM, + _registered_method=True, + )(iter([_REQUEST] * test_constants.STREAM_LENGTH)) self.assertSequenceEqual( [_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator) ) diff --git a/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py b/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py index 4f07477fac9..334b7ed5a3a 100644 --- a/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py +++ b/src/python/grpcio_tests/tests/unit/_error_message_encoding_test.py @@ -73,7 +73,10 @@ class ErrorMessageEncodingTest(unittest.TestCase): def testMessageEncoding(self): for message in _UNICODE_ERROR_MESSAGES: - multi_callable = self._channel.unary_unary(_UNARY_UNARY) + multi_callable = self._channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ) with self.assertRaises(grpc.RpcError) as cm: multi_callable(message.encode("utf-8")) diff --git a/src/python/grpcio_tests/tests/unit/_exit_scenarios.py b/src/python/grpcio_tests/tests/unit/_exit_scenarios.py index c1f9816df08..1b7e2e5ac84 100644 --- a/src/python/grpcio_tests/tests/unit/_exit_scenarios.py +++ b/src/python/grpcio_tests/tests/unit/_exit_scenarios.py @@ -210,14 +210,20 @@ if __name__ == "__main__": method = TEST_TO_METHOD[args.scenario] if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL: - multi_callable = channel.unary_unary(method) + multi_callable = channel.unary_unary( + method, + _registered_method=True, + ) future = multi_callable.future(REQUEST) result, call = multi_callable.with_call(REQUEST) elif ( args.scenario == IN_FLIGHT_UNARY_STREAM_CALL or args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL ): - multi_callable = channel.unary_stream(method) + multi_callable = channel.unary_stream( + method, + _registered_method=True, + ) response_iterator = multi_callable(REQUEST) for response in response_iterator: pass @@ -225,7 +231,10 @@ if __name__ == "__main__": args.scenario == IN_FLIGHT_STREAM_UNARY_CALL or args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL ): - multi_callable = channel.stream_unary(method) + multi_callable = channel.stream_unary( + method, + _registered_method=True, + ) future = multi_callable.future(infinite_request_iterator()) result, call = multi_callable.with_call( iter([REQUEST] * test_constants.STREAM_LENGTH) @@ -234,7 +243,10 @@ if __name__ == "__main__": args.scenario == IN_FLIGHT_STREAM_STREAM_CALL or args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL ): - multi_callable = channel.stream_stream(method) + multi_callable = channel.stream_stream( + method, + _registered_method=True, + ) response_iterator = multi_callable(infinite_request_iterator()) for response in response_iterator: pass diff --git a/src/python/grpcio_tests/tests/unit/_interceptor_test.py b/src/python/grpcio_tests/tests/unit/_interceptor_test.py index 9bbff1f6bee..72e299b5887 100644 --- a/src/python/grpcio_tests/tests/unit/_interceptor_test.py +++ b/src/python/grpcio_tests/tests/unit/_interceptor_test.py @@ -231,7 +231,7 @@ class _GenericHandler(grpc.GenericRpcHandler): def _unary_unary_multi_callable(channel): - return channel.unary_unary(_UNARY_UNARY) + return channel.unary_unary(_UNARY_UNARY, _registered_method=True) def _unary_stream_multi_callable(channel): @@ -239,6 +239,7 @@ def _unary_stream_multi_callable(channel): _UNARY_STREAM, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE, + _registered_method=True, ) @@ -247,11 +248,12 @@ def _stream_unary_multi_callable(channel): _STREAM_UNARY, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE, + _registered_method=True, ) def _stream_stream_multi_callable(channel): - return channel.stream_stream(_STREAM_STREAM) + return channel.stream_stream(_STREAM_STREAM, _registered_method=True) class _ClientCallDetails( @@ -562,7 +564,7 @@ class InterceptorTest(unittest.TestCase): self._record[:] = [] multi_callable = _unary_unary_multi_callable(channel) - multi_callable.with_call( + response, call = multi_callable.with_call( request, metadata=( ( diff --git a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py index a19966131c5..58d1e589fde 100644 --- a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py +++ b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py @@ -32,7 +32,7 @@ _STREAM_STREAM = "/test/StreamStream" def _unary_unary_multi_callable(channel): - return channel.unary_unary(_UNARY_UNARY) + return channel.unary_unary(_UNARY_UNARY, _registered_method=True) def _unary_stream_multi_callable(channel): @@ -40,6 +40,7 @@ def _unary_stream_multi_callable(channel): _UNARY_STREAM, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE, + _registered_method=True, ) @@ -48,11 +49,15 @@ def _stream_unary_multi_callable(channel): _STREAM_UNARY, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE, + _registered_method=True, ) def _stream_stream_multi_callable(channel): - return channel.stream_stream(_STREAM_STREAM) + return channel.stream_stream( + _STREAM_STREAM, + _registered_method=True, + ) class InvalidMetadataTest(unittest.TestCase): diff --git a/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py b/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py index b22ab016593..cb903e7b812 100644 --- a/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py +++ b/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py @@ -219,7 +219,10 @@ class FailAfterFewIterationsCounter(object): def _unary_unary_multi_callable(channel): - return channel.unary_unary(_UNARY_UNARY) + return channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ) def _unary_stream_multi_callable(channel): @@ -227,6 +230,7 @@ def _unary_stream_multi_callable(channel): _UNARY_STREAM, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE, + _registered_method=True, ) @@ -235,19 +239,29 @@ def _stream_unary_multi_callable(channel): _STREAM_UNARY, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE, + _registered_method=True, ) def _stream_stream_multi_callable(channel): - return channel.stream_stream(_STREAM_STREAM) + return channel.stream_stream( + _STREAM_STREAM, + _registered_method=True, + ) def _defective_handler_multi_callable(channel): - return channel.unary_unary(_DEFECTIVE_GENERIC_RPC_HANDLER) + return channel.unary_unary( + _DEFECTIVE_GENERIC_RPC_HANDLER, + _registered_method=True, + ) def _defective_nested_exception_handler_multi_callable(channel): - return channel.unary_unary(_UNARY_UNARY_NESTED_EXCEPTION) + return channel.unary_unary( + _UNARY_UNARY_NESTED_EXCEPTION, + _registered_method=True, + ) class InvocationDefectsTest(unittest.TestCase): diff --git a/src/python/grpcio_tests/tests/unit/_local_credentials_test.py b/src/python/grpcio_tests/tests/unit/_local_credentials_test.py index 165f6ca16eb..9c5b425eaf7 100644 --- a/src/python/grpcio_tests/tests/unit/_local_credentials_test.py +++ b/src/python/grpcio_tests/tests/unit/_local_credentials_test.py @@ -53,9 +53,10 @@ class LocalCredentialsTest(unittest.TestCase): ) as channel: self.assertEqual( b"abc", - channel.unary_unary("/test/method")( - b"abc", wait_for_ready=True - ), + channel.unary_unary( + "/test/method", + _registered_method=True, + )(b"abc", wait_for_ready=True), ) server.stop(None) @@ -77,9 +78,10 @@ class LocalCredentialsTest(unittest.TestCase): with grpc.secure_channel(server_addr, channel_creds) as channel: self.assertEqual( b"abc", - channel.unary_unary("/test/method")( - b"abc", wait_for_ready=True - ), + channel.unary_unary( + "/test/method", + _registered_method=True, + )(b"abc", wait_for_ready=True), ) server.stop(None) diff --git a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py index 3c530058dc3..320deb7e5f6 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py @@ -207,45 +207,53 @@ class MetadataCodeDetailsTest(unittest.TestCase): self._server.start() self._channel = grpc.insecure_channel("localhost:{}".format(port)) + unary_unary_method_name = "/".join( + ( + "", + _SERVICE, + _UNARY_UNARY, + ) + ) self._unary_unary = self._channel.unary_unary( - "/".join( - ( - "", - _SERVICE, - _UNARY_UNARY, - ) - ), + unary_unary_method_name, request_serializer=_REQUEST_SERIALIZER, response_deserializer=_RESPONSE_DESERIALIZER, + _registered_method=True, + ) + unary_stream_method_name = "/".join( + ( + "", + _SERVICE, + _UNARY_STREAM, + ) ) self._unary_stream = self._channel.unary_stream( - "/".join( - ( - "", - _SERVICE, - _UNARY_STREAM, - ) - ), + unary_stream_method_name, + _registered_method=True, + ) + stream_unary_method_name = "/".join( + ( + "", + _SERVICE, + _STREAM_UNARY, + ) ) self._stream_unary = self._channel.stream_unary( - "/".join( - ( - "", - _SERVICE, - _STREAM_UNARY, - ) - ), + stream_unary_method_name, + _registered_method=True, + ) + stream_stream_method_name = "/".join( + ( + "", + _SERVICE, + _STREAM_STREAM, + ) ) self._stream_stream = self._channel.stream_stream( - "/".join( - ( - "", - _SERVICE, - _STREAM_STREAM, - ) - ), + stream_stream_method_name, request_serializer=_REQUEST_SERIALIZER, response_deserializer=_RESPONSE_DESERIALIZER, + _registered_method=True, ) def tearDown(self): @@ -828,16 +836,18 @@ class InspectContextTest(unittest.TestCase): self._server.start() self._channel = grpc.insecure_channel("localhost:{}".format(port)) + unary_unary_method_name = "/".join( + ( + "", + _SERVICE, + _UNARY_UNARY, + ) + ) self._unary_unary = self._channel.unary_unary( - "/".join( - ( - "", - _SERVICE, - _UNARY_UNARY, - ) - ), + unary_unary_method_name, request_serializer=_REQUEST_SERIALIZER, response_deserializer=_RESPONSE_DESERIALIZER, + _registered_method=True, ) def tearDown(self): diff --git a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py index a67a496860f..2cd9ad9bd89 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py @@ -110,7 +110,10 @@ def create_phony_channel(): def perform_unary_unary_call(channel, wait_for_ready=None): - channel.unary_unary(_UNARY_UNARY).__call__( + channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ).__call__( _REQUEST, timeout=test_constants.LONG_TIMEOUT, wait_for_ready=wait_for_ready, @@ -118,7 +121,10 @@ def perform_unary_unary_call(channel, wait_for_ready=None): def perform_unary_unary_with_call(channel, wait_for_ready=None): - channel.unary_unary(_UNARY_UNARY).with_call( + channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ).with_call( _REQUEST, timeout=test_constants.LONG_TIMEOUT, wait_for_ready=wait_for_ready, @@ -126,7 +132,10 @@ def perform_unary_unary_with_call(channel, wait_for_ready=None): def perform_unary_unary_future(channel, wait_for_ready=None): - channel.unary_unary(_UNARY_UNARY).future( + channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ).future( _REQUEST, timeout=test_constants.LONG_TIMEOUT, wait_for_ready=wait_for_ready, @@ -134,7 +143,10 @@ def perform_unary_unary_future(channel, wait_for_ready=None): def perform_unary_stream_call(channel, wait_for_ready=None): - response_iterator = channel.unary_stream(_UNARY_STREAM).__call__( + response_iterator = channel.unary_stream( + _UNARY_STREAM, + _registered_method=True, + ).__call__( _REQUEST, timeout=test_constants.LONG_TIMEOUT, wait_for_ready=wait_for_ready, @@ -144,7 +156,10 @@ def perform_unary_stream_call(channel, wait_for_ready=None): def perform_stream_unary_call(channel, wait_for_ready=None): - channel.stream_unary(_STREAM_UNARY).__call__( + channel.stream_unary( + _STREAM_UNARY, + _registered_method=True, + ).__call__( iter([_REQUEST] * test_constants.STREAM_LENGTH), timeout=test_constants.LONG_TIMEOUT, wait_for_ready=wait_for_ready, @@ -152,7 +167,10 @@ def perform_stream_unary_call(channel, wait_for_ready=None): def perform_stream_unary_with_call(channel, wait_for_ready=None): - channel.stream_unary(_STREAM_UNARY).with_call( + channel.stream_unary( + _STREAM_UNARY, + _registered_method=True, + ).with_call( iter([_REQUEST] * test_constants.STREAM_LENGTH), timeout=test_constants.LONG_TIMEOUT, wait_for_ready=wait_for_ready, @@ -160,7 +178,10 @@ def perform_stream_unary_with_call(channel, wait_for_ready=None): def perform_stream_unary_future(channel, wait_for_ready=None): - channel.stream_unary(_STREAM_UNARY).future( + channel.stream_unary( + _STREAM_UNARY, + _registered_method=True, + ).future( iter([_REQUEST] * test_constants.STREAM_LENGTH), timeout=test_constants.LONG_TIMEOUT, wait_for_ready=wait_for_ready, @@ -168,7 +189,9 @@ def perform_stream_unary_future(channel, wait_for_ready=None): def perform_stream_stream_call(channel, wait_for_ready=None): - response_iterator = channel.stream_stream(_STREAM_STREAM).__call__( + response_iterator = channel.stream_stream( + _STREAM_STREAM, _registered_method=True + ).__call__( iter([_REQUEST] * test_constants.STREAM_LENGTH), timeout=test_constants.LONG_TIMEOUT, wait_for_ready=wait_for_ready, diff --git a/src/python/grpcio_tests/tests/unit/_metadata_test.py b/src/python/grpcio_tests/tests/unit/_metadata_test.py index 7110177fa18..b9b7502972c 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_test.py @@ -195,7 +195,9 @@ class MetadataTest(unittest.TestCase): self._channel.close() def testUnaryUnary(self): - multi_callable = self._channel.unary_unary(_UNARY_UNARY) + multi_callable = self._channel.unary_unary( + _UNARY_UNARY, _registered_method=True + ) unused_response, call = multi_callable.with_call( _REQUEST, metadata=_INVOCATION_METADATA ) @@ -211,7 +213,9 @@ class MetadataTest(unittest.TestCase): ) def testUnaryStream(self): - multi_callable = self._channel.unary_stream(_UNARY_STREAM) + multi_callable = self._channel.unary_stream( + _UNARY_STREAM, _registered_method=True + ) call = multi_callable(_REQUEST, metadata=_INVOCATION_METADATA) self.assertTrue( test_common.metadata_transmitted( @@ -227,7 +231,9 @@ class MetadataTest(unittest.TestCase): ) def testStreamUnary(self): - multi_callable = self._channel.stream_unary(_STREAM_UNARY) + multi_callable = self._channel.stream_unary( + _STREAM_UNARY, _registered_method=True + ) unused_response, call = multi_callable.with_call( iter([_REQUEST] * test_constants.STREAM_LENGTH), metadata=_INVOCATION_METADATA, @@ -244,7 +250,9 @@ class MetadataTest(unittest.TestCase): ) def testStreamStream(self): - multi_callable = self._channel.stream_stream(_STREAM_STREAM) + multi_callable = self._channel.stream_stream( + _STREAM_STREAM, _registered_method=True + ) call = multi_callable( iter([_REQUEST] * test_constants.STREAM_LENGTH), metadata=_INVOCATION_METADATA, diff --git a/src/python/grpcio_tests/tests/unit/_reconnect_test.py b/src/python/grpcio_tests/tests/unit/_reconnect_test.py index d412533251c..8d4fadaa3da 100644 --- a/src/python/grpcio_tests/tests/unit/_reconnect_test.py +++ b/src/python/grpcio_tests/tests/unit/_reconnect_test.py @@ -52,7 +52,10 @@ class ReconnectTest(unittest.TestCase): server.add_insecure_port(addr) server.start() channel = grpc.insecure_channel(addr) - multi_callable = channel.unary_unary(_UNARY_UNARY) + multi_callable = channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ) self.assertEqual(_RESPONSE, multi_callable(_REQUEST)) server.stop(None) # By default, the channel connectivity is checked every 5s diff --git a/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py index 3fc04f06a18..93e72c84f07 100644 --- a/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py +++ b/src/python/grpcio_tests/tests/unit/_resource_exhausted_test.py @@ -149,7 +149,10 @@ class ResourceExhaustedTest(unittest.TestCase): self._channel.close() def testUnaryUnary(self): - multi_callable = self._channel.unary_unary(_UNARY_UNARY) + multi_callable = self._channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ) futures = [] for _ in range(test_constants.THREAD_CONCURRENCY): futures.append(multi_callable.future(_REQUEST)) @@ -178,7 +181,10 @@ class ResourceExhaustedTest(unittest.TestCase): self.assertEqual(_RESPONSE, multi_callable(_REQUEST)) def testUnaryStream(self): - multi_callable = self._channel.unary_stream(_UNARY_STREAM) + multi_callable = self._channel.unary_stream( + _UNARY_STREAM, + _registered_method=True, + ) calls = [] for _ in range(test_constants.THREAD_CONCURRENCY): calls.append(multi_callable(_REQUEST)) @@ -205,7 +211,10 @@ class ResourceExhaustedTest(unittest.TestCase): self.assertEqual(_RESPONSE, response) def testStreamUnary(self): - multi_callable = self._channel.stream_unary(_STREAM_UNARY) + multi_callable = self._channel.stream_unary( + _STREAM_UNARY, + _registered_method=True, + ) futures = [] request = iter([_REQUEST] * test_constants.STREAM_LENGTH) for _ in range(test_constants.THREAD_CONCURRENCY): @@ -236,7 +245,10 @@ class ResourceExhaustedTest(unittest.TestCase): self.assertEqual(_RESPONSE, multi_callable(request)) def testStreamStream(self): - multi_callable = self._channel.stream_stream(_STREAM_STREAM) + multi_callable = self._channel.stream_stream( + _STREAM_STREAM, + _registered_method=True, + ) calls = [] request = iter([_REQUEST] * test_constants.STREAM_LENGTH) for _ in range(test_constants.THREAD_CONCURRENCY): diff --git a/src/python/grpcio_tests/tests/unit/_rpc_test_helpers.py b/src/python/grpcio_tests/tests/unit/_rpc_test_helpers.py index 1027be1c677..d85eb30c8e7 100644 --- a/src/python/grpcio_tests/tests/unit/_rpc_test_helpers.py +++ b/src/python/grpcio_tests/tests/unit/_rpc_test_helpers.py @@ -277,7 +277,10 @@ class _GenericHandler(grpc.GenericRpcHandler): def unary_unary_multi_callable(channel): - return channel.unary_unary(_UNARY_UNARY) + return channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + ) def unary_stream_multi_callable(channel): @@ -285,6 +288,7 @@ def unary_stream_multi_callable(channel): _UNARY_STREAM, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE, + _registered_method=True, ) @@ -293,6 +297,7 @@ def unary_stream_non_blocking_multi_callable(channel): _UNARY_STREAM_NON_BLOCKING, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE, + _registered_method=True, ) @@ -301,15 +306,22 @@ def stream_unary_multi_callable(channel): _STREAM_UNARY, request_serializer=_SERIALIZE_REQUEST, response_deserializer=_DESERIALIZE_RESPONSE, + _registered_method=True, ) def stream_stream_multi_callable(channel): - return channel.stream_stream(_STREAM_STREAM) + return channel.stream_stream( + _STREAM_STREAM, + _registered_method=True, + ) def stream_stream_non_blocking_multi_callable(channel): - return channel.stream_stream(_STREAM_STREAM_NON_BLOCKING) + return channel.stream_stream( + _STREAM_STREAM_NON_BLOCKING, + _registered_method=True, + ) class BaseRPCTest(object): diff --git a/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py b/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py index 9190f108f7b..34d51fd72ec 100644 --- a/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py +++ b/src/python/grpcio_tests/tests/unit/_server_shutdown_scenarios.py @@ -81,7 +81,10 @@ def run_test(args): thread.start() port = port_queue.get() channel = grpc.insecure_channel("localhost:%d" % port) - multi_callable = channel.unary_unary(FORK_EXIT) + multi_callable = channel.unary_unary( + FORK_EXIT, + _registered_method=True, + ) result, call = multi_callable.with_call(REQUEST, wait_for_ready=True) os.wait() else: diff --git a/src/python/grpcio_tests/tests/unit/_session_cache_test.py b/src/python/grpcio_tests/tests/unit/_session_cache_test.py index acf671d1ca9..e4bc64f7d11 100644 --- a/src/python/grpcio_tests/tests/unit/_session_cache_test.py +++ b/src/python/grpcio_tests/tests/unit/_session_cache_test.py @@ -77,7 +77,10 @@ class SSLSessionCacheTest(unittest.TestCase): channel = grpc.secure_channel( "localhost:{}".format(port), channel_creds, options=channel_options ) - response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + response = channel.unary_unary( + _UNARY_UNARY, + _registered_method=True, + )(_REQUEST) auth_data = pickle.loads(response) self.assertEqual( expect_ssl_session_reused, diff --git a/src/python/grpcio_tests/tests/unit/_signal_client.py b/src/python/grpcio_tests/tests/unit/_signal_client.py index 56563c20075..34c3da0c933 100644 --- a/src/python/grpcio_tests/tests/unit/_signal_client.py +++ b/src/python/grpcio_tests/tests/unit/_signal_client.py @@ -53,7 +53,10 @@ def main_unary(server_target): """Initiate a unary RPC to be interrupted by a SIGINT.""" global per_process_rpc_future # pylint: disable=global-statement with grpc.insecure_channel(server_target) as channel: - multicallable = channel.unary_unary(UNARY_UNARY) + multicallable = channel.unary_unary( + UNARY_UNARY, + _registered_method=True, + ) signal.signal(signal.SIGINT, handle_sigint) per_process_rpc_future = multicallable.future( _MESSAGE, wait_for_ready=True @@ -67,9 +70,10 @@ def main_streaming(server_target): global per_process_rpc_future # pylint: disable=global-statement with grpc.insecure_channel(server_target) as channel: signal.signal(signal.SIGINT, handle_sigint) - per_process_rpc_future = channel.unary_stream(UNARY_STREAM)( - _MESSAGE, wait_for_ready=True - ) + per_process_rpc_future = channel.unary_stream( + UNARY_STREAM, + _registered_method=True, + )(_MESSAGE, wait_for_ready=True) for result in per_process_rpc_future: pass assert False, _ASSERTION_MESSAGE @@ -79,7 +83,10 @@ def main_unary_with_exception(server_target): """Initiate a unary RPC with a signal handler that will raise.""" channel = grpc.insecure_channel(server_target) try: - channel.unary_unary(UNARY_UNARY)(_MESSAGE, wait_for_ready=True) + channel.unary_unary( + UNARY_UNARY, + _registered_method=True, + )(_MESSAGE, wait_for_ready=True) except KeyboardInterrupt: sys.stderr.write("Running signal handler.\n") sys.stderr.flush() @@ -92,9 +99,10 @@ def main_streaming_with_exception(server_target): """Initiate a streaming RPC with a signal handler that will raise.""" channel = grpc.insecure_channel(server_target) try: - for _ in channel.unary_stream(UNARY_STREAM)( - _MESSAGE, wait_for_ready=True - ): + for _ in channel.unary_stream( + UNARY_STREAM, + _registered_method=True, + )(_MESSAGE, wait_for_ready=True): pass except KeyboardInterrupt: sys.stderr.write("Running signal handler.\n") diff --git a/src/python/grpcio_tests/tests/unit/_xds_credentials_test.py b/src/python/grpcio_tests/tests/unit/_xds_credentials_test.py index 977d564888d..6d8b2b6a041 100644 --- a/src/python/grpcio_tests/tests/unit/_xds_credentials_test.py +++ b/src/python/grpcio_tests/tests/unit/_xds_credentials_test.py @@ -71,9 +71,10 @@ class XdsCredentialsTest(unittest.TestCase): server_address, channel_creds, options=override_options ) as channel: request = b"abc" - response = channel.unary_unary("/test/method")( - request, wait_for_ready=True - ) + response = channel.unary_unary( + "/test/method", + _registered_method=True, + )(request, wait_for_ready=True) self.assertEqual(response, request) def test_xds_creds_fallback_insecure(self): @@ -89,9 +90,10 @@ class XdsCredentialsTest(unittest.TestCase): channel_creds = grpc.xds_channel_credentials(channel_fallback_creds) with grpc.secure_channel(server_address, channel_creds) as channel: request = b"abc" - response = channel.unary_unary("/test/method")( - request, wait_for_ready=True - ) + response = channel.unary_unary( + "/test/method", + _registered_method=True, + )(request, wait_for_ready=True) self.assertEqual(response, request) def test_start_xds_server(self): diff --git a/src/python/grpcio_tests/tests_gevent/unit/close_channel_test.py b/src/python/grpcio_tests/tests_gevent/unit/close_channel_test.py index 47fdb2c22e7..f09c47734e0 100644 --- a/src/python/grpcio_tests/tests_gevent/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_gevent/unit/close_channel_test.py @@ -65,6 +65,7 @@ class CloseChannelTest(unittest.TestCase): _UNARY_CALL_METHOD_WITH_SLEEP, request_serializer=messages_pb2.SimpleRequest.SerializeToString, response_deserializer=messages_pb2.SimpleResponse.FromString, + _registered_method=True, ) greenlet = group.spawn(self._run_client, UnaryCallWithSleep) # release loop so that greenlet can take control @@ -78,6 +79,7 @@ class CloseChannelTest(unittest.TestCase): _UNARY_CALL_METHOD_WITH_SLEEP, request_serializer=messages_pb2.SimpleRequest.SerializeToString, response_deserializer=messages_pb2.SimpleResponse.FromString, + _registered_method=True, ) greenlet = group.spawn(self._run_client, UnaryCallWithSleep) # release loop so that greenlet can take control diff --git a/src/python/grpcio_tests/tests_py3_only/unit/_leak_test.py b/src/python/grpcio_tests/tests_py3_only/unit/_leak_test.py index c917bd10521..06fa6466014 100644 --- a/src/python/grpcio_tests/tests_py3_only/unit/_leak_test.py +++ b/src/python/grpcio_tests/tests_py3_only/unit/_leak_test.py @@ -68,7 +68,10 @@ def _start_a_test_server(): def _perform_an_rpc(address): channel = grpc.insecure_channel(address) - multicallable = channel.unary_unary(_TEST_METHOD) + multicallable = channel.unary_unary( + _TEST_METHOD, + _registered_method=True, + ) response = multicallable(_REQUEST) assert _REQUEST == response diff --git a/src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py b/src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py index 771097936f6..adcc1299e9c 100644 --- a/src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py +++ b/src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py @@ -193,6 +193,7 @@ class SimpleStubsTest(unittest.TestCase): _UNARY_UNARY, channel_credentials=grpc.experimental.insecure_channel_credentials(), timeout=None, + _registered_method=0, ) self.assertEqual(_REQUEST, response) @@ -205,6 +206,7 @@ class SimpleStubsTest(unittest.TestCase): _UNARY_UNARY, channel_credentials=grpc.local_channel_credentials(), timeout=None, + _registered_method=0, ) self.assertEqual(_REQUEST, response) @@ -213,7 +215,10 @@ class SimpleStubsTest(unittest.TestCase): target = f"localhost:{port}" test_name = inspect.stack()[0][3] args = (_REQUEST, target, _UNARY_UNARY) - kwargs = {"channel_credentials": grpc.local_channel_credentials()} + kwargs = { + "channel_credentials": grpc.local_channel_credentials(), + "_registered_method": True, + } def _invoke(seed: str): run_kwargs = dict(kwargs) @@ -230,6 +235,7 @@ class SimpleStubsTest(unittest.TestCase): target, _UNARY_UNARY, channel_credentials=grpc.local_channel_credentials(), + _registered_method=0, ) self.assert_eventually( lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count() @@ -250,6 +256,7 @@ class SimpleStubsTest(unittest.TestCase): _UNARY_UNARY, options=options, channel_credentials=grpc.local_channel_credentials(), + _registered_method=0, ) self.assert_eventually( lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count() @@ -265,6 +272,7 @@ class SimpleStubsTest(unittest.TestCase): target, _UNARY_STREAM, channel_credentials=grpc.local_channel_credentials(), + _registered_method=0, ): self.assertEqual(_REQUEST, response) @@ -280,6 +288,7 @@ class SimpleStubsTest(unittest.TestCase): target, _STREAM_UNARY, channel_credentials=grpc.local_channel_credentials(), + _registered_method=0, ) self.assertEqual(_REQUEST, response) @@ -295,6 +304,7 @@ class SimpleStubsTest(unittest.TestCase): target, _STREAM_STREAM, channel_credentials=grpc.local_channel_credentials(), + _registered_method=0, ): self.assertEqual(_REQUEST, response) @@ -319,14 +329,22 @@ class SimpleStubsTest(unittest.TestCase): with _server(server_creds) as port: target = f"localhost:{port}" response = grpc.experimental.unary_unary( - _REQUEST, target, _UNARY_UNARY, options=_property_options + _REQUEST, + target, + _UNARY_UNARY, + options=_property_options, + _registered_method=0, ) def test_insecure_sugar(self): with _server(None) as port: target = f"localhost:{port}" response = grpc.experimental.unary_unary( - _REQUEST, target, _UNARY_UNARY, insecure=True + _REQUEST, + target, + _UNARY_UNARY, + insecure=True, + _registered_method=0, ) self.assertEqual(_REQUEST, response) @@ -340,14 +358,24 @@ class SimpleStubsTest(unittest.TestCase): _UNARY_UNARY, insecure=True, channel_credentials=grpc.local_channel_credentials(), + _registered_method=0, ) def test_default_wait_for_ready(self): addr, port, sock = get_socket() sock.close() target = f"{addr}:{port}" - channel = grpc._simple_stubs.ChannelCache.get().get_channel( - target, (), None, True, None + ( + channel, + unused_method_handle, + ) = grpc._simple_stubs.ChannelCache.get().get_channel( + target=target, + options=(), + channel_credentials=None, + insecure=True, + compression=None, + method=_UNARY_UNARY, + _registered_method=True, ) rpc_finished_event = threading.Event() rpc_failed_event = threading.Event() @@ -376,7 +404,12 @@ class SimpleStubsTest(unittest.TestCase): def _send_rpc(): try: response = grpc.experimental.unary_unary( - _REQUEST, target, _UNARY_UNARY, timeout=None, insecure=True + _REQUEST, + target, + _UNARY_UNARY, + timeout=None, + insecure=True, + _registered_method=0, ) rpc_finished_event.set() except Exception as e: @@ -399,6 +432,7 @@ class SimpleStubsTest(unittest.TestCase): target, _BLACK_HOLE, insecure=True, + _registered_method=0, **invocation_args, ) self.assertEqual(