Implement unary_stream

pull/21954/head
Richard Belleville 5 years ago
parent 4ac50ceed6
commit b5f06c216e
  1. 4
      src/python/grpcio/grpc/__init__.py
  2. 50
      src/python/grpcio/grpc/_simple_stubs.py
  3. 19
      src/python/grpcio_tests/tests/unit/py3_only/_simple_stubs_test.py

@ -2036,8 +2036,8 @@ __all__ = (
) )
if sys.version_info[0] > 2: if sys.version_info[0] > 2:
from grpc._simple_stubs import unary_unary from grpc._simple_stubs import unary_unary, unary_stream
__all__ = __all__ + (unary_unary,) __all__ = __all__ + (unary_unary, unary_stream)
############################### Extension Shims ################################ ############################### Extension Shims ################################

@ -7,7 +7,7 @@ import logging
import threading import threading
import grpc import grpc
from typing import Any, Callable, Optional, Sequence, Text, Tuple, Union from typing import Any, AnyStr, Callable, Iterator, Optional, Sequence, Tuple, Union
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -26,8 +26,8 @@ if _MAXIMUM_CHANNELS_KEY in os.environ:
else: else:
_MAXIMUM_CHANNELS = 2 ** 8 _MAXIMUM_CHANNELS = 2 ** 8
def _create_channel(target: Text, def _create_channel(target: str,
options: Sequence[Tuple[Text, Text]], options: Sequence[Tuple[str, str]],
channel_credentials: Optional[grpc.ChannelCredentials], channel_credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression]) -> grpc.Channel: compression: Optional[grpc.Compression]) -> grpc.Channel:
if channel_credentials is None: if channel_credentials is None:
@ -98,8 +98,8 @@ class ChannelCache:
def get_channel(self, def get_channel(self,
target: Text, target: str,
options: Sequence[Tuple[Text, Text]], options: Sequence[Tuple[str, str]],
channel_credentials: Optional[grpc.ChannelCredentials], channel_credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression]) -> grpc.Channel: compression: Optional[grpc.Compression]) -> grpc.Channel:
key = (target, options, channel_credentials, compression) key = (target, options, channel_credentials, compression)
@ -123,20 +123,19 @@ class ChannelCache:
return len(self._mapping) return len(self._mapping)
# TODO: s/Text/str/g
def unary_unary(request: Any, def unary_unary(request: Any,
target: Text, target: str,
method: Text, method: str,
request_serializer: Optional[Callable[[Any], bytes]] = None, request_serializer: Optional[Callable[[Any], bytes]] = None,
request_deserializer: Optional[Callable[[bytes], Any]] = None, request_deserializer: Optional[Callable[[bytes], Any]] = None,
options: Sequence[Tuple[Text, Text]] = (), options: Sequence[Tuple[AnyStr, AnyStr]] = (),
# TODO: Somehow make insecure_channel opt-in, not the default. # TODO: Somehow make insecure_channel opt-in, not the default.
channel_credentials: Optional[grpc.ChannelCredentials] = None, channel_credentials: Optional[grpc.ChannelCredentials] = None,
call_credentials: Optional[grpc.CallCredentials] = None, call_credentials: Optional[grpc.CallCredentials] = None,
compression: Optional[grpc.Compression] = None, compression: Optional[grpc.Compression] = None,
wait_for_ready: Optional[bool] = None, wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
metadata: Optional[Sequence[Tuple[Text, Union[Text, bytes]]]] = None) -> Any: metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None) -> Any:
"""Invokes a unary RPC without an explicitly specified channel. """Invokes a unary RPC without an explicitly specified channel.
This is backed by a cache of channels evicted by a background thread This is backed by a cache of channels evicted by a background thread
@ -144,8 +143,6 @@ def unary_unary(request: Any,
TODO: Document the parameters and return value. TODO: Document the parameters and return value.
""" """
# TODO: Warn if the timeout is greater than the channel eviction time.
channel = ChannelCache.get().get_channel(target, options, channel_credentials, compression) channel = ChannelCache.get().get_channel(target, options, channel_credentials, compression)
multicallable = channel.unary_unary(method, request_serializer, request_deserializer) multicallable = channel.unary_unary(method, request_serializer, request_deserializer)
return multicallable(request, return multicallable(request,
@ -153,3 +150,32 @@ def unary_unary(request: Any,
wait_for_ready=wait_for_ready, wait_for_ready=wait_for_ready,
credentials=call_credentials, credentials=call_credentials,
timeout=timeout) timeout=timeout)
def unary_stream(request: Any,
target: str,
method: str,
request_serializer: Optional[Callable[[Any], bytes]] = None,
request_deserializer: Optional[Callable[[bytes], Any]] = None,
options: Sequence[Tuple[AnyStr, AnyStr]] = (),
# TODO: Somehow make insecure_channel opt-in, not the default.
channel_credentials: Optional[grpc.ChannelCredentials] = None,
call_credentials: Optional[grpc.CallCredentials] = None,
compression: Optional[grpc.Compression] = None,
wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = None,
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None) -> Iterator[Any]:
"""Invokes a unary-stream RPC without an explicitly specified channel.
This is backed by a cache of channels evicted by a background thread
on a periodic basis.
TODO: Document the parameters and return value.
"""
channel = ChannelCache.get().get_channel(target, options, channel_credentials, compression)
multicallable = channel.unary_stream(method, request_serializer, request_deserializer)
return multicallable(request,
metadata=metadata,
wait_for_ready=wait_for_ready,
credentials=call_credentials,
timeout=timeout)

@ -36,18 +36,27 @@ import grpc
_CACHE_EPOCHS = 8 _CACHE_EPOCHS = 8
_CACHE_TRIALS = 6 _CACHE_TRIALS = 6
_SERVER_RESPONSE_COUNT = 10
_UNARY_UNARY = "/test/UnaryUnary" _UNARY_UNARY = "/test/UnaryUnary"
_UNARY_STREAM = "/test/UnaryStream"
def _unary_unary_handler(request, context): def _unary_unary_handler(request, context):
return request return request
def _unary_stream_handler(request, context):
for _ in range(_SERVER_RESPONSE_COUNT):
yield request
class _GenericHandler(grpc.GenericRpcHandler): class _GenericHandler(grpc.GenericRpcHandler):
def service(self, handler_call_details): def service(self, handler_call_details):
if handler_call_details.method == _UNARY_UNARY: if handler_call_details.method == _UNARY_UNARY:
return grpc.unary_unary_rpc_method_handler(_unary_unary_handler) return grpc.unary_unary_rpc_method_handler(_unary_unary_handler)
elif handler_call_details.method == _UNARY_STREAM:
return grpc.unary_stream_rpc_method_handler(_unary_stream_handler)
else: else:
raise NotImplementedError() raise NotImplementedError()
@ -176,6 +185,16 @@ class SimpleStubsTest(unittest.TestCase):
lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count() <= _MAXIMUM_CHANNELS + 1, lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count() <= _MAXIMUM_CHANNELS + 1,
message=lambda: f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} channels remain") message=lambda: f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} channels remain")
def test_unary_stream(self):
with _server(grpc.local_server_credentials()) as (_, port):
target = f'localhost:{port}'
request = b'0000'
for response in grpc.unary_stream(request,
target,
_UNARY_STREAM,
channel_credentials=grpc.local_channel_credentials()):
self.assertEqual(request, response)
# TODO: Test request_serializer # TODO: Test request_serializer
# TODO: Test request_deserializer # TODO: Test request_deserializer

Loading…
Cancel
Save