Merge pull request #22580 from lidizheng/aio-iterator

[Aio] Accepts normal iterable of request messages
pull/22596/head
Lidi Zheng 5 years ago committed by GitHub
commit 4d91e531ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 15
      src/python/grpcio/grpc/experimental/aio/_base_channel.py
  2. 33
      src/python/grpcio/grpc/experimental/aio/_call.py
  3. 12
      src/python/grpcio/grpc/experimental/aio/_channel.py
  4. 5
      src/python/grpcio/grpc/experimental/aio/_typing.py
  5. 26
      src/python/grpcio_tests/tests_aio/unit/call_test.py

@ -14,12 +14,13 @@
"""Abstract base classes for Channel objects and Multicallable objects.""" """Abstract base classes for Channel objects and Multicallable objects."""
import abc import abc
from typing import Any, AsyncIterable, Optional from typing import Any, Optional
import grpc import grpc
from . import _base_call from . import _base_call
from ._typing import DeserializingFunction, MetadataType, SerializingFunction from ._typing import (DeserializingFunction, MetadataType, RequestIterableType,
SerializingFunction)
_IMMUTABLE_EMPTY_TUPLE = tuple() _IMMUTABLE_EMPTY_TUPLE = tuple()
@ -105,7 +106,7 @@ class StreamUnaryMultiCallable(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def __call__(self, def __call__(self,
request_async_iterator: Optional[AsyncIterable[Any]] = None, request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
credentials: Optional[grpc.CallCredentials] = None, credentials: Optional[grpc.CallCredentials] = None,
@ -115,7 +116,8 @@ class StreamUnaryMultiCallable(abc.ABC):
"""Asynchronously invokes the underlying RPC. """Asynchronously invokes the underlying RPC.
Args: Args:
request: The request value for the RPC. request_iterator: An optional async iterable or iterable of request
messages for the RPC.
timeout: An optional duration of time in seconds to allow timeout: An optional duration of time in seconds to allow
for the RPC. for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the metadata: Optional :term:`metadata` to be transmitted to the
@ -142,7 +144,7 @@ class StreamStreamMultiCallable(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def __call__(self, def __call__(self,
request_async_iterator: Optional[AsyncIterable[Any]] = None, request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
credentials: Optional[grpc.CallCredentials] = None, credentials: Optional[grpc.CallCredentials] = None,
@ -152,7 +154,8 @@ class StreamStreamMultiCallable(abc.ABC):
"""Asynchronously invokes the underlying RPC. """Asynchronously invokes the underlying RPC.
Args: Args:
request: The request value for the RPC. request_iterator: An optional async iterable or iterable of request
messages for the RPC.
timeout: An optional duration of time in seconds to allow timeout: An optional duration of time in seconds to allow
for the RPC. for the RPC.
metadata: Optional :term:`metadata` to be transmitted to the metadata: Optional :term:`metadata` to be transmitted to the

@ -15,6 +15,7 @@
import asyncio import asyncio
import enum import enum
import inspect
import logging import logging
from functools import partial from functools import partial
from typing import AsyncIterable, Awaitable, Optional, Tuple from typing import AsyncIterable, Awaitable, Optional, Tuple
@ -25,8 +26,8 @@ from grpc._cython import cygrpc
from . import _base_call from . import _base_call
from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType, from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType,
MetadatumType, RequestType, ResponseType, MetadatumType, RequestIterableType, RequestType,
SerializingFunction) ResponseType, SerializingFunction)
__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' __all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
@ -363,14 +364,14 @@ class _StreamRequestMixin(Call):
_request_style: _APIStyle _request_style: _APIStyle
def _init_stream_request_mixin( def _init_stream_request_mixin(
self, request_async_iterator: Optional[AsyncIterable[RequestType]]): self, request_iterator: Optional[RequestIterableType]):
self._metadata_sent = asyncio.Event(loop=self._loop) self._metadata_sent = asyncio.Event(loop=self._loop)
self._done_writing_flag = False self._done_writing_flag = False
# If user passes in an async iterator, create a consumer Task. # If user passes in an async iterator, create a consumer Task.
if request_async_iterator is not None: if request_iterator is not None:
self._async_request_poller = self._loop.create_task( self._async_request_poller = self._loop.create_task(
self._consume_request_iterator(request_async_iterator)) self._consume_request_iterator(request_iterator))
self._request_style = _APIStyle.ASYNC_GENERATOR self._request_style = _APIStyle.ASYNC_GENERATOR
else: else:
self._async_request_poller = None self._async_request_poller = None
@ -392,11 +393,17 @@ class _StreamRequestMixin(Call):
def _metadata_sent_observer(self): def _metadata_sent_observer(self):
self._metadata_sent.set() self._metadata_sent.set()
async def _consume_request_iterator( async def _consume_request_iterator(self,
self, request_async_iterator: AsyncIterable[RequestType]) -> None: request_iterator: RequestIterableType
) -> None:
try: try:
async for request in request_async_iterator: if inspect.isasyncgen(request_iterator):
async for request in request_iterator:
await self._write(request) await self._write(request)
else:
for request in request_iterator:
await self._write(request)
await self._done_writing() await self._done_writing()
except AioRpcError as rpc_error: except AioRpcError as rpc_error:
# Rpc status should be exposed through other API. Exceptions raised # Rpc status should be exposed through other API. Exceptions raised
@ -538,8 +545,7 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
""" """
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
def __init__(self, def __init__(self, request_iterator: Optional[RequestIterableType],
request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float], metadata: MetadataType, deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials], credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
@ -550,7 +556,7 @@ class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
channel.call(method, deadline, credentials, wait_for_ready), channel.call(method, deadline, credentials, wait_for_ready),
metadata, request_serializer, response_deserializer, loop) metadata, request_serializer, response_deserializer, loop)
self._init_stream_request_mixin(request_async_iterator) self._init_stream_request_mixin(request_iterator)
self._init_unary_response_mixin(self._conduct_rpc()) self._init_unary_response_mixin(self._conduct_rpc())
async def _conduct_rpc(self) -> ResponseType: async def _conduct_rpc(self) -> ResponseType:
@ -577,8 +583,7 @@ class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call,
_initializer: asyncio.Task _initializer: asyncio.Task
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
def __init__(self, def __init__(self, request_iterator: Optional[RequestIterableType],
request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float], metadata: MetadataType, deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials], credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
@ -589,7 +594,7 @@ class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call,
channel.call(method, deadline, credentials, wait_for_ready), channel.call(method, deadline, credentials, wait_for_ready),
metadata, request_serializer, response_deserializer, loop) metadata, request_serializer, response_deserializer, loop)
self._initializer = self._loop.create_task(self._prepare_rpc()) self._initializer = self._loop.create_task(self._prepare_rpc())
self._init_stream_request_mixin(request_async_iterator) self._init_stream_request_mixin(request_iterator)
self._init_stream_response_mixin(self._initializer) self._init_stream_response_mixin(self._initializer)
async def _prepare_rpc(self): async def _prepare_rpc(self):

