De-duplication for the call objects

pull/21772/head
Lidi Zheng 5 years ago
parent 152de3cf91
commit 0c142306be
  1. 1
      .pylintrc
  2. 400
      src/python/grpcio/grpc/experimental/aio/_call.py

@ -20,6 +20,7 @@ dummy-variables-rgx=^ignored_|^unused_
# be what works for us at the moment (excepting the dead-code-walking Beta
# API).
max-args=7
max-parents=8
[MISCELLANEOUS]

@ -15,15 +15,15 @@
import asyncio
from functools import partial
from typing import AsyncIterable, Dict, Optional
from typing import AsyncIterable, Awaitable, Dict, Optional
import grpc
from grpc import _common
from grpc._cython import cygrpc
from . import _base_call
from ._typing import (DeserializingFunction, MetadataType, RequestType,
ResponseType, SerializingFunction, DoneCallbackType)
from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType,
RequestType, ResponseType, SerializingFunction)
__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'
@ -145,7 +145,7 @@ def _create_rpc_error(initial_metadata: Optional[MetadataType],
status.trailing_metadata())
class Call(_base_call.Call):
class Call:
"""Base implementation of client RPC Call object.
Implements logic around final status, metadata and cancellation.
@ -153,11 +153,19 @@ class Call(_base_call.Call):
_loop: asyncio.AbstractEventLoop
_code: grpc.StatusCode
_cython_call: cygrpc._AioCall
_metadata: MetadataType
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
def __init__(self, cython_call: cygrpc._AioCall,
def __init__(self, cython_call: cygrpc._AioCall, metadata: MetadataType,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._cython_call = cython_call
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
def __del__(self) -> None:
if not self._cython_call.done():
@ -221,63 +229,24 @@ class Call(_base_call.Call):
return self._repr()
class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
"""Object for managing unary-unary RPC calls.
Returned when an instance of `UnaryUnaryMultiCallable` object is called.
"""
_request: RequestType
_metadata: Optional[MetadataType]
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_call: asyncio.Task
class _UnaryResponseMixin(Call):
_call_finisher: asyncio.Task
# pylint: disable=too-many-arguments
def __init__(self, request: RequestType, deadline: Optional[float],
metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), loop)
self._request = request
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._call = loop.create_task(self._invoke())
def _init_unary_response_mixin(self,
response_coro: Awaitable[ResponseType]):
self._call_finisher = self._loop.create_task(response_coro)
def cancel(self) -> bool:
if super().cancel():
self._call.cancel()
self._call_finisher.cancel()
return True
else:
return False
async def _invoke(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
self._request_serializer)
# NOTE(lidiz) asyncio.CancelledError is not a good transport for status,
# because the asyncio.Task class do not cache the exception object.
# https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
try:
serialized_response = await self._cython_call.unary_unary(
serialized_request, self._metadata)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
# Raises here if RPC failed or cancelled
await self._raise_for_status()
return _common.deserialize(serialized_response,
self._response_deserializer)
def __await__(self) -> ResponseType:
"""Wait till the ongoing RPC request finishes."""
try:
response = yield from self._call
response = yield from self._call_finisher
except asyncio.CancelledError:
# Even if we caught all other CancelledError, there is still
# this corner case. If the application cancels immediately after
@ -289,53 +258,21 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
return response
class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
"""Object for managing unary-stream RPC calls.
Returned when an instance of `UnaryStreamMultiCallable` object is called.
"""
_request: RequestType
_metadata: MetadataType
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_send_unary_request_task: asyncio.Task
class _StreamResponseMixin(Call):
_message_aiter: AsyncIterable[ResponseType]
_prerequisite: asyncio.Task
# pylint: disable=too-many-arguments
def __init__(self, request: RequestType, deadline: Optional[float],
metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), loop)
self._request = request
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._send_unary_request_task = loop.create_task(
self._send_unary_request())
def _init_stream_response_mixin(self, prerequisite: asyncio.Task):
self._message_aiter = None
self._prerequisite = prerequisite
def cancel(self) -> bool:
if super().cancel():
self._send_unary_request_task.cancel()
self._prerequisite.cancel()
return True
else:
return False
async def _send_unary_request(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
self._request_serializer)
try:
await self._cython_call.initiate_unary_stream(
serialized_request, self._metadata)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
raise
async def _fetch_stream_responses(self) -> ResponseType:
message = await self._read()
while message is not cygrpc.EOF:
@ -349,7 +286,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
async def _read(self) -> ResponseType:
# Wait for the request being sent
await self._send_unary_request_task
await self._prerequisite
# Reads response message from Core
try:
@ -366,7 +303,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
self._response_deserializer)
async def read(self) -> ResponseType:
if self._cython_call.done():
if self.done():
await self._raise_for_status()
return cygrpc.EOF
@ -378,39 +315,16 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
return response_message
class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
"""Object for managing stream-unary RPC calls.
Returned when an instance of `StreamUnaryMultiCallable` object is called.
"""
_metadata: MetadataType
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
class _StreamRequestMixin(Call):
_metadata_sent: asyncio.Event
_done_writing: bool
_call_finisher: asyncio.Task
_async_request_poller: asyncio.Task
# pylint: disable=too-many-arguments
def __init__(self,
request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), loop)
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
_async_request_poller: Optional[asyncio.Task]
self._metadata_sent = asyncio.Event(loop=loop)
def _init_stream_request_mixin(
self, request_async_iterator: Optional[AsyncIterable[RequestType]]):
self._metadata_sent = asyncio.Event(loop=self._loop)
self._done_writing = False
self._call_finisher = loop.create_task(self._conduct_rpc())
# If user passes in an async iterator, create a consumer Task.
if request_async_iterator is not None:
self._async_request_poller = self._loop.create_task(
@ -420,7 +334,6 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
def cancel(self) -> bool:
if super().cancel():
self._call_finisher.cancel()
if self._async_request_poller is not None:
self._async_request_poller.cancel()
return True
@ -430,38 +343,14 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
def _metadata_sent_observer(self):
self._metadata_sent.set()
async def _conduct_rpc(self) -> ResponseType:
try:
serialized_response = await self._cython_call.stream_unary(
self._metadata, self._metadata_sent_observer)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
# Raises RpcError if the RPC failed or cancelled
await self._raise_for_status()
return _common.deserialize(serialized_response,
self._response_deserializer)
async def _consume_request_iterator(
self, request_async_iterator: AsyncIterable[RequestType]) -> None:
async for request in request_async_iterator:
await self.write(request)
await self.done_writing()
def __await__(self) -> ResponseType:
"""Wait till the ongoing RPC request finishes."""
try:
response = yield from self._call_finisher
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
raise
return response
async def write(self, request: RequestType) -> None:
if self._cython_call.done():
if self.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
if self._done_writing:
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
@ -480,7 +369,7 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
async def done_writing(self) -> None:
"""Implementation of done_writing is idempotent."""
if self._cython_call.done():
if self.done():
# If the RPC is finished, do nothing.
return
if not self._done_writing:
@ -494,152 +383,153 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
await self._raise_for_status()
class StreamStreamCall(Call, _base_call.StreamStreamCall):
"""Object for managing stream-stream RPC calls.
class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
"""Object for managing unary-unary RPC calls.
Returned when an instance of `StreamStreamMultiCallable` object is called.
Returned when an instance of `UnaryUnaryMultiCallable` object is called.
"""
_metadata: MetadataType
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_metadata_sent: asyncio.Event
_done_writing: bool
_initializer: asyncio.Task
_async_request_poller: asyncio.Task
_message_aiter: AsyncIterable[ResponseType]
_request: RequestType
_call: asyncio.Task
# pylint: disable=too-many-arguments
def __init__(self,
request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float], metadata: MetadataType,
def __init__(self, request: RequestType, deadline: Optional[float],
metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), loop)
self._metadata = metadata
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
super().__init__(channel.call(method, deadline, credentials), metadata,
request_serializer, response_deserializer, loop)
self._request = request
self._init_unary_response_mixin(self._invoke())
self._metadata_sent = asyncio.Event(loop=loop)
self._done_writing = False
async def _invoke(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
self._request_serializer)
self._initializer = self._loop.create_task(self._prepare_rpc())
# NOTE(lidiz) asyncio.CancelledError is not a good transport for status,
# because the asyncio.Task class do not cache the exception object.
# https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
try:
serialized_response = await self._cython_call.unary_unary(
serialized_request, self._metadata)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
# If user passes in an async iterator, create a consumer coroutine.
if request_async_iterator is not None:
self._async_request_poller = loop.create_task(
self._consume_request_iterator(request_async_iterator))
else:
self._async_request_poller = None
self._message_aiter = None
# Raises here if RPC failed or cancelled
await self._raise_for_status()
def cancel(self) -> bool:
if super().cancel():
self._initializer.cancel()
if self._async_request_poller is not None:
self._async_request_poller.cancel()
return True
else:
return False
return _common.deserialize(serialized_response,
self._response_deserializer)
def _metadata_sent_observer(self):
self._metadata_sent.set()
async def _prepare_rpc(self):
"""This method prepares the RPC for receiving/sending messages.
class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
"""Object for managing unary-stream RPC calls.
All other operations around the stream should only happen after the
completion of this method.
Returned when an instance of `UnaryStreamMultiCallable` object is called.
"""
try:
await self._cython_call.initiate_stream_stream(
self._metadata, self._metadata_sent_observer)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
# No need to raise RpcError here, because no one will `await` this task.
async def _consume_request_iterator(
self, request_async_iterator: Optional[AsyncIterable[RequestType]]
) -> None:
async for request in request_async_iterator:
await self.write(request)
await self.done_writing()
_request: RequestType
_send_unary_request_task: asyncio.Task
async def write(self, request: RequestType) -> None:
if self._cython_call.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
if self._done_writing:
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
if not self._metadata_sent.is_set():
await self._metadata_sent.wait()
# pylint: disable=too-many-arguments
def __init__(self, request: RequestType, deadline: Optional[float],
metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), metadata,
request_serializer, response_deserializer, loop)
self._request = request
self._send_unary_request_task = loop.create_task(
self._send_unary_request())
self._init_stream_response_mixin(self._send_unary_request_task)
serialized_request = _common.serialize(request,
async def _send_unary_request(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
self._request_serializer)
try:
await self._cython_call.send_serialized_message(serialized_request)
await self._cython_call.initiate_unary_stream(
serialized_request, self._metadata)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
await self._raise_for_status()
raise
async def done_writing(self) -> None:
"""Implementation of done_writing is idempotent."""
if self._cython_call.done():
# If the RPC is finished, do nothing.
return
if not self._done_writing:
# If the done writing is not sent before, try to send it.
self._done_writing = True
try:
await self._cython_call.send_receive_close()
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
await self._raise_for_status()
async def _fetch_stream_responses(self) -> ResponseType:
"""The async generator that yields responses from peer."""
message = await self._read()
while message is not cygrpc.EOF:
yield message
message = await self._read()
class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
_base_call.StreamUnaryCall):
"""Object for managing stream-unary RPC calls.
def __aiter__(self) -> AsyncIterable[ResponseType]:
if self._message_aiter is None:
self._message_aiter = self._fetch_stream_responses()
return self._message_aiter
Returned when an instance of `StreamUnaryMultiCallable` object is called.
"""
async def _read(self) -> ResponseType:
# Wait for the setup
await self._initializer
# pylint: disable=too-many-arguments
def __init__(self,
request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), metadata,
request_serializer, response_deserializer, loop)
# Reads response message from Core
self._init_stream_request_mixin(request_async_iterator)
self._init_unary_response_mixin(self._conduct_rpc())
async def _conduct_rpc(self) -> ResponseType:
try:
raw_response = await self._cython_call.receive_serialized_message()
serialized_response = await self._cython_call.stream_unary(
self._metadata, self._metadata_sent_observer)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
# Raises RpcError if the RPC failed or cancelled
await self._raise_for_status()
if raw_response is cygrpc.EOF:
return cygrpc.EOF
else:
return _common.deserialize(raw_response,
return _common.deserialize(serialized_response,
self._response_deserializer)
async def read(self) -> ResponseType:
if self._cython_call.done():
await self._raise_for_status()
return cygrpc.EOF
response_message = await self._read()
class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call,
_base_call.StreamStreamCall):
"""Object for managing stream-stream RPC calls.
if response_message is cygrpc.EOF:
# If the read operation failed, Core should explain why.
await self._raise_for_status()
return response_message
Returned when an instance of `StreamStreamMultiCallable` object is called.
"""
_initializer: asyncio.Task
# pylint: disable=too-many-arguments
def __init__(self,
request_async_iterator: Optional[AsyncIterable[RequestType]],
deadline: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
super().__init__(channel.call(method, deadline, credentials), metadata,
request_serializer, response_deserializer, loop)
self._initializer = self._loop.create_task(self._prepare_rpc())
self._init_stream_request_mixin(request_async_iterator)
self._init_stream_response_mixin(self._initializer)
async def _prepare_rpc(self):
"""This method prepares the RPC for receiving/sending messages.
All other operations around the stream should only happen after the
completion of this method.
"""
try:
await self._cython_call.initiate_stream_stream(
self._metadata, self._metadata_sent_observer)
except asyncio.CancelledError:
if not self.cancelled():
self.cancel()
# No need to raise RpcError here, because no one will `await` this task.

Loading…
Cancel
Save