Implement unary_stream

pull/21954/head
Richard Belleville 5 years ago
parent 4ac50ceed6
commit b5f06c216e
  1. 4
      src/python/grpcio/grpc/__init__.py
  2. 50
      src/python/grpcio/grpc/_simple_stubs.py
  3. 19
      src/python/grpcio_tests/tests/unit/py3_only/_simple_stubs_test.py

@ -2036,8 +2036,8 @@ __all__ = (
)
if sys.version_info[0] > 2:
from grpc._simple_stubs import unary_unary
__all__ = __all__ + (unary_unary,)
from grpc._simple_stubs import unary_unary, unary_stream
__all__ = __all__ + (unary_unary, unary_stream)
############################### Extension Shims ################################

@ -7,7 +7,7 @@ import logging
import threading
import grpc
from typing import Any, Callable, Optional, Sequence, Text, Tuple, Union
from typing import Any, AnyStr, Callable, Iterator, Optional, Sequence, Tuple, Union
_LOGGER = logging.getLogger(__name__)
@ -26,8 +26,8 @@ if _MAXIMUM_CHANNELS_KEY in os.environ:
else:
_MAXIMUM_CHANNELS = 2 ** 8
def _create_channel(target: Text,
options: Sequence[Tuple[Text, Text]],
def _create_channel(target: str,
options: Sequence[Tuple[str, str]],
channel_credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression]) -> grpc.Channel:
if channel_credentials is None:
@ -98,8 +98,8 @@ class ChannelCache:
def get_channel(self,
target: Text,
options: Sequence[Tuple[Text, Text]],
target: str,
options: Sequence[Tuple[str, str]],
channel_credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression]) -> grpc.Channel:
key = (target, options, channel_credentials, compression)
@ -123,20 +123,19 @@ class ChannelCache:
return len(self._mapping)
# TODO: s/Text/str/g
def unary_unary(request: Any,
target: Text,
method: Text,
target: str,
method: str,
request_serializer: Optional[Callable[[Any], bytes]] = None,
request_deserializer: Optional[Callable[[bytes], Any]] = None,
options: Sequence[Tuple[Text, Text]] = (),
options: Sequence[Tuple[AnyStr, AnyStr]] = (),
# TODO: Somehow make insecure_channel opt-in, not the default.
channel_credentials: Optional[grpc.ChannelCredentials] = None,
call_credentials: Optional[grpc.CallCredentials] = None,
compression: Optional[grpc.Compression] = None,
wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = None,
metadata: Optional[Sequence[Tuple[Text, Union[Text, bytes]]]] = None) -> Any:
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None) -> Any:
"""Invokes a unary RPC without an explicitly specified channel.
This is backed by a cache of channels evicted by a background thread
@ -144,8 +143,6 @@ def unary_unary(request: Any,
TODO: Document the parameters and return value.
"""
# TODO: Warn if the timeout is greater than the channel eviction time.
channel = ChannelCache.get().get_channel(target, options, channel_credentials, compression)
multicallable = channel.unary_unary(method, request_serializer, request_deserializer)
return multicallable(request,
@ -153,3 +150,32 @@ def unary_unary(request: Any,
wait_for_ready=wait_for_ready,
credentials=call_credentials,
timeout=timeout)
def unary_stream(request: Any,
target: str,
method: str,
request_serializer: Optional[Callable[[Any], bytes]] = None,
request_deserializer: Optional[Callable[[bytes], Any]] = None,
options: Sequence[Tuple[AnyStr, AnyStr]] = (),
# TODO: Somehow make insecure_channel opt-in, not the default.
channel_credentials: Optional[grpc.ChannelCredentials] = None,
call_credentials: Optional[grpc.CallCredentials] = None,
compression: Optional[grpc.Compression] = None,
wait_for_ready: Optional[bool] = None,
timeout: Optional[float] = None,
metadata: Optional[Sequence[Tuple[str, Union[str, bytes]]]] = None) -> Iterator[Any]:
"""Invokes a unary-stream RPC without an explicitly specified channel.
This is backed by a cache of channels evicted by a background thread
on a periodic basis.
TODO: Document the parameters and return value.
"""
channel = ChannelCache.get().get_channel(target, options, channel_credentials, compression)
multicallable = channel.unary_stream(method, request_serializer, request_deserializer)
return multicallable(request,
metadata=metadata,
wait_for_ready=wait_for_ready,
credentials=call_credentials,
timeout=timeout)

@ -36,18 +36,27 @@ import grpc
_CACHE_EPOCHS = 8
_CACHE_TRIALS = 6
_SERVER_RESPONSE_COUNT = 10
_UNARY_UNARY = "/test/UnaryUnary"
_UNARY_STREAM = "/test/UnaryStream"
def _unary_unary_handler(request, context):
return request
def _unary_stream_handler(request, context):
for _ in range(_SERVER_RESPONSE_COUNT):
yield request
class _GenericHandler(grpc.GenericRpcHandler):
def service(self, handler_call_details):
if handler_call_details.method == _UNARY_UNARY:
return grpc.unary_unary_rpc_method_handler(_unary_unary_handler)
elif handler_call_details.method == _UNARY_STREAM:
return grpc.unary_stream_rpc_method_handler(_unary_stream_handler)
else:
raise NotImplementedError()
@ -176,6 +185,16 @@ class SimpleStubsTest(unittest.TestCase):
lambda: grpc._simple_stubs.ChannelCache.get()._test_only_channel_count() <= _MAXIMUM_CHANNELS + 1,
message=lambda: f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} channels remain")
def test_unary_stream(self):
with _server(grpc.local_server_credentials()) as (_, port):
target = f'localhost:{port}'
request = b'0000'
for response in grpc.unary_stream(request,
target,
_UNARY_STREAM,
channel_credentials=grpc.local_channel_credentials()):
self.assertEqual(request, response)
# TODO: Test request_serializer
# TODO: Test request_deserializer

Loading…
Cancel
Save