De-duplication for the call objects

pull/21772/head
Lidi Zheng 5 years ago
parent 152de3cf91
commit 0c142306be
  1. 1
      .pylintrc
  2. 400
      src/python/grpcio/grpc/experimental/aio/_call.py

@ -20,6 +20,7 @@ dummy-variables-rgx=^ignored_|^unused_
# be what works for us at the moment (excepting the dead-code-walking Beta # be what works for us at the moment (excepting the dead-code-walking Beta
# API). # API).
max-args=7 max-args=7
max-parents=8
[MISCELLANEOUS] [MISCELLANEOUS]

@ -15,15 +15,15 @@
import asyncio import asyncio
from functools import partial from functools import partial
from typing import AsyncIterable, Dict, Optional from typing import AsyncIterable, Awaitable, Dict, Optional
import grpc import grpc
from grpc import _common from grpc import _common
from grpc._cython import cygrpc from grpc._cython import cygrpc
from . import _base_call from . import _base_call
from ._typing import (DeserializingFunction, MetadataType, RequestType, from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType,
ResponseType, SerializingFunction, DoneCallbackType) RequestType, ResponseType, SerializingFunction)
__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' __all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
@ -145,7 +145,7 @@ def _create_rpc_error(initial_metadata: Optional[MetadataType],
status.trailing_metadata()) status.trailing_metadata())
class Call(_base_call.Call): class Call:
"""Base implementation of client RPC Call object. """Base implementation of client RPC Call object.
Implements logic around final status, metadata and cancellation. Implements logic around final status, metadata and cancellation.
@ -153,11 +153,19 @@ class Call(_base_call.Call):
_loop: asyncio.AbstractEventLoop _loop: asyncio.AbstractEventLoop
_code: grpc.StatusCode _code: grpc.StatusCode
_cython_call: cygrpc._AioCall _cython_call: cygrpc._AioCall
_metadata: MetadataType
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
def __init__(self, cython_call: cygrpc._AioCall, def __init__(self, cython_call: cygrpc._AioCall, metadata: MetadataType,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None: loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop self._loop = loop
self._cython_call = cython_call self._cython_call = cython_call
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def __del__(self) -> None: def __del__(self) -> None:
if not self._cython_call.done(): if not self._cython_call.done():
@ -221,63 +229,24 @@ class Call(_base_call.Call):
return self._repr() return self._repr()
class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): class _UnaryResponseMixin(Call):
"""Object for managing unary-unary RPC calls. _call_finisher: asyncio.Task
Returned when an instance of `UnaryUnaryMultiCallable` object is called.
"""
_request: RequestType
_metadata: Optional[MetadataType]
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_call: asyncio.Task
# pylint: disable=too-many-arguments def _init_unary_response_mixin(self,
def __init__(self, request: RequestType, deadline: Optional[float], response_coro: Awaitable[ResponseType]):
metadata: MetadataType, self._call_finisher = self._loop.create_task(response_coro)
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), loop)
self._request = request
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._call = loop.create_task(self._invoke())
def cancel(self) -> bool: def cancel(self) -> bool:
if super().cancel(): if super().cancel():
self._call.cancel() self._call_finisher.cancel()
return True return True
else: else:
return False return False
async def _invoke(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
self._request_serializer)
# NOTE(lidiz) asyncio.CancelledError is not a good transport for status,
# because the asyncio.Task class do not cache the exception object.
# https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
try:
serialized_response = await self._cython_call.unary_unary(
serialized_request, self._metadata)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
# Raises here if RPC failed or cancelled
await self._raise_for_status()
return _common.deserialize(serialized_response,
self._response_deserializer)
def __await__(self) -> ResponseType: def __await__(self) -> ResponseType:
"""Wait till the ongoing RPC request finishes.""" """Wait till the ongoing RPC request finishes."""
try: try:
response = yield from self._call response = yield from self._call_finisher
except asyncio.CancelledError: except asyncio.CancelledError:
# Even if we caught all other CancelledError, there is still # Even if we caught all other CancelledError, there is still
# this corner case. If the application cancels immediately after # this corner case. If the application cancels immediately after
@ -289,53 +258,21 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
return response return response
class UnaryStreamCall(Call, _base_call.UnaryStreamCall): class _StreamResponseMixin(Call):
"""Object for managing unary-stream RPC calls.
Returned when an instance of `UnaryStreamMultiCallable` object is called.
"""
_request: RequestType
_metadata: MetadataType
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_send_unary_request_task: asyncio.Task
_message_aiter: AsyncIterable[ResponseType] _message_aiter: AsyncIterable[ResponseType]
_prerequisite: asyncio.Task
# pylint: disable=too-many-arguments def _init_stream_response_mixin(self, prerequisite: asyncio.Task):
def __init__(self, request: RequestType, deadline: Optional[float],
metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), loop)
self._request = request
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._send_unary_request_task = loop.create_task(
self._send_unary_request())
self._message_aiter = None self._message_aiter = None
self._prerequisite = prerequisite
def cancel(self) -> bool: def cancel(self) -> bool:
if super().cancel(): if super().cancel():
self._send_unary_request_task.cancel() self._prerequisite.cancel()
return True return True
else: else:
return False return False
async def _send_unary_request(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
self._request_serializer)
try:
await self._cython_call.initiate_unary_stream(
serialized_request, self._metadata)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
raise
async def _fetch_stream_responses(self) -> ResponseType: async def _fetch_stream_responses(self) -> ResponseType:
message = await self._read() message = await self._read()
while message is not cygrpc.EOF: while message is not cygrpc.EOF:
@ -349,7 +286,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
async def _read(self) -> ResponseType: async def _read(self) -> ResponseType:
# Wait for the request being sent # Wait for the request being sent
await self._send_unary_request_task await self._prerequisite
# Reads response message from Core # Reads response message from Core
try: try:
@ -366,7 +303,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
self._response_deserializer) self._response_deserializer)
async def read(self) -> ResponseType: async def read(self) -> ResponseType:
if self._cython_call.done(): if self.done():
await self._raise_for_status() await self._raise_for_status()
return cygrpc.EOF return cygrpc.EOF
@ -378,39 +315,16 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
return response_message return response_message
class StreamUnaryCall(Call, _base_call.StreamUnaryCall): class _StreamRequestMixin(Call):
"""Object for managing stream-unary RPC calls.
Returned when an instance of `StreamUnaryMultiCallable` object is called.
"""
_metadata: MetadataType
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_metadata_sent: asyncio.Event _metadata_sent: asyncio.Event
_done_writing: bool _done_writing: bool
_call_finisher: asyncio.Task _async_request_poller: Optional[asyncio.Task]
_async_request_poller: asyncio.Task
# pylint: disable=too-many-arguments
def __init__(self,
request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), loop)
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._metadata_sent = asyncio.Event(loop=loop) def _init_stream_request_mixin(
self, request_async_iterator: Optional[AsyncIterable[RequestType]]):
self._metadata_sent = asyncio.Event(loop=self._loop)
self._done_writing = False self._done_writing = False
self._call_finisher = loop.create_task(self._conduct_rpc())
# If user passes in an async iterator, create a consumer Task. # If user passes in an async iterator, create a consumer Task.
if request_async_iterator is not None: if request_async_iterator is not None:
self._async_request_poller = self._loop.create_task( self._async_request_poller = self._loop.create_task(
@ -420,7 +334,6 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
def cancel(self) -> bool: def cancel(self) -> bool:
if super().cancel(): if super().cancel():
self._call_finisher.cancel()
if self._async_request_poller is not None: if self._async_request_poller is not None:
self._async_request_poller.cancel() self._async_request_poller.cancel()
return True return True
@ -430,38 +343,14 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
def _metadata_sent_observer(self): def _metadata_sent_observer(self):
self._metadata_sent.set() self._metadata_sent.set()
async def _conduct_rpc(self) -> ResponseType:
try:
serialized_response = await self._cython_call.stream_unary(
self._metadata, self._metadata_sent_observer)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
# Raises RpcError if the RPC failed or cancelled
await self._raise_for_status()
return _common.deserialize(serialized_response,
self._response_deserializer)
async def _consume_request_iterator( async def _consume_request_iterator(
self, request_async_iterator: AsyncIterable[RequestType]) -> None: self, request_async_iterator: AsyncIterable[RequestType]) -> None:
async for request in request_async_iterator: async for request in request_async_iterator:
await self.write(request) await self.write(request)
await self.done_writing() await self.done_writing()
def __await__(self) -> ResponseType:
"""Wait till the ongoing RPC request finishes."""
try:
response = yield from self._call_finisher
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
raise
return response
async def write(self, request: RequestType) -> None: async def write(self, request: RequestType) -> None:
if self._cython_call.done(): if self.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
if self._done_writing: if self._done_writing:
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
@ -480,7 +369,7 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
async def done_writing(self) -> None: async def done_writing(self) -> None:
"""Implementation of done_writing is idempotent.""" """Implementation of done_writing is idempotent."""
if self._cython_call.done(): if self.done():
# If the RPC is finished, do nothing. # If the RPC is finished, do nothing.
return return
if not self._done_writing: if not self._done_writing:
@ -494,152 +383,153 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
await self._raise_for_status() await self._raise_for_status()
class StreamStreamCall(Call, _base_call.StreamStreamCall): class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
"""Object for managing stream-stream RPC calls. """Object for managing unary-unary RPC calls.
Returned when an instance of `StreamStreamMultiCallable` object is called. Returned when an instance of `UnaryUnaryMultiCallable` object is called.
""" """
_metadata: MetadataType _request: RequestType
_request_serializer: SerializingFunction _call: asyncio.Task
_response_deserializer: DeserializingFunction
_metadata_sent: asyncio.Event
_done_writing: bool
_initializer: asyncio.Task
_async_request_poller: asyncio.Task
_message_aiter: AsyncIterable[ResponseType]
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
def __init__(self, def __init__(self, request: RequestType, deadline: Optional[float],
request_async_iterator: Optional[AsyncIterable[RequestType]], metadata: MetadataType,
deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials], credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction, response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None: loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), loop) super().__init__(channel.call(method, deadline, credentials), metadata,
self._metadata = metadata request_serializer, response_deserializer, loop)
self._request_serializer = request_serializer self._request = request
self._response_deserializer = response_deserializer self._init_unary_response_mixin(self._invoke())
self._metadata_sent = asyncio.Event(loop=loop) async def _invoke(self) -> ResponseType:
self._done_writing = False serialized_request = _common.serialize(self._request,
self._request_serializer)
self._initializer = self._loop.create_task(self._prepare_rpc()) # NOTE(lidiz) asyncio.CancelledError is not a good transport for status,
# because the asyncio.Task class do not cache the exception object.
# https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
try:
serialized_response = await self._cython_call.unary_unary(
serialized_request, self._metadata)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
# If user passes in an async iterator, create a consumer coroutine. # Raises here if RPC failed or cancelled
if request_async_iterator is not None: await self._raise_for_status()
self._async_request_poller = loop.create_task(
self._consume_request_iterator(request_async_iterator))
else:
self._async_request_poller = None
self._message_aiter = None
def cancel(self) -> bool: return _common.deserialize(serialized_response,
if super().cancel(): self._response_deserializer)
self._initializer.cancel()
if self._async_request_poller is not None:
self._async_request_poller.cancel()
return True
else:
return False
def _metadata_sent_observer(self):
self._metadata_sent.set()
async def _prepare_rpc(self): class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
"""This method prepares the RPC for receiving/sending messages. """Object for managing unary-stream RPC calls.
All other operations around the stream should only happen after the Returned when an instance of `UnaryStreamMultiCallable` object is called.
completion of this method.
""" """
try: _request: RequestType
await self._cython_call.initiate_stream_stream( _send_unary_request_task: asyncio.Task
self._metadata, self._metadata_sent_observer)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
# No need to raise RpcError here, because no one will `await` this task.
async def _consume_request_iterator(
self, request_async_iterator: Optional[AsyncIterable[RequestType]]
) -> None:
async for request in request_async_iterator:
await self.write(request)
await self.done_writing()
async def write(self, request: RequestType) -> None: # pylint: disable=too-many-arguments
if self._cython_call.done(): def __init__(self, request: RequestType, deadline: Optional[float],
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) metadata: MetadataType,
if self._done_writing: credentials: Optional[grpc.CallCredentials],
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) channel: cygrpc.AioChannel, method: bytes,
if not self._metadata_sent.is_set(): request_serializer: SerializingFunction,
await self._metadata_sent.wait() response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), metadata,
request_serializer, response_deserializer, loop)
self._request = request
self._send_unary_request_task = loop.create_task(
self._send_unary_request())
self._init_stream_response_mixin(self._send_unary_request_task)
serialized_request = _common.serialize(request, async def _send_unary_request(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
self._request_serializer) self._request_serializer)
try: try:
await self._cython_call.send_serialized_message(serialized_request) await self._cython_call.initiate_unary_stream(
serialized_request, self._metadata)
except asyncio.CancelledError: except asyncio.CancelledError:
if not self.cancelled(): if not self.cancelled():
self.cancel() self.cancel()
await self._raise_for_status() raise
async def done_writing(self) -> None:
"""Implementation of done_writing is idempotent."""
if self._cython_call.done():
# If the RPC is finished, do nothing.
return
if not self._done_writing:
# If the done writing is not sent before, try to send it.
self._done_writing = True
try:
await self._cython_call.send_receive_close()
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
await self._raise_for_status()
async def _fetch_stream_responses(self) -> ResponseType: class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
"""The async generator that yields responses from peer.""" _base_call.StreamUnaryCall):
message = await self._read() """Object for managing stream-unary RPC calls.
while message is not cygrpc.EOF:
yield message
message = await self._read()
def __aiter__(self) -> AsyncIterable[ResponseType]: Returned when an instance of `StreamUnaryMultiCallable` object is called.
if self._message_aiter is None: """
self._message_aiter = self._fetch_stream_responses()
return self._message_aiter
async def _read(self) -> ResponseType: # pylint: disable=too-many-arguments
# Wait for the setup def __init__(self,
await self._initializer request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), metadata,
request_serializer, response_deserializer, loop)
# Reads response message from Core self._init_stream_request_mixin(request_async_iterator)
self._init_unary_response_mixin(self._conduct_rpc())
async def _conduct_rpc(self) -> ResponseType:
try: try:
raw_response = await self._cython_call.receive_serialized_message() serialized_response = await self._cython_call.stream_unary(
self._metadata, self._metadata_sent_observer)
except asyncio.CancelledError: except asyncio.CancelledError:
if not self.cancelled(): if not self.cancelled():
self.cancel() self.cancel()
# Raises RpcError if the RPC failed or cancelled
await self._raise_for_status() await self._raise_for_status()
if raw_response is cygrpc.EOF: return _common.deserialize(serialized_response,
return cygrpc.EOF
else:
return _common.deserialize(raw_response,
self._response_deserializer) self._response_deserializer)
async def read(self) -> ResponseType:
if self._cython_call.done():
await self._raise_for_status()
return cygrpc.EOF
response_message = await self._read() class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call,
_base_call.StreamStreamCall):
"""Object for managing stream-stream RPC calls.
if response_message is cygrpc.EOF: Returned when an instance of `StreamStreamMultiCallable` object is called.
# If the read operation failed, Core should explain why. """
await self._raise_for_status() _initializer: asyncio.Task
return response_message
# pylint: disable=too-many-arguments
def __init__(self,
request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), metadata,
request_serializer, response_deserializer, loop)
self._initializer = self._loop.create_task(self._prepare_rpc())
self._init_stream_request_mixin(request_async_iterator)
self._init_stream_response_mixin(self._initializer)
async def _prepare_rpc(self):
"""This method prepares the RPC for receiving/sending messages.
All other operations around the stream should only happen after the
completion of this method.
"""
try:
await self._cython_call.initiate_stream_stream(
self._metadata, self._metadata_sent_observer)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
# No need to raise RpcError here, because no one will `await` this task.

Loading…
Cancel
Save