From b5f06c216ef7016572ebecc1a75e247bbde432e4 Mon Sep 17 00:00:00 2001 From: Richard Belleville Date: Thu, 6 Feb 2020 16:10:25 -0800 Subject: [PATCH] Implement unary_stream --- src/python/grpcio/grpc/__init__.py | 4 +- src/python/grpcio/grpc/_simple_stubs.py | 50 ++++++++++++++----- .../tests/unit/py3_only/_simple_stubs_test.py | 19 +++++++ 3 files changed, 59 insertions(+), 14 deletions(-) diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index 9d7b701a8d7..d3e860b4690 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -2036,8 +2036,8 @@ __all__ = ( ) if sys.version_info[0] > 2: - from grpc._simple_stubs import unary_unary - __all__ = __all__ + (unary_unary,) + from grpc._simple_stubs import unary_unary, unary_stream + __all__ = __all__ + (unary_unary, unary_stream) ############################### Extension Shims ################################ diff --git a/src/python/grpcio/grpc/_simple_stubs.py b/src/python/grpcio/grpc/_simple_stubs.py index 1ae06bd8d63..b9c6753fcd2 100644 --- a/src/python/grpcio/grpc/_simple_stubs.py +++ b/src/python/grpcio/grpc/_simple_stubs.py @@ -7,7 +7,7 @@ import logging import threading import grpc -from typing import Any, Callable, Optional, Sequence, Text, Tuple, Union +from typing import Any, AnyStr, Callable, Iterator, Optional, Sequence, Tuple, Union _LOGGER = logging.getLogger(__name__) @@ -26,8 +26,8 @@ if _MAXIMUM_CHANNELS_KEY in os.environ: else: _MAXIMUM_CHANNELS = 2 ** 8 -def _create_channel(target: Text, - options: Sequence[Tuple[Text, Text]], +def _create_channel(target: str, + options: Sequence[Tuple[str, str]], channel_credentials: Optional[grpc.ChannelCredentials], compression: Optional[grpc.Compression]) -> grpc.Channel: if channel_credentials is None: @@ -98,8 +98,8 @@ class ChannelCache: def get_channel(self, - target: Text, - options: Sequence[Tuple[Text, Text]], + target: str, + options: Sequence[Tuple[str, str]], channel_credentials: Optional[grpc.ChannelCredentials], compression: Optional[grpc.Compression]) -> grpc.Channel: key = (target, options, channel_credentials, compression) @@ -123,20 +123,19 @@ class ChannelCache: return len(self._mapping) -# TODO: s/Text/str/g def unary_unary(request: Any, - target: Text, - method: Text, + target: str, + method: str, request_serializer: Optional[Callable[[Any], bytes]] = None, request_deserializer: Optional[Callable[[bytes], Any]] = None, - options: Sequence[Tuple[Text, Text]] = (), + options: Sequence[Tuple[AnyStr, AnyStr]] = (), # TODO: Somehow make insecure_channel opt-in, not the default. channel_credentials: Optional[grpc.ChannelCredentials] = None, call_credentials: Optional[grpc.CallCredentials] = None, compression: Optional[grpc.Compression] = None, wait_for_ready: Optional[bool] = None, timeout: Optional[float] = None, - metadata: Optional[Sequence[Tuple[Text, Union[Text, bytes]]]] = None) -> Any: + metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None) -> Any: """Invokes a unary RPC without an explicitly specified channel. This is backed by a cache of channels evicted by a background thread @@ -144,8 +143,6 @@ def unary_unary(request: Any, TODO: Document the parameters and return value. """ - - # TODO: Warn if the timeout is greater than the channel eviction time. channel = ChannelCache.get().get_channel(target, options, channel_credentials, compression) multicallable = channel.unary_unary(method, request_serializer, request_deserializer) return multicallable(request, @@ -153,3 +150,32 @@ def unary_unary(request: Any, wait_for_ready=wait_for_ready, credentials=call_credentials, timeout=timeout) + + +def unary_stream(request: Any, + target: str, + method: str, + request_serializer: Optional[Callable[[Any], bytes]] = None, + request_deserializer: Optional[Callable[[bytes], Any]] = None, + options: Sequence[Tuple[AnyStr, AnyStr]] = (), + # TODO: Somehow make insecure_channel opt-in, not the default. + channel_credentials: Optional[grpc.ChannelCredentials] = None, + call_credentials: Optional[grpc.CallCredentials] = None, + compression: Optional[grpc.Compression] = None, + wait_for_ready: Optional[bool] = None, + timeout: Optional[float] = None, + metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None) -> Iterator[Any]: + """Invokes a unary-stream RPC without an explicitly specified channel. + + This is backed by a cache of channels evicted by a background thread + on a periodic basis. + + TODO: Document the parameters and return value. + """ + channel = ChannelCache.get().get_channel(target, options, channel_credentials, compression) + multicallable = channel.unary_stream(method, request_serializer, request_deserializer) + return multicallable(request, + metadata=metadata, + wait_for_ready=wait_for_ready, + credentials=call_credentials, + timeout=timeout) diff --git a/src/python/grpcio_tests/tests/unit/py3_only/_simple_stubs_test.py b/src/python/grpcio_tests/tests/unit/py3_only/_simple_stubs_test.py index 8a97a8f7d67..67d1b6c48f1 100644 --- a/src/python/grpcio_tests/tests/unit/py3_only/_simple_stubs_test.py +++ b/src/python/grpcio_tests/tests/unit/py3_only/_simple_stubs_test.py @@ -36,18 +36,27 @@ import grpc _CACHE_EPOCHS = 8 _CACHE_TRIALS = 6 +_SERVER_RESPONSE_COUNT = 10 _UNARY_UNARY = "/test/UnaryUnary" +_UNARY_STREAM = "/test/UnaryStream" def _unary_unary_handler(request, context): return request +def _unary_stream_handler(request, context): + for _ in range(_SERVER_RESPONSE_COUNT): + yield request + + class _GenericHandler(grpc.GenericRpcHandler): def service(self, handler_call_details): if handler_call_details.method == _UNARY_UNARY: return grpc.unary_unary_rpc_method_handler(_unary_unary_handler) + elif handler_call_details.method == _UNARY_STREAM: + return grpc.unary_stream_rpc_method_handler(_unary_stream_handler) else: raise NotImplementedError() @@ -176,6 +185,16 @@ class SimpleStubsTest(unittest.TestCase): lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count() <= _MAXIMUM_CHANNELS + 1, message=lambda: f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} channels remain") + def test_unary_stream(self): + with _server(grpc.local_server_credentials()) as (_, port): + target = f'localhost:{port}' + request = b'0000' + for response in grpc.unary_stream(request, + target, + _UNARY_STREAM, + channel_credentials=grpc.local_channel_credentials()): + self.assertEqual(request, response) + # TODO: Test request_serializer # TODO: Test request_deserializer