@ -15,7 +15,7 @@
import asyncio import asyncio
import sys import sys
from typing import Any, AsyncIterable, Iterable, Optional, Sequence from typing import Any, Iterable, Optional, Sequence
import grpc import grpc
from grpc import _common, _compression, _grpcio_metadata from grpc import _common, _compression, _grpcio_metadata
@ -27,7 +27,7 @@ from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall,
from ._interceptor import (InterceptedUnaryUnaryCall, from ._interceptor import (InterceptedUnaryUnaryCall,
UnaryUnaryClientInterceptor) UnaryUnaryClientInterceptor)
from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType, from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
SerializingFunction) SerializingFunction, RequestIterableType)
from ._utils import _timeout_to_deadline from ._utils import _timeout_to_deadline
_IMMUTABLE_EMPTY_TUPLE = tuple() _IMMUTABLE_EMPTY_TUPLE = tuple()
@ -146,7 +146,7 @@ class StreamUnaryMultiCallable(_BaseMultiCallable,
_base_channel.StreamUnaryMultiCallable): _base_channel.StreamUnaryMultiCallable):
def __call__(self, def __call__(self,
request_async_iterator: Optional[AsyncIterable[Any]] = None, request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
credentials: Optional[grpc.CallCredentials] = None, credentials: Optional[grpc.CallCredentials] = None,
@ -158,7 +158,7 @@ class StreamUnaryMultiCallable(_BaseMultiCallable,
deadline = _timeout_to_deadline(timeout) deadline = _timeout_to_deadline(timeout)
call = StreamUnaryCall(request_async_iterator, deadline, metadata, call = StreamUnaryCall(request_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel, credentials, wait_for_ready, self._channel,
self._method, self._request_serializer, self._method, self._request_serializer,
self._response_deserializer, self._loop) self._response_deserializer, self._loop)
@ -170,7 +170,7 @@ class StreamStreamMultiCallable(_BaseMultiCallable,
_base_channel.StreamStreamMultiCallable): _base_channel.StreamStreamMultiCallable):
def __call__(self, def __call__(self,
request_async_iterator: Optional[AsyncIterable[Any]] = None, request_iterator: Optional[RequestIterableType] = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE, metadata: Optional[MetadataType] = _IMMUTABLE_EMPTY_TUPLE,
credentials: Optional[grpc.CallCredentials] = None, credentials: Optional[grpc.CallCredentials] = None,
@ -182,7 +182,7 @@ class StreamStreamMultiCallable(_BaseMultiCallable,
deadline = _timeout_to_deadline(timeout) deadline = _timeout_to_deadline(timeout)
call = StreamStreamCall(request_async_iterator, deadline, metadata, call = StreamStreamCall(request_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel, credentials, wait_for_ready, self._channel,
self._method, self._request_serializer, self._method, self._request_serializer,
self._response_deserializer, self._loop) self._response_deserializer, self._loop)

@ -13,7 +13,9 @@
# limitations under the License. # limitations under the License.
"""Common types for gRPC Async API""" """Common types for gRPC Async API"""
from typing import Any, AnyStr, Callable, Sequence, Tuple, TypeVar from typing import (Any, AnyStr, AsyncIterable, Callable, Iterable, Sequence,
Tuple, TypeVar, Union)
from grpc._cython.cygrpc import EOF from grpc._cython.cygrpc import EOF
RequestType = TypeVar('RequestType') RequestType = TypeVar('RequestType')
@ -25,3 +27,4 @@ MetadataType = Sequence[MetadatumType]
ChannelArgumentType = Sequence[Tuple[str, Any]] ChannelArgumentType = Sequence[Tuple[str, Any]]
EOFType = type(EOF) EOFType = type(EOF)
DoneCallbackType = Callable[[Any], None] DoneCallbackType = Callable[[Any], None]
RequestIterableType = Union[Iterable[Any], AsyncIterable[Any]]

@ -559,6 +559,23 @@ class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase):
# No failures in the cancel later task! # No failures in the cancel later task!
await cancel_later_task await cancel_later_task
async def test_normal_iterable_requests(self):
# Prepares the request
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
requests = [request] * _NUM_STREAM_RESPONSES
# Sends out requests
call = self._stub.StreamingInputCall(requests)
# RPC should succeed
response = await call
self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
self.assertEqual(_NUM_STREAM_RESPONSES * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
# Prepares the request that stream in a ping-pong manner. # Prepares the request that stream in a ping-pong manner.
_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest() _STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest()
@ -738,6 +755,15 @@ class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase):
# No failures in the cancel later task! # No failures in the cancel later task!
await cancel_later_task await cancel_later_task
async def test_normal_iterable_requests(self):
requests = [_STREAM_OUTPUT_REQUEST_ONE_RESPONSE] * _NUM_STREAM_RESPONSES
call = self._stub.FullDuplexCall(iter(requests))
async for response in call:
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(await call.code(), grpc.StatusCode.OK)
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)

Loading…
Cancel
Save