diff --git a/.pylintrc b/.pylintrc index fcc8e73cb41..4924d3651e4 100644 --- a/.pylintrc +++ b/.pylintrc @@ -20,6 +20,7 @@ dummy-variables-rgx=^ignored_|^unused_ # be what works for us at the moment (excepting the dead-code-walking Beta # API). max-args=7 +max-parents=8 [MISCELLANEOUS] diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index b2f8a819c12..a57608b1ddd 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -15,15 +15,15 @@ import asyncio from functools import partial -from typing import AsyncIterable, Dict, Optional +from typing import AsyncIterable, Awaitable, Dict, Optional import grpc from grpc import _common from grpc._cython import cygrpc from . import _base_call -from ._typing import (DeserializingFunction, MetadataType, RequestType, - ResponseType, SerializingFunction, DoneCallbackType) +from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType, + RequestType, ResponseType, SerializingFunction) __all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' @@ -145,7 +145,7 @@ def _create_rpc_error(initial_metadata: Optional[MetadataType], status.trailing_metadata()) -class Call(_base_call.Call): +class Call: """Base implementation of client RPC Call object. Implements logic around final status, metadata and cancellation. @@ -153,11 +153,19 @@ class Call(_base_call.Call): _loop: asyncio.AbstractEventLoop _code: grpc.StatusCode _cython_call: cygrpc._AioCall + _metadata: MetadataType + _request_serializer: SerializingFunction + _response_deserializer: DeserializingFunction - def __init__(self, cython_call: cygrpc._AioCall, + def __init__(self, cython_call: cygrpc._AioCall, metadata: MetadataType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, loop: asyncio.AbstractEventLoop) -> None: self._loop = loop self._cython_call = cython_call + self._metadata = metadata + self._request_serializer = request_serializer + self._response_deserializer = response_deserializer def __del__(self) -> None: if not self._cython_call.done(): @@ -221,63 +229,24 @@ class Call(_base_call.Call): return self._repr() -class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): - """Object for managing unary-unary RPC calls. +class _UnaryResponseMixin(Call): + _call_finisher: asyncio.Task - Returned when an instance of `UnaryUnaryMultiCallable` object is called. - """ - _request: RequestType - _metadata: Optional[MetadataType] - _request_serializer: SerializingFunction - _response_deserializer: DeserializingFunction - _call: asyncio.Task - - # pylint: disable=too-many-arguments - def __init__(self, request: RequestType, deadline: Optional[float], - metadata: MetadataType, - credentials: Optional[grpc.CallCredentials], - channel: cygrpc.AioChannel, method: bytes, - request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction, - loop: asyncio.AbstractEventLoop) -> None: - super().__init__(channel.call(method, deadline, credentials), loop) - self._request = request - self._metadata = metadata - self._request_serializer = request_serializer - self._response_deserializer = response_deserializer - self._call = loop.create_task(self._invoke()) + def _init_unary_response_mixin(self, + response_coro: Awaitable[ResponseType]): + self._call_finisher = self._loop.create_task(response_coro) def cancel(self) -> bool: if super().cancel(): - self._call.cancel() + self._call_finisher.cancel() return True else: return False - async def _invoke(self) -> ResponseType: - serialized_request = _common.serialize(self._request, - self._request_serializer) - - # NOTE(lidiz) asyncio.CancelledError is not a good transport for status, - # because the asyncio.Task class do not cache the exception object. - # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785 - try: - serialized_response = await self._cython_call.unary_unary( - serialized_request, self._metadata) - except asyncio.CancelledError: - if not self.cancelled(): - self.cancel() - - # Raises here if RPC failed or cancelled - await self._raise_for_status() - - return _common.deserialize(serialized_response, - self._response_deserializer) - def __await__(self) -> ResponseType: """Wait till the ongoing RPC request finishes.""" try: - response = yield from self._call + response = yield from self._call_finisher except asyncio.CancelledError: # Even if we caught all other CancelledError, there is still # this corner case. If the application cancels immediately after @@ -289,53 +258,21 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): return response -class UnaryStreamCall(Call, _base_call.UnaryStreamCall): - """Object for managing unary-stream RPC calls. - - Returned when an instance of `UnaryStreamMultiCallable` object is called. - """ - _request: RequestType - _metadata: MetadataType - _request_serializer: SerializingFunction - _response_deserializer: DeserializingFunction - _send_unary_request_task: asyncio.Task +class _StreamResponseMixin(Call): _message_aiter: AsyncIterable[ResponseType] + _prerequisite: asyncio.Task - # pylint: disable=too-many-arguments - def __init__(self, request: RequestType, deadline: Optional[float], - metadata: MetadataType, - credentials: Optional[grpc.CallCredentials], - channel: cygrpc.AioChannel, method: bytes, - request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction, - loop: asyncio.AbstractEventLoop) -> None: - super().__init__(channel.call(method, deadline, credentials), loop) - self._request = request - self._metadata = metadata - self._request_serializer = request_serializer - self._response_deserializer = response_deserializer - self._send_unary_request_task = loop.create_task( - self._send_unary_request()) + def _init_stream_response_mixin(self, prerequisite: asyncio.Task): self._message_aiter = None + self._prerequisite = prerequisite def cancel(self) -> bool: if super().cancel(): - self._send_unary_request_task.cancel() + self._prerequisite.cancel() return True else: return False - async def _send_unary_request(self) -> ResponseType: - serialized_request = _common.serialize(self._request, - self._request_serializer) - try: - await self._cython_call.initiate_unary_stream( - serialized_request, self._metadata) - except asyncio.CancelledError: - if not self.cancelled(): - self.cancel() - raise - async def _fetch_stream_responses(self) -> ResponseType: message = await self._read() while message is not cygrpc.EOF: @@ -349,7 +286,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): async def _read(self) -> ResponseType: # Wait for the request being sent - await self._send_unary_request_task + await self._prerequisite # Reads response message from Core try: @@ -366,7 +303,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): self._response_deserializer) async def read(self) -> ResponseType: - if self._cython_call.done(): + if self.done(): await self._raise_for_status() return cygrpc.EOF @@ -378,39 +315,16 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): return response_message -class StreamUnaryCall(Call, _base_call.StreamUnaryCall): - """Object for managing stream-unary RPC calls. - - Returned when an instance of `StreamUnaryMultiCallable` object is called. - """ - _metadata: MetadataType - _request_serializer: SerializingFunction - _response_deserializer: DeserializingFunction - +class _StreamRequestMixin(Call): _metadata_sent: asyncio.Event _done_writing: bool - _call_finisher: asyncio.Task - _async_request_poller: asyncio.Task + _async_request_poller: Optional[asyncio.Task] - # pylint: disable=too-many-arguments - def __init__(self, - request_async_iterator: Optional[AsyncIterable[RequestType]], - deadline: Optional[float], metadata: MetadataType, - credentials: Optional[grpc.CallCredentials], - channel: cygrpc.AioChannel, method: bytes, - request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction, - loop: asyncio.AbstractEventLoop) -> None: - super().__init__(channel.call(method, deadline, credentials), loop) - self._metadata = metadata - self._request_serializer = request_serializer - self._response_deserializer = response_deserializer - - self._metadata_sent = asyncio.Event(loop=loop) + def _init_stream_request_mixin( + self, request_async_iterator: Optional[AsyncIterable[RequestType]]): + self._metadata_sent = asyncio.Event(loop=self._loop) self._done_writing = False - self._call_finisher = loop.create_task(self._conduct_rpc()) - # If user passes in an async iterator, create a consumer Task. if request_async_iterator is not None: self._async_request_poller = self._loop.create_task( @@ -420,7 +334,6 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall): def cancel(self) -> bool: if super().cancel(): - self._call_finisher.cancel() if self._async_request_poller is not None: self._async_request_poller.cancel() return True @@ -430,38 +343,14 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall): def _metadata_sent_observer(self): self._metadata_sent.set() - async def _conduct_rpc(self) -> ResponseType: - try: - serialized_response = await self._cython_call.stream_unary( - self._metadata, self._metadata_sent_observer) - except asyncio.CancelledError: - if not self.cancelled(): - self.cancel() - - # Raises RpcError if the RPC failed or cancelled - await self._raise_for_status() - - return _common.deserialize(serialized_response, - self._response_deserializer) - async def _consume_request_iterator( self, request_async_iterator: AsyncIterable[RequestType]) -> None: async for request in request_async_iterator: await self.write(request) await self.done_writing() - def __await__(self) -> ResponseType: - """Wait till the ongoing RPC request finishes.""" - try: - response = yield from self._call_finisher - except asyncio.CancelledError: - if not self.cancelled(): - self.cancel() - raise - return response - async def write(self, request: RequestType) -> None: - if self._cython_call.done(): + if self.done(): raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) if self._done_writing: raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) @@ -480,7 +369,7 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall): async def done_writing(self) -> None: """Implementation of done_writing is idempotent.""" - if self._cython_call.done(): + if self.done(): # If the RPC is finished, do nothing. return if not self._done_writing: @@ -494,152 +383,153 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall): await self._raise_for_status() -class StreamStreamCall(Call, _base_call.StreamStreamCall): - """Object for managing stream-stream RPC calls. +class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): + """Object for managing unary-unary RPC calls. - Returned when an instance of `StreamStreamMultiCallable` object is called. + Returned when an instance of `UnaryUnaryMultiCallable` object is called. """ - _metadata: MetadataType - _request_serializer: SerializingFunction - _response_deserializer: DeserializingFunction - - _metadata_sent: asyncio.Event - _done_writing: bool - _initializer: asyncio.Task - _async_request_poller: asyncio.Task - _message_aiter: AsyncIterable[ResponseType] + _request: RequestType + _call: asyncio.Task # pylint: disable=too-many-arguments - def __init__(self, - request_async_iterator: Optional[AsyncIterable[RequestType]], - deadline: Optional[float], metadata: MetadataType, + def __init__(self, request: RequestType, deadline: Optional[float], + metadata: MetadataType, credentials: Optional[grpc.CallCredentials], channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction, loop: asyncio.AbstractEventLoop) -> None: - super().__init__(channel.call(method, deadline, credentials), loop) - self._metadata = metadata - self._request_serializer = request_serializer - self._response_deserializer = response_deserializer + super().__init__(channel.call(method, deadline, credentials), metadata, + request_serializer, response_deserializer, loop) + self._request = request + self._init_unary_response_mixin(self._invoke()) - self._metadata_sent = asyncio.Event(loop=loop) - self._done_writing = False + async def _invoke(self) -> ResponseType: + serialized_request = _common.serialize(self._request, + self._request_serializer) - self._initializer = self._loop.create_task(self._prepare_rpc()) + # NOTE(lidiz) asyncio.CancelledError is not a good transport for status, + # because the asyncio.Task class do not cache the exception object. + # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785 + try: + serialized_response = await self._cython_call.unary_unary( + serialized_request, self._metadata) + except asyncio.CancelledError: + if not self.cancelled(): + self.cancel() - # If user passes in an async iterator, create a consumer coroutine. - if request_async_iterator is not None: - self._async_request_poller = loop.create_task( - self._consume_request_iterator(request_async_iterator)) - else: - self._async_request_poller = None - self._message_aiter = None + # Raises here if RPC failed or cancelled + await self._raise_for_status() - def cancel(self) -> bool: - if super().cancel(): - self._initializer.cancel() - if self._async_request_poller is not None: - self._async_request_poller.cancel() - return True - else: - return False + return _common.deserialize(serialized_response, + self._response_deserializer) - def _metadata_sent_observer(self): - self._metadata_sent.set() - async def _prepare_rpc(self): - """This method prepares the RPC for receiving/sending messages. +class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall): + """Object for managing unary-stream RPC calls. - All other operations around the stream should only happen after the - completion of this method. - """ + Returned when an instance of `UnaryStreamMultiCallable` object is called. + """ + _request: RequestType + _send_unary_request_task: asyncio.Task + + # pylint: disable=too-many-arguments + def __init__(self, request: RequestType, deadline: Optional[float], + metadata: MetadataType, + credentials: Optional[grpc.CallCredentials], + channel: cygrpc.AioChannel, method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop) -> None: + super().__init__(channel.call(method, deadline, credentials), metadata, + request_serializer, response_deserializer, loop) + self._request = request + self._send_unary_request_task = loop.create_task( + self._send_unary_request()) + self._init_stream_response_mixin(self._send_unary_request_task) + + async def _send_unary_request(self) -> ResponseType: + serialized_request = _common.serialize(self._request, + self._request_serializer) try: - await self._cython_call.initiate_stream_stream( - self._metadata, self._metadata_sent_observer) + await self._cython_call.initiate_unary_stream( + serialized_request, self._metadata) except asyncio.CancelledError: if not self.cancelled(): self.cancel() - # No need to raise RpcError here, because no one will `await` this task. + raise - async def _consume_request_iterator( - self, request_async_iterator: Optional[AsyncIterable[RequestType]] - ) -> None: - async for request in request_async_iterator: - await self.write(request) - await self.done_writing() - async def write(self, request: RequestType) -> None: - if self._cython_call.done(): - raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) - if self._done_writing: - raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) - if not self._metadata_sent.is_set(): - await self._metadata_sent.wait() +class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call, + _base_call.StreamUnaryCall): + """Object for managing stream-unary RPC calls. - serialized_request = _common.serialize(request, - self._request_serializer) + Returned when an instance of `StreamUnaryMultiCallable` object is called. + """ + # pylint: disable=too-many-arguments + def __init__(self, + request_async_iterator: Optional[AsyncIterable[RequestType]], + deadline: Optional[float], metadata: MetadataType, + credentials: Optional[grpc.CallCredentials], + channel: cygrpc.AioChannel, method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop) -> None: + super().__init__(channel.call(method, deadline, credentials), metadata, + request_serializer, response_deserializer, loop) + + self._init_stream_request_mixin(request_async_iterator) + self._init_unary_response_mixin(self._conduct_rpc()) + + async def _conduct_rpc(self) -> ResponseType: try: - await self._cython_call.send_serialized_message(serialized_request) + serialized_response = await self._cython_call.stream_unary( + self._metadata, self._metadata_sent_observer) except asyncio.CancelledError: if not self.cancelled(): self.cancel() - await self._raise_for_status() - async def done_writing(self) -> None: - """Implementation of done_writing is idempotent.""" - if self._cython_call.done(): - # If the RPC is finished, do nothing. - return - if not self._done_writing: - # If the done writing is not sent before, try to send it. - self._done_writing = True - try: - await self._cython_call.send_receive_close() - except asyncio.CancelledError: - if not self.cancelled(): - self.cancel() - await self._raise_for_status() + # Raises RpcError if the RPC failed or cancelled + await self._raise_for_status() - async def _fetch_stream_responses(self) -> ResponseType: - """The async generator that yields responses from peer.""" - message = await self._read() - while message is not cygrpc.EOF: - yield message - message = await self._read() + return _common.deserialize(serialized_response, + self._response_deserializer) - def __aiter__(self) -> AsyncIterable[ResponseType]: - if self._message_aiter is None: - self._message_aiter = self._fetch_stream_responses() - return self._message_aiter - async def _read(self) -> ResponseType: - # Wait for the setup - await self._initializer +class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call, + _base_call.StreamStreamCall): + """Object for managing stream-stream RPC calls. - # Reads response message from Core + Returned when an instance of `StreamStreamMultiCallable` object is called. + """ + _initializer: asyncio.Task + + # pylint: disable=too-many-arguments + def __init__(self, + request_async_iterator: Optional[AsyncIterable[RequestType]], + deadline: Optional[float], metadata: MetadataType, + credentials: Optional[grpc.CallCredentials], + channel: cygrpc.AioChannel, method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop) -> None: + super().__init__(channel.call(method, deadline, credentials), metadata, + request_serializer, response_deserializer, loop) + self._initializer = self._loop.create_task(self._prepare_rpc()) + self._init_stream_request_mixin(request_async_iterator) + self._init_stream_response_mixin(self._initializer) + + async def _prepare_rpc(self): + """This method prepares the RPC for receiving/sending messages. + + All other operations around the stream should only happen after the + completion of this method. + """ try: - raw_response = await self._cython_call.receive_serialized_message() + await self._cython_call.initiate_stream_stream( + self._metadata, self._metadata_sent_observer) except asyncio.CancelledError: if not self.cancelled(): self.cancel() - await self._raise_for_status() - - if raw_response is cygrpc.EOF: - return cygrpc.EOF - else: - return _common.deserialize(raw_response, - self._response_deserializer) - - async def read(self) -> ResponseType: - if self._cython_call.done(): - await self._raise_for_status() - return cygrpc.EOF - - response_message = await self._read() - - if response_message is cygrpc.EOF: - # If the read operation failed, Core should explain why. - await self._raise_for_status() - return response_message + # No need to raise RpcError here, because no one will `await` this task.