Merge pull request #22580 from lidizheng/aio-iterator

[Aio] Accepts normal iterable of request messages
pull/22596/head
Lidi Zheng 5 years ago committed by GitHub
commit 4d91e531ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 15
      src/python/grpcio/grpc/experimental/aio/_base_channel.py
  2. 33
      src/python/grpcio/grpc/experimental/aio/_call.py
  3. 12
      src/python/grpcio/grpc/experimental/aio/_channel.py
  4. 5
      src/python/grpcio/grpc/experimental/aio/_typing.py
  5. 26
      src/python/grpcio_tests/tests_aio/unit/call_test.py

@ -14,12 +14,13 @@
"""Abstract base classes for Channel objects and Multicallable objects."""
import abc
from typing import Any, AsyncIterable, Optional
from typing import Any, Optional
import grpc
from . import _base_call
from ._typing import DeserializingFunction, MetadataType, SerializingFunction
from ._typing import (DeserializingFunction, MetadataType, RequestIterableType,
SerializingFunction)
_IMMUTABLE_EMPTY_TUPLE = tuple()
@ -105,7 +106,7 @@ class StreamUnaryMultiCallable(abc.ABC):
@abc.abstractmethod
def __call__(self,
request_async_iterator: Optional[AsyncIterable[Any]] = None,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
credentials: Optional[grpc.CallCredentials] = None,
@ -115,7 +116,8 @@ class StreamUnaryMultiCallable(abc.ABC):
"""Asynchronously invokes the underlying RPC.
Args:
request: The request value for the RPC.
request_iterator: An optional async iterable or iterable of request
messages for the RPC.
timeout: An optional duration of time in seconds to allow
for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the
@ -142,7 +144,7 @@ class StreamStreamMultiCallable(abc.ABC):
@abc.abstractmethod
def __call__(self,
request_async_iterator: Optional[AsyncIterable[Any]] = None,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
credentials: Optional[grpc.CallCredentials] = None,
@ -152,7 +154,8 @@ class StreamStreamMultiCallable(abc.ABC):
"""Asynchronously invokes the underlying RPC.
Args:
request: The request value for the RPC.
request_iterator: An optional async iterable or iterable of request
messages for the RPC.
timeout: An optional duration of time in seconds to allow
for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the

@ -15,6 +15,7 @@
import asyncio
import enum
import inspect
import logging
from functools import partial
from typing import AsyncIterable, Awaitable, Optional, Tuple
@ -25,8 +26,8 @@ from grpc._cython import cygrpc
from . import _base_call
from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType,
MetadatumType, RequestType, ResponseType,
SerializingFunction)
MetadatumType, RequestIterableType, RequestType,
ResponseType, SerializingFunction)
__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
@ -363,14 +364,14 @@ class _StreamRequestMixin(Call):
_request_style: _APIStyle
def _init_stream_request_mixin(
self, request_async_iterator: Optional[AsyncIterable[RequestType]]):
self, request_iterator: Optional[RequestIterableType]):
self._metadata_sent = asyncio.Event(loop=self._loop)
self._done_writing_flag = False
# If user passes in an async iterator, create a consumer Task.
if request_async_iterator is not None:
if request_iterator is not None:
self._async_request_poller = self._loop.create_task(
self._consume_request_iterator(request_async_iterator))
self._consume_request_iterator(request_iterator))
self._request_style = _APIStyle.ASYNC_GENERATOR
else:
self._async_request_poller = None
@ -392,11 +393,17 @@ class _StreamRequestMixin(Call):
def _metadata_sent_observer(self):
self._metadata_sent.set()
async def _consume_request_iterator(
self, request_async_iterator: AsyncIterable[RequestType]) -> None:
async def _consume_request_iterator(self,
request_iterator: RequestIterableType
) -> None:
try:
async for request in request_async_iterator:
if inspect.isasyncgen(request_iterator):
async for request in request_iterator:
await self._write(request)
else:
for request in request_iterator:
await self._write(request)
await self._done_writing()
except AioRpcError as rpc_error:
# Rpc status should be exposed through other API. Exceptions raised
@ -538,8 +545,7 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
"""
# pylint: disable=too-many-arguments
def __init__(self,
request_async_iterator: Optional[AsyncIterable[RequestType]],
def __init__(self, request_iterator: Optional[RequestIterableType],
deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
@ -550,7 +556,7 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
channel.call(method, deadline, credentials, wait_for_ready),
metadata, request_serializer, response_deserializer, loop)
self._init_stream_request_mixin(request_async_iterator)
self._init_stream_request_mixin(request_iterator)
self._init_unary_response_mixin(self._conduct_rpc())
async def _conduct_rpc(self) -> ResponseType:
@ -577,8 +583,7 @@ class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call,
_initializer: asyncio.Task
# pylint: disable=too-many-arguments
def __init__(self,
request_async_iterator: Optional[AsyncIterable[RequestType]],
def __init__(self, request_iterator: Optional[RequestIterableType],
deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
@ -589,7 +594,7 @@ class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call,
channel.call(method, deadline, credentials, wait_for_ready),
metadata, request_serializer, response_deserializer, loop)
self._initializer = self._loop.create_task(self._prepare_rpc())
self._init_stream_request_mixin(request_async_iterator)
self._init_stream_request_mixin(request_iterator)
self._init_stream_response_mixin(self._initializer)
async def _prepare_rpc(self):

@ -15,7 +15,7 @@
import asyncio
import sys
from typing import Any, AsyncIterable, Iterable, Optional, Sequence
from typing import Any, Iterable, Optional, Sequence
import grpc
from grpc import _common, _compression, _grpcio_metadata
@ -27,7 +27,7 @@ from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall,
from ._interceptor import (InterceptedUnaryUnaryCall,
UnaryUnaryClientInterceptor)
from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
SerializingFunction)
SerializingFunction, RequestIterableType)
from ._utils import _timeout_to_deadline
_IMMUTABLE_EMPTY_TUPLE = tuple()
@ -146,7 +146,7 @@ class StreamUnaryMultiCallable(_BaseMultiCallable,
_base_channel.StreamUnaryMultiCallable):
def __call__(self,
request_async_iterator: Optional[AsyncIterable[Any]] = None,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
credentials: Optional[grpc.CallCredentials] = None,
@ -158,7 +158,7 @@ class StreamUnaryMultiCallable(_BaseMultiCallable,
deadline = _timeout_to_deadline(timeout)
call = StreamUnaryCall(request_async_iterator, deadline, metadata,
call = StreamUnaryCall(request_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)
@ -170,7 +170,7 @@ class StreamStreamMultiCallable(_BaseMultiCallable,
_base_channel.StreamStreamMultiCallable):
def __call__(self,
request_async_iterator: Optional[AsyncIterable[Any]] = None,
request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
credentials: Optional[grpc.CallCredentials] = None,
@ -182,7 +182,7 @@ class StreamStreamMultiCallable(_BaseMultiCallable,
deadline = _timeout_to_deadline(timeout)
call = StreamStreamCall(request_async_iterator, deadline, metadata,
call = StreamStreamCall(request_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)

@ -13,7 +13,9 @@
# limitations under the License.
"""Common types for gRPC Async API"""
from typing import Any, AnyStr, Callable, Sequence, Tuple, TypeVar
from typing import (Any, AnyStr, AsyncIterable, Callable, Iterable, Sequence,
Tuple, TypeVar, Union)
from grpc._cython.cygrpc import EOF
RequestType = TypeVar('RequestType')
@ -25,3 +27,4 @@ MetadataType = Sequence[MetadatumType]
ChannelArgumentType = Sequence[Tuple[str, Any]]
EOFType = type(EOF)
DoneCallbackType = Callable[[Any], None]
RequestIterableType = Union[Iterable[Any], AsyncIterable[Any]]

@ -559,6 +559,23 @@ class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase):
# No failures in the cancel later task!
await cancel_later_task
async def test_normal_iterable_requests(self):
# Prepares the request
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
requests = [request] * _NUM_STREAM_RESPONSES
# Sends out requests
call = self._stub.StreamingInputCall(requests)
# RPC should succeed
response = await call
self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
# Prepares the request that stream in a ping-pong manner.
_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest()
@ -738,6 +755,15 @@ class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase):
# No failures in the cancel later task!
await cancel_later_task
async def test_normal_iterable_requests(self):
requests = [_STREAM_OUTPUT_REQUEST_ONE_RESPONSE] * _NUM_STREAM_RESPONSES
call = self._stub.FullDuplexCall(iter(requests))
async for response in call:
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)

Loading…
Cancel
Save