diff --git a/src/python/grpcio/grpc/experimental/aio/_base_channel.py b/src/python/grpcio/grpc/experimental/aio/_base_channel.py index 1168c260e97..663afe096eb 100644 --- a/src/python/grpcio/grpc/experimental/aio/_base_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_base_channel.py @@ -14,12 +14,13 @@ """Abstract base classes for Channel objects and Multicallable objects.""" import abc -from typing import Any, AsyncIterable, Optional +from typing import Any, AsyncIterable, Iterable, Optional import grpc from . import _base_call -from ._typing import DeserializingFunction, MetadataType, SerializingFunction +from ._typing import (DeserializingFunction, MetadataType, RequestIterableType, + SerializingFunction) _IMMUTABLE_EMPTY_TUPLE = tuple() @@ -105,7 +106,7 @@ class StreamUnaryMultiCallable(abc.ABC): @abc.abstractmethod def __call__(self, - request_async_iterator: Optional[AsyncIterable[Any]] = None, + request_iterator: Optional[RequestIterableType] = None, timeout: Optional[float] = None, metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, credentials: Optional[grpc.CallCredentials] = None, @@ -115,7 +116,8 @@ class StreamUnaryMultiCallable(abc.ABC): """Asynchronously invokes the underlying RPC. Args: - request: The request value for the RPC. + request_iterator: An optional async iterable or iterable of request + messages for the RPC. timeout: An optional duration of time in seconds to allow for the RPC. metadata: Optional :term:`metadata` to be transmitted to the @@ -142,7 +144,7 @@ class StreamStreamMultiCallable(abc.ABC): @abc.abstractmethod def __call__(self, - request_async_iterator: Optional[AsyncIterable[Any]] = None, + request_iterator: Optional[RequestIterableType] = None, timeout: Optional[float] = None, metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, credentials: Optional[grpc.CallCredentials] = None, @@ -152,7 +154,8 @@ class StreamStreamMultiCallable(abc.ABC): """Asynchronously invokes the underlying RPC. Args: - request: The request value for the RPC. + request_iterator: An optional async iterable or iterable of request + messages for the RPC. timeout: An optional duration of time in seconds to allow for the RPC. metadata: Optional :term:`metadata` to be transmitted to the diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 25e8f7eeaa8..8fafcde2a1b 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -15,6 +15,7 @@ import asyncio import enum +import inspect import logging from functools import partial from typing import AsyncIterable, Awaitable, Optional, Tuple @@ -25,8 +26,8 @@ from grpc._cython import cygrpc from . import _base_call from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType, - MetadatumType, RequestType, ResponseType, - SerializingFunction) + MetadatumType, RequestIterableType, RequestType, + ResponseType, SerializingFunction) __all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' @@ -363,14 +364,14 @@ class _StreamRequestMixin(Call): _request_style: _APIStyle def _init_stream_request_mixin( - self, request_async_iterator: Optional[AsyncIterable[RequestType]]): + self, request_iterator: Optional[RequestIterableType]): self._metadata_sent = asyncio.Event(loop=self._loop) self._done_writing_flag = False # If user passes in an async iterator, create a consumer Task. - if request_async_iterator is not None: + if request_iterator is not None: self._async_request_poller = self._loop.create_task( - self._consume_request_iterator(request_async_iterator)) + self._consume_request_iterator(request_iterator)) self._request_style = _APIStyle.ASYNC_GENERATOR else: self._async_request_poller = None @@ -392,12 +393,18 @@ class _StreamRequestMixin(Call): def _metadata_sent_observer(self): self._metadata_sent.set() - async def _consume_request_iterator( - self, request_async_iterator: AsyncIterable[RequestType]) -> None: + async def _consume_request_iterator(self, + request_iterator: RequestIterableType + ) -> None: try: - async for request in request_async_iterator: - await self._write(request) - await self._done_writing() + if inspect.isasyncgen(request_iterator): + async for request in request_iterator: + await self._write(request) + await self._done_writing() + else: + for request in request_iterator: + await self._write(request) + await self._done_writing() except AioRpcError as rpc_error: # Rpc status should be exposed through other API. Exceptions raised # within this Task won't be retrieved by another coroutine. It's @@ -538,8 +545,7 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call, """ # pylint: disable=too-many-arguments - def __init__(self, - request_async_iterator: Optional[AsyncIterable[RequestType]], + def __init__(self, request_iterator: Optional[RequestIterableType], deadline: Optional[float], metadata: MetadataType, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, @@ -550,7 +556,7 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call, channel.call(method, deadline, credentials, wait_for_ready), metadata, request_serializer, response_deserializer, loop) - self._init_stream_request_mixin(request_async_iterator) + self._init_stream_request_mixin(request_iterator) self._init_unary_response_mixin(self._conduct_rpc()) async def _conduct_rpc(self) -> ResponseType: @@ -577,8 +583,7 @@ class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call, _initializer: asyncio.Task # pylint: disable=too-many-arguments - def __init__(self, - request_async_iterator: Optional[AsyncIterable[RequestType]], + def __init__(self, request_iterator: Optional[RequestIterableType], deadline: Optional[float], metadata: MetadataType, credentials: Optional[grpc.CallCredentials], wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, @@ -589,7 +594,7 @@ class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call, channel.call(method, deadline, credentials, wait_for_ready), metadata, request_serializer, response_deserializer, loop) self._initializer = self._loop.create_task(self._prepare_rpc()) - self._init_stream_request_mixin(request_async_iterator) + self._init_stream_request_mixin(request_iterator) self._init_stream_response_mixin(self._initializer) async def _prepare_rpc(self): diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 24a38e1f3d0..859a6ddd846 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -27,7 +27,7 @@ from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall, from ._interceptor import (InterceptedUnaryUnaryCall, UnaryUnaryClientInterceptor) from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType, - SerializingFunction) + SerializingFunction, RequestIterableType) from ._utils import _timeout_to_deadline _IMMUTABLE_EMPTY_TUPLE = tuple() @@ -146,7 +146,7 @@ class StreamUnaryMultiCallable(_BaseMultiCallable, _base_channel.StreamUnaryMultiCallable): def __call__(self, - request_async_iterator: Optional[AsyncIterable[Any]] = None, + request_iterator: Optional[RequestIterableType] = None, timeout: Optional[float] = None, metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, credentials: Optional[grpc.CallCredentials] = None, @@ -158,7 +158,7 @@ class StreamUnaryMultiCallable(_BaseMultiCallable, deadline = _timeout_to_deadline(timeout) - call = StreamUnaryCall(request_async_iterator, deadline, metadata, + call = StreamUnaryCall(request_iterator, deadline, metadata, credentials, wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) @@ -170,7 +170,7 @@ class StreamStreamMultiCallable(_BaseMultiCallable, _base_channel.StreamStreamMultiCallable): def __call__(self, - request_async_iterator: Optional[AsyncIterable[Any]] = None, + request_iterator: Optional[RequestIterableType] = None, timeout: Optional[float] = None, metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, credentials: Optional[grpc.CallCredentials] = None, @@ -182,7 +182,7 @@ class StreamStreamMultiCallable(_BaseMultiCallable, deadline = _timeout_to_deadline(timeout) - call = StreamStreamCall(request_async_iterator, deadline, metadata, + call = StreamStreamCall(request_iterator, deadline, metadata, credentials, wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) diff --git a/src/python/grpcio/grpc/experimental/aio/_typing.py b/src/python/grpcio/grpc/experimental/aio/_typing.py index ccd2f529936..205f6dc6227 100644 --- a/src/python/grpcio/grpc/experimental/aio/_typing.py +++ b/src/python/grpcio/grpc/experimental/aio/_typing.py @@ -13,7 +13,9 @@ # limitations under the License. """Common types for gRPC Async API""" -from typing import Any, AnyStr, Callable, Sequence, Tuple, TypeVar +from typing import (Any, AnyStr, AsyncIterable, Callable, Iterable, Sequence, + Tuple, TypeVar, Union) + from grpc._cython.cygrpc import EOF RequestType = TypeVar('RequestType') @@ -25,3 +27,4 @@ MetadataType = Sequence[MetadatumType] ChannelArgumentType = Sequence[Tuple[str, Any]] EOFType = type(EOF) DoneCallbackType = Callable[[Any], None] +RequestIterableType = Union[Iterable[Any], AsyncIterable[Any]] diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index f64f4e44802..5b52f0e1724 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -559,6 +559,23 @@ class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase): # No failures in the cancel later task! await cancel_later_task + async def test_normal_iterable_requests(self): + # Prepares the request + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + requests = [request] * _NUM_STREAM_RESPONSES + + # Sends out requests + call = self._stub.StreamingInputCall(requests) + + # RPC should succeed + response = await call + self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) + self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + # Prepares the request that stream in a ping-pong manner. _STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest() @@ -738,6 +755,15 @@ class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase): # No failures in the cancel later task! await cancel_later_task + async def test_normal_iterable_requests(self): + requests = [_STREAM_OUTPUT_REQUEST_ONE_RESPONSE] * _NUM_STREAM_RESPONSES + + call = self._stub.FullDuplexCall(iter(requests)) + async for response in call: + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + if __name__ == '__main__': logging.basicConfig(level=logging.DEBUG)