|
|
|
@ -15,15 +15,15 @@ |
|
|
|
|
|
|
|
|
|
import asyncio |
|
|
|
|
from functools import partial |
|
|
|
|
from typing import AsyncIterable, Dict, Optional |
|
|
|
|
from typing import AsyncIterable, Awaitable, Dict, Optional |
|
|
|
|
|
|
|
|
|
import grpc |
|
|
|
|
from grpc import _common |
|
|
|
|
from grpc._cython import cygrpc |
|
|
|
|
|
|
|
|
|
from . import _base_call |
|
|
|
|
from ._typing import (DeserializingFunction, MetadataType, RequestType, |
|
|
|
|
ResponseType, SerializingFunction, DoneCallbackType) |
|
|
|
|
from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType, |
|
|
|
|
RequestType, ResponseType, SerializingFunction) |
|
|
|
|
|
|
|
|
|
__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' |
|
|
|
|
|
|
|
|
@ -145,7 +145,7 @@ def _create_rpc_error(initial_metadata: Optional[MetadataType], |
|
|
|
|
status.trailing_metadata()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Call(_base_call.Call): |
|
|
|
|
class Call: |
|
|
|
|
"""Base implementation of client RPC Call object. |
|
|
|
|
|
|
|
|
|
Implements logic around final status, metadata and cancellation. |
|
|
|
@ -153,11 +153,19 @@ class Call(_base_call.Call): |
|
|
|
|
_loop: asyncio.AbstractEventLoop |
|
|
|
|
_code: grpc.StatusCode |
|
|
|
|
_cython_call: cygrpc._AioCall |
|
|
|
|
_metadata: MetadataType |
|
|
|
|
_request_serializer: SerializingFunction |
|
|
|
|
_response_deserializer: DeserializingFunction |
|
|
|
|
|
|
|
|
|
def __init__(self, cython_call: cygrpc._AioCall, |
|
|
|
|
def __init__(self, cython_call: cygrpc._AioCall, metadata: MetadataType, |
|
|
|
|
request_serializer: SerializingFunction, |
|
|
|
|
response_deserializer: DeserializingFunction, |
|
|
|
|
loop: asyncio.AbstractEventLoop) -> None: |
|
|
|
|
self._loop = loop |
|
|
|
|
self._cython_call = cython_call |
|
|
|
|
self._metadata = metadata |
|
|
|
|
self._request_serializer = request_serializer |
|
|
|
|
self._response_deserializer = response_deserializer |
|
|
|
|
|
|
|
|
|
def __del__(self) -> None: |
|
|
|
|
if not self._cython_call.done(): |
|
|
|
@ -221,63 +229,24 @@ class Call(_base_call.Call): |
|
|
|
|
return self._repr() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): |
|
|
|
|
"""Object for managing unary-unary RPC calls. |
|
|
|
|
|
|
|
|
|
Returned when an instance of `UnaryUnaryMultiCallable` object is called. |
|
|
|
|
""" |
|
|
|
|
_request: RequestType |
|
|
|
|
_metadata: Optional[MetadataType] |
|
|
|
|
_request_serializer: SerializingFunction |
|
|
|
|
_response_deserializer: DeserializingFunction |
|
|
|
|
_call: asyncio.Task |
|
|
|
|
class _UnaryResponseMixin(Call): |
|
|
|
|
_call_finisher: asyncio.Task |
|
|
|
|
|
|
|
|
|
# pylint: disable=too-many-arguments |
|
|
|
|
def __init__(self, request: RequestType, deadline: Optional[float], |
|
|
|
|
metadata: MetadataType, |
|
|
|
|
credentials: Optional[grpc.CallCredentials], |
|
|
|
|
channel: cygrpc.AioChannel, method: bytes, |
|
|
|
|
request_serializer: SerializingFunction, |
|
|
|
|
response_deserializer: DeserializingFunction, |
|
|
|
|
loop: asyncio.AbstractEventLoop) -> None: |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials), loop) |
|
|
|
|
self._request = request |
|
|
|
|
self._metadata = metadata |
|
|
|
|
self._request_serializer = request_serializer |
|
|
|
|
self._response_deserializer = response_deserializer |
|
|
|
|
self._call = loop.create_task(self._invoke()) |
|
|
|
|
def _init_unary_response_mixin(self, |
|
|
|
|
response_coro: Awaitable[ResponseType]): |
|
|
|
|
self._call_finisher = self._loop.create_task(response_coro) |
|
|
|
|
|
|
|
|
|
def cancel(self) -> bool: |
|
|
|
|
if super().cancel(): |
|
|
|
|
self._call.cancel() |
|
|
|
|
self._call_finisher.cancel() |
|
|
|
|
return True |
|
|
|
|
else: |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
async def _invoke(self) -> ResponseType: |
|
|
|
|
serialized_request = _common.serialize(self._request, |
|
|
|
|
self._request_serializer) |
|
|
|
|
|
|
|
|
|
# NOTE(lidiz) asyncio.CancelledError is not a good transport for status, |
|
|
|
|
# because the asyncio.Task class do not cache the exception object. |
|
|
|
|
# https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785 |
|
|
|
|
try: |
|
|
|
|
serialized_response = await self._cython_call.unary_unary( |
|
|
|
|
serialized_request, self._metadata) |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
|
|
|
|
|
# Raises here if RPC failed or cancelled |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
|
|
|
|
|
return _common.deserialize(serialized_response, |
|
|
|
|
self._response_deserializer) |
|
|
|
|
|
|
|
|
|
def __await__(self) -> ResponseType: |
|
|
|
|
"""Wait till the ongoing RPC request finishes.""" |
|
|
|
|
try: |
|
|
|
|
response = yield from self._call |
|
|
|
|
response = yield from self._call_finisher |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
# Even if we caught all other CancelledError, there is still |
|
|
|
|
# this corner case. If the application cancels immediately after |
|
|
|
@ -289,53 +258,21 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): |
|
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UnaryStreamCall(Call, _base_call.UnaryStreamCall): |
|
|
|
|
"""Object for managing unary-stream RPC calls. |
|
|
|
|
|
|
|
|
|
Returned when an instance of `UnaryStreamMultiCallable` object is called. |
|
|
|
|
""" |
|
|
|
|
_request: RequestType |
|
|
|
|
_metadata: MetadataType |
|
|
|
|
_request_serializer: SerializingFunction |
|
|
|
|
_response_deserializer: DeserializingFunction |
|
|
|
|
_send_unary_request_task: asyncio.Task |
|
|
|
|
class _StreamResponseMixin(Call): |
|
|
|
|
_message_aiter: AsyncIterable[ResponseType] |
|
|
|
|
_prerequisite: asyncio.Task |
|
|
|
|
|
|
|
|
|
# pylint: disable=too-many-arguments |
|
|
|
|
def __init__(self, request: RequestType, deadline: Optional[float], |
|
|
|
|
metadata: MetadataType, |
|
|
|
|
credentials: Optional[grpc.CallCredentials], |
|
|
|
|
channel: cygrpc.AioChannel, method: bytes, |
|
|
|
|
request_serializer: SerializingFunction, |
|
|
|
|
response_deserializer: DeserializingFunction, |
|
|
|
|
loop: asyncio.AbstractEventLoop) -> None: |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials), loop) |
|
|
|
|
self._request = request |
|
|
|
|
self._metadata = metadata |
|
|
|
|
self._request_serializer = request_serializer |
|
|
|
|
self._response_deserializer = response_deserializer |
|
|
|
|
self._send_unary_request_task = loop.create_task( |
|
|
|
|
self._send_unary_request()) |
|
|
|
|
def _init_stream_response_mixin(self, prerequisite: asyncio.Task): |
|
|
|
|
self._message_aiter = None |
|
|
|
|
self._prerequisite = prerequisite |
|
|
|
|
|
|
|
|
|
def cancel(self) -> bool: |
|
|
|
|
if super().cancel(): |
|
|
|
|
self._send_unary_request_task.cancel() |
|
|
|
|
self._prerequisite.cancel() |
|
|
|
|
return True |
|
|
|
|
else: |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
async def _send_unary_request(self) -> ResponseType: |
|
|
|
|
serialized_request = _common.serialize(self._request, |
|
|
|
|
self._request_serializer) |
|
|
|
|
try: |
|
|
|
|
await self._cython_call.initiate_unary_stream( |
|
|
|
|
serialized_request, self._metadata) |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
raise |
|
|
|
|
|
|
|
|
|
async def _fetch_stream_responses(self) -> ResponseType: |
|
|
|
|
message = await self._read() |
|
|
|
|
while message is not cygrpc.EOF: |
|
|
|
@ -349,7 +286,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): |
|
|
|
|
|
|
|
|
|
async def _read(self) -> ResponseType: |
|
|
|
|
# Wait for the request being sent |
|
|
|
|
await self._send_unary_request_task |
|
|
|
|
await self._prerequisite |
|
|
|
|
|
|
|
|
|
# Reads response message from Core |
|
|
|
|
try: |
|
|
|
@ -366,7 +303,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): |
|
|
|
|
self._response_deserializer) |
|
|
|
|
|
|
|
|
|
async def read(self) -> ResponseType: |
|
|
|
|
if self._cython_call.done(): |
|
|
|
|
if self.done(): |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
return cygrpc.EOF |
|
|
|
|
|
|
|
|
@ -378,39 +315,16 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): |
|
|
|
|
return response_message |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StreamUnaryCall(Call, _base_call.StreamUnaryCall): |
|
|
|
|
"""Object for managing stream-unary RPC calls. |
|
|
|
|
|
|
|
|
|
Returned when an instance of `StreamUnaryMultiCallable` object is called. |
|
|
|
|
""" |
|
|
|
|
_metadata: MetadataType |
|
|
|
|
_request_serializer: SerializingFunction |
|
|
|
|
_response_deserializer: DeserializingFunction |
|
|
|
|
|
|
|
|
|
class _StreamRequestMixin(Call): |
|
|
|
|
_metadata_sent: asyncio.Event |
|
|
|
|
_done_writing: bool |
|
|
|
|
_call_finisher: asyncio.Task |
|
|
|
|
_async_request_poller: asyncio.Task |
|
|
|
|
|
|
|
|
|
# pylint: disable=too-many-arguments |
|
|
|
|
def __init__(self, |
|
|
|
|
request_async_iterator: Optional[AsyncIterable[RequestType]], |
|
|
|
|
deadline: Optional[float], metadata: MetadataType, |
|
|
|
|
credentials: Optional[grpc.CallCredentials], |
|
|
|
|
channel: cygrpc.AioChannel, method: bytes, |
|
|
|
|
request_serializer: SerializingFunction, |
|
|
|
|
response_deserializer: DeserializingFunction, |
|
|
|
|
loop: asyncio.AbstractEventLoop) -> None: |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials), loop) |
|
|
|
|
self._metadata = metadata |
|
|
|
|
self._request_serializer = request_serializer |
|
|
|
|
self._response_deserializer = response_deserializer |
|
|
|
|
_async_request_poller: Optional[asyncio.Task] |
|
|
|
|
|
|
|
|
|
self._metadata_sent = asyncio.Event(loop=loop) |
|
|
|
|
def _init_stream_request_mixin( |
|
|
|
|
self, request_async_iterator: Optional[AsyncIterable[RequestType]]): |
|
|
|
|
self._metadata_sent = asyncio.Event(loop=self._loop) |
|
|
|
|
self._done_writing = False |
|
|
|
|
|
|
|
|
|
self._call_finisher = loop.create_task(self._conduct_rpc()) |
|
|
|
|
|
|
|
|
|
# If user passes in an async iterator, create a consumer Task. |
|
|
|
|
if request_async_iterator is not None: |
|
|
|
|
self._async_request_poller = self._loop.create_task( |
|
|
|
@ -420,7 +334,6 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall): |
|
|
|
|
|
|
|
|
|
def cancel(self) -> bool: |
|
|
|
|
if super().cancel(): |
|
|
|
|
self._call_finisher.cancel() |
|
|
|
|
if self._async_request_poller is not None: |
|
|
|
|
self._async_request_poller.cancel() |
|
|
|
|
return True |
|
|
|
@ -430,38 +343,14 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall): |
|
|
|
|
def _metadata_sent_observer(self): |
|
|
|
|
self._metadata_sent.set() |
|
|
|
|
|
|
|
|
|
async def _conduct_rpc(self) -> ResponseType: |
|
|
|
|
try: |
|
|
|
|
serialized_response = await self._cython_call.stream_unary( |
|
|
|
|
self._metadata, self._metadata_sent_observer) |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
|
|
|
|
|
# Raises RpcError if the RPC failed or cancelled |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
|
|
|
|
|
return _common.deserialize(serialized_response, |
|
|
|
|
self._response_deserializer) |
|
|
|
|
|
|
|
|
|
async def _consume_request_iterator( |
|
|
|
|
self, request_async_iterator: AsyncIterable[RequestType]) -> None: |
|
|
|
|
async for request in request_async_iterator: |
|
|
|
|
await self.write(request) |
|
|
|
|
await self.done_writing() |
|
|
|
|
|
|
|
|
|
def __await__(self) -> ResponseType: |
|
|
|
|
"""Wait till the ongoing RPC request finishes.""" |
|
|
|
|
try: |
|
|
|
|
response = yield from self._call_finisher |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
raise |
|
|
|
|
return response |
|
|
|
|
|
|
|
|
|
async def write(self, request: RequestType) -> None: |
|
|
|
|
if self._cython_call.done(): |
|
|
|
|
if self.done(): |
|
|
|
|
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) |
|
|
|
|
if self._done_writing: |
|
|
|
|
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) |
|
|
|
@ -480,7 +369,7 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall): |
|
|
|
|
|
|
|
|
|
async def done_writing(self) -> None: |
|
|
|
|
"""Implementation of done_writing is idempotent.""" |
|
|
|
|
if self._cython_call.done(): |
|
|
|
|
if self.done(): |
|
|
|
|
# If the RPC is finished, do nothing. |
|
|
|
|
return |
|
|
|
|
if not self._done_writing: |
|
|
|
@ -494,152 +383,153 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall): |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StreamStreamCall(Call, _base_call.StreamStreamCall): |
|
|
|
|
"""Object for managing stream-stream RPC calls. |
|
|
|
|
class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall): |
|
|
|
|
"""Object for managing unary-unary RPC calls. |
|
|
|
|
|
|
|
|
|
Returned when an instance of `StreamStreamMultiCallable` object is called. |
|
|
|
|
Returned when an instance of `UnaryUnaryMultiCallable` object is called. |
|
|
|
|
""" |
|
|
|
|
_metadata: MetadataType |
|
|
|
|
_request_serializer: SerializingFunction |
|
|
|
|
_response_deserializer: DeserializingFunction |
|
|
|
|
|
|
|
|
|
_metadata_sent: asyncio.Event |
|
|
|
|
_done_writing: bool |
|
|
|
|
_initializer: asyncio.Task |
|
|
|
|
_async_request_poller: asyncio.Task |
|
|
|
|
_message_aiter: AsyncIterable[ResponseType] |
|
|
|
|
_request: RequestType |
|
|
|
|
_call: asyncio.Task |
|
|
|
|
|
|
|
|
|
# pylint: disable=too-many-arguments |
|
|
|
|
def __init__(self, |
|
|
|
|
request_async_iterator: Optional[AsyncIterable[RequestType]], |
|
|
|
|
deadline: Optional[float], metadata: MetadataType, |
|
|
|
|
def __init__(self, request: RequestType, deadline: Optional[float], |
|
|
|
|
metadata: MetadataType, |
|
|
|
|
credentials: Optional[grpc.CallCredentials], |
|
|
|
|
channel: cygrpc.AioChannel, method: bytes, |
|
|
|
|
request_serializer: SerializingFunction, |
|
|
|
|
response_deserializer: DeserializingFunction, |
|
|
|
|
loop: asyncio.AbstractEventLoop) -> None: |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials), loop) |
|
|
|
|
self._metadata = metadata |
|
|
|
|
self._request_serializer = request_serializer |
|
|
|
|
self._response_deserializer = response_deserializer |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials), metadata, |
|
|
|
|
request_serializer, response_deserializer, loop) |
|
|
|
|
self._request = request |
|
|
|
|
self._init_unary_response_mixin(self._invoke()) |
|
|
|
|
|
|
|
|
|
self._metadata_sent = asyncio.Event(loop=loop) |
|
|
|
|
self._done_writing = False |
|
|
|
|
async def _invoke(self) -> ResponseType: |
|
|
|
|
serialized_request = _common.serialize(self._request, |
|
|
|
|
self._request_serializer) |
|
|
|
|
|
|
|
|
|
self._initializer = self._loop.create_task(self._prepare_rpc()) |
|
|
|
|
# NOTE(lidiz) asyncio.CancelledError is not a good transport for status, |
|
|
|
|
# because the asyncio.Task class do not cache the exception object. |
|
|
|
|
# https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785 |
|
|
|
|
try: |
|
|
|
|
serialized_response = await self._cython_call.unary_unary( |
|
|
|
|
serialized_request, self._metadata) |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
|
|
|
|
|
# If user passes in an async iterator, create a consumer coroutine. |
|
|
|
|
if request_async_iterator is not None: |
|
|
|
|
self._async_request_poller = loop.create_task( |
|
|
|
|
self._consume_request_iterator(request_async_iterator)) |
|
|
|
|
else: |
|
|
|
|
self._async_request_poller = None |
|
|
|
|
self._message_aiter = None |
|
|
|
|
# Raises here if RPC failed or cancelled |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
|
|
|
|
|
def cancel(self) -> bool: |
|
|
|
|
if super().cancel(): |
|
|
|
|
self._initializer.cancel() |
|
|
|
|
if self._async_request_poller is not None: |
|
|
|
|
self._async_request_poller.cancel() |
|
|
|
|
return True |
|
|
|
|
else: |
|
|
|
|
return False |
|
|
|
|
return _common.deserialize(serialized_response, |
|
|
|
|
self._response_deserializer) |
|
|
|
|
|
|
|
|
|
def _metadata_sent_observer(self): |
|
|
|
|
self._metadata_sent.set() |
|
|
|
|
|
|
|
|
|
async def _prepare_rpc(self): |
|
|
|
|
"""This method prepares the RPC for receiving/sending messages. |
|
|
|
|
class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall): |
|
|
|
|
"""Object for managing unary-stream RPC calls. |
|
|
|
|
|
|
|
|
|
All other operations around the stream should only happen after the |
|
|
|
|
completion of this method. |
|
|
|
|
Returned when an instance of `UnaryStreamMultiCallable` object is called. |
|
|
|
|
""" |
|
|
|
|
try: |
|
|
|
|
await self._cython_call.initiate_stream_stream( |
|
|
|
|
self._metadata, self._metadata_sent_observer) |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
# No need to raise RpcError here, because no one will `await` this task. |
|
|
|
|
|
|
|
|
|
async def _consume_request_iterator( |
|
|
|
|
self, request_async_iterator: Optional[AsyncIterable[RequestType]] |
|
|
|
|
) -> None: |
|
|
|
|
async for request in request_async_iterator: |
|
|
|
|
await self.write(request) |
|
|
|
|
await self.done_writing() |
|
|
|
|
_request: RequestType |
|
|
|
|
_send_unary_request_task: asyncio.Task |
|
|
|
|
|
|
|
|
|
async def write(self, request: RequestType) -> None: |
|
|
|
|
if self._cython_call.done(): |
|
|
|
|
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) |
|
|
|
|
if self._done_writing: |
|
|
|
|
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) |
|
|
|
|
if not self._metadata_sent.is_set(): |
|
|
|
|
await self._metadata_sent.wait() |
|
|
|
|
# pylint: disable=too-many-arguments |
|
|
|
|
def __init__(self, request: RequestType, deadline: Optional[float], |
|
|
|
|
metadata: MetadataType, |
|
|
|
|
credentials: Optional[grpc.CallCredentials], |
|
|
|
|
channel: cygrpc.AioChannel, method: bytes, |
|
|
|
|
request_serializer: SerializingFunction, |
|
|
|
|
response_deserializer: DeserializingFunction, |
|
|
|
|
loop: asyncio.AbstractEventLoop) -> None: |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials), metadata, |
|
|
|
|
request_serializer, response_deserializer, loop) |
|
|
|
|
self._request = request |
|
|
|
|
self._send_unary_request_task = loop.create_task( |
|
|
|
|
self._send_unary_request()) |
|
|
|
|
self._init_stream_response_mixin(self._send_unary_request_task) |
|
|
|
|
|
|
|
|
|
serialized_request = _common.serialize(request, |
|
|
|
|
async def _send_unary_request(self) -> ResponseType: |
|
|
|
|
serialized_request = _common.serialize(self._request, |
|
|
|
|
self._request_serializer) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
await self._cython_call.send_serialized_message(serialized_request) |
|
|
|
|
await self._cython_call.initiate_unary_stream( |
|
|
|
|
serialized_request, self._metadata) |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
raise |
|
|
|
|
|
|
|
|
|
async def done_writing(self) -> None: |
|
|
|
|
"""Implementation of done_writing is idempotent.""" |
|
|
|
|
if self._cython_call.done(): |
|
|
|
|
# If the RPC is finished, do nothing. |
|
|
|
|
return |
|
|
|
|
if not self._done_writing: |
|
|
|
|
# If the done writing is not sent before, try to send it. |
|
|
|
|
self._done_writing = True |
|
|
|
|
try: |
|
|
|
|
await self._cython_call.send_receive_close() |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
|
|
|
|
|
async def _fetch_stream_responses(self) -> ResponseType: |
|
|
|
|
"""The async generator that yields responses from peer.""" |
|
|
|
|
message = await self._read() |
|
|
|
|
while message is not cygrpc.EOF: |
|
|
|
|
yield message |
|
|
|
|
message = await self._read() |
|
|
|
|
class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call, |
|
|
|
|
_base_call.StreamUnaryCall): |
|
|
|
|
"""Object for managing stream-unary RPC calls. |
|
|
|
|
|
|
|
|
|
def __aiter__(self) -> AsyncIterable[ResponseType]: |
|
|
|
|
if self._message_aiter is None: |
|
|
|
|
self._message_aiter = self._fetch_stream_responses() |
|
|
|
|
return self._message_aiter |
|
|
|
|
Returned when an instance of `StreamUnaryMultiCallable` object is called. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
async def _read(self) -> ResponseType: |
|
|
|
|
# Wait for the setup |
|
|
|
|
await self._initializer |
|
|
|
|
# pylint: disable=too-many-arguments |
|
|
|
|
def __init__(self, |
|
|
|
|
request_async_iterator: Optional[AsyncIterable[RequestType]], |
|
|
|
|
deadline: Optional[float], metadata: MetadataType, |
|
|
|
|
credentials: Optional[grpc.CallCredentials], |
|
|
|
|
channel: cygrpc.AioChannel, method: bytes, |
|
|
|
|
request_serializer: SerializingFunction, |
|
|
|
|
response_deserializer: DeserializingFunction, |
|
|
|
|
loop: asyncio.AbstractEventLoop) -> None: |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials), metadata, |
|
|
|
|
request_serializer, response_deserializer, loop) |
|
|
|
|
|
|
|
|
|
# Reads response message from Core |
|
|
|
|
self._init_stream_request_mixin(request_async_iterator) |
|
|
|
|
self._init_unary_response_mixin(self._conduct_rpc()) |
|
|
|
|
|
|
|
|
|
async def _conduct_rpc(self) -> ResponseType: |
|
|
|
|
try: |
|
|
|
|
raw_response = await self._cython_call.receive_serialized_message() |
|
|
|
|
serialized_response = await self._cython_call.stream_unary( |
|
|
|
|
self._metadata, self._metadata_sent_observer) |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
|
|
|
|
|
# Raises RpcError if the RPC failed or cancelled |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
|
|
|
|
|
if raw_response is cygrpc.EOF: |
|
|
|
|
return cygrpc.EOF |
|
|
|
|
else: |
|
|
|
|
return _common.deserialize(raw_response, |
|
|
|
|
return _common.deserialize(serialized_response, |
|
|
|
|
self._response_deserializer) |
|
|
|
|
|
|
|
|
|
async def read(self) -> ResponseType: |
|
|
|
|
if self._cython_call.done(): |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
return cygrpc.EOF |
|
|
|
|
|
|
|
|
|
response_message = await self._read() |
|
|
|
|
class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call, |
|
|
|
|
_base_call.StreamStreamCall): |
|
|
|
|
"""Object for managing stream-stream RPC calls. |
|
|
|
|
|
|
|
|
|
if response_message is cygrpc.EOF: |
|
|
|
|
# If the read operation failed, Core should explain why. |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
return response_message |
|
|
|
|
Returned when an instance of `StreamStreamMultiCallable` object is called. |
|
|
|
|
""" |
|
|
|
|
_initializer: asyncio.Task |
|
|
|
|
|
|
|
|
|
# pylint: disable=too-many-arguments |
|
|
|
|
def __init__(self, |
|
|
|
|
request_async_iterator: Optional[AsyncIterable[RequestType]], |
|
|
|
|
deadline: Optional[float], metadata: MetadataType, |
|
|
|
|
credentials: Optional[grpc.CallCredentials], |
|
|
|
|
channel: cygrpc.AioChannel, method: bytes, |
|
|
|
|
request_serializer: SerializingFunction, |
|
|
|
|
response_deserializer: DeserializingFunction, |
|
|
|
|
loop: asyncio.AbstractEventLoop) -> None: |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials), metadata, |
|
|
|
|
request_serializer, response_deserializer, loop) |
|
|
|
|
self._initializer = self._loop.create_task(self._prepare_rpc()) |
|
|
|
|
self._init_stream_request_mixin(request_async_iterator) |
|
|
|
|
self._init_stream_response_mixin(self._initializer) |
|
|
|
|
|
|
|
|
|
async def _prepare_rpc(self): |
|
|
|
|
"""This method prepares the RPC for receiving/sending messages. |
|
|
|
|
|
|
|
|
|
All other operations around the stream should only happen after the |
|
|
|
|
completion of this method. |
|
|
|
|
""" |
|
|
|
|
try: |
|
|
|
|
await self._cython_call.initiate_stream_stream( |
|
|
|
|
self._metadata, self._metadata_sent_observer) |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
# No need to raise RpcError here, because no one will `await` this task. |
|
|
|
|