|
|
|
@ -29,6 +29,7 @@ __all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall' |
|
|
|
|
_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!' |
|
|
|
|
_GC_CANCELLATION_DETAILS = 'Cancelled upon garbage collection!' |
|
|
|
|
_RPC_ALREADY_FINISHED_DETAILS = 'RPC already finished.' |
|
|
|
|
_RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".' |
|
|
|
|
|
|
|
|
|
_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n' |
|
|
|
|
'\tstatus = {}\n' |
|
|
|
@ -146,31 +147,48 @@ def _create_rpc_error(initial_metadata: Optional[MetadataType], |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Call(_base_call.Call): |
|
|
|
|
"""Base implementation of client RPC Call object. |
|
|
|
|
|
|
|
|
|
Implements logic around final status, metadata and cancellation. |
|
|
|
|
""" |
|
|
|
|
_loop: asyncio.AbstractEventLoop |
|
|
|
|
_code: grpc.StatusCode |
|
|
|
|
_status: Awaitable[cygrpc.AioRpcStatus] |
|
|
|
|
_initial_metadata: Awaitable[MetadataType] |
|
|
|
|
_locally_cancelled: bool |
|
|
|
|
_cython_call: cygrpc._AioCall |
|
|
|
|
|
|
|
|
|
def __init__(self) -> None: |
|
|
|
|
def __init__(self, cython_call: cygrpc._AioCall) -> None: |
|
|
|
|
self._loop = asyncio.get_event_loop() |
|
|
|
|
self._code = None |
|
|
|
|
self._status = self._loop.create_future() |
|
|
|
|
self._initial_metadata = self._loop.create_future() |
|
|
|
|
self._locally_cancelled = False |
|
|
|
|
self._cython_call = cython_call |
|
|
|
|
|
|
|
|
|
def cancel(self) -> bool: |
|
|
|
|
"""Placeholder cancellation method. |
|
|
|
|
|
|
|
|
|
The implementation of this method needs to pass the cancellation reason |
|
|
|
|
into self._cancellation, using `set_result` instead of |
|
|
|
|
`set_exception`. |
|
|
|
|
""" |
|
|
|
|
raise NotImplementedError() |
|
|
|
|
def __del__(self) -> None: |
|
|
|
|
if not self._status.done(): |
|
|
|
|
self._cancel( |
|
|
|
|
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, |
|
|
|
|
_GC_CANCELLATION_DETAILS, None, None)) |
|
|
|
|
|
|
|
|
|
def cancelled(self) -> bool: |
|
|
|
|
return self._code == grpc.StatusCode.CANCELLED |
|
|
|
|
|
|
|
|
|
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool: |
|
|
|
|
"""Forwards the application cancellation reasoning.""" |
|
|
|
|
if not self._status.done(): |
|
|
|
|
self._set_status(status) |
|
|
|
|
self._cython_call.cancel(status) |
|
|
|
|
return True |
|
|
|
|
else: |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
def cancel(self) -> bool: |
|
|
|
|
return self._cancel( |
|
|
|
|
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, |
|
|
|
|
_LOCAL_CANCELLATION_DETAILS, None, None)) |
|
|
|
|
|
|
|
|
|
def done(self) -> bool: |
|
|
|
|
return self._status.done() |
|
|
|
|
|
|
|
|
@ -247,6 +265,7 @@ class Call(_base_call.Call): |
|
|
|
|
return self._repr() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression |
|
|
|
|
# pylint: disable=abstract-method |
|
|
|
|
class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): |
|
|
|
|
"""Object for managing unary-unary RPC calls. |
|
|
|
@ -254,37 +273,29 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): |
|
|
|
|
Returned when an instance of `UnaryUnaryMultiCallable` object is called. |
|
|
|
|
""" |
|
|
|
|
_request: RequestType |
|
|
|
|
_channel: cygrpc.AioChannel |
|
|
|
|
_request_serializer: SerializingFunction |
|
|
|
|
_response_deserializer: DeserializingFunction |
|
|
|
|
_call: asyncio.Task |
|
|
|
|
_cython_call: cygrpc._AioCall |
|
|
|
|
|
|
|
|
|
def __init__( # pylint: disable=R0913 |
|
|
|
|
self, request: RequestType, deadline: Optional[float], |
|
|
|
|
credentials: Optional[grpc.CallCredentials], |
|
|
|
|
channel: cygrpc.AioChannel, method: bytes, |
|
|
|
|
request_serializer: SerializingFunction, |
|
|
|
|
response_deserializer: DeserializingFunction) -> None: |
|
|
|
|
super().__init__() |
|
|
|
|
# pylint: disable=too-many-arguments |
|
|
|
|
def __init__(self, request: RequestType, deadline: Optional[float], |
|
|
|
|
credentials: Optional[grpc.CallCredentials], |
|
|
|
|
channel: cygrpc.AioChannel, method: bytes, |
|
|
|
|
request_serializer: SerializingFunction, |
|
|
|
|
response_deserializer: DeserializingFunction) -> None: |
|
|
|
|
channel.call(method, deadline, credentials) |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials)) |
|
|
|
|
self._request = request |
|
|
|
|
self._channel = channel |
|
|
|
|
self._request_serializer = request_serializer |
|
|
|
|
self._response_deserializer = response_deserializer |
|
|
|
|
|
|
|
|
|
if credentials is not None: |
|
|
|
|
grpc_credentials = credentials._credentials |
|
|
|
|
else: |
|
|
|
|
grpc_credentials = None |
|
|
|
|
self._cython_call = self._channel.call(method, deadline, |
|
|
|
|
grpc_credentials) |
|
|
|
|
self._call = self._loop.create_task(self._invoke()) |
|
|
|
|
|
|
|
|
|
def __del__(self) -> None: |
|
|
|
|
if not self._call.done(): |
|
|
|
|
self._cancel( |
|
|
|
|
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, |
|
|
|
|
_GC_CANCELLATION_DETAILS, None, None)) |
|
|
|
|
def cancel(self) -> bool: |
|
|
|
|
if super().cancel(): |
|
|
|
|
self._call.cancel() |
|
|
|
|
return True |
|
|
|
|
else: |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
async def _invoke(self) -> ResponseType: |
|
|
|
|
serialized_request = _common.serialize(self._request, |
|
|
|
@ -300,7 +311,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): |
|
|
|
|
self._set_status, |
|
|
|
|
) |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if self._code != grpc.StatusCode.CANCELLED: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
|
|
|
|
|
# Raises here if RPC failed or cancelled |
|
|
|
@ -309,21 +320,6 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): |
|
|
|
|
return _common.deserialize(serialized_response, |
|
|
|
|
self._response_deserializer) |
|
|
|
|
|
|
|
|
|
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool: |
|
|
|
|
"""Forwards the application cancellation reasoning.""" |
|
|
|
|
if not self._status.done(): |
|
|
|
|
self._set_status(status) |
|
|
|
|
self._cython_call.cancel(status) |
|
|
|
|
self._call.cancel() |
|
|
|
|
return True |
|
|
|
|
else: |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
def cancel(self) -> bool: |
|
|
|
|
return self._cancel( |
|
|
|
|
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, |
|
|
|
|
_LOCAL_CANCELLATION_DETAILS, None, None)) |
|
|
|
|
|
|
|
|
|
def __await__(self) -> ResponseType: |
|
|
|
|
"""Wait till the ongoing RPC request finishes.""" |
|
|
|
|
try: |
|
|
|
@ -339,6 +335,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): |
|
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression |
|
|
|
|
# pylint: disable=abstract-method |
|
|
|
|
class UnaryStreamCall(Call, _base_call.UnaryStreamCall): |
|
|
|
|
"""Object for managing unary-stream RPC calls. |
|
|
|
@ -346,107 +343,346 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): |
|
|
|
|
Returned when an instance of `UnaryStreamMultiCallable` object is called. |
|
|
|
|
""" |
|
|
|
|
_request: RequestType |
|
|
|
|
_channel: cygrpc.AioChannel |
|
|
|
|
_request_serializer: SerializingFunction |
|
|
|
|
_response_deserializer: DeserializingFunction |
|
|
|
|
_cython_call: cygrpc._AioCall |
|
|
|
|
_send_unary_request_task: asyncio.Task |
|
|
|
|
_message_aiter: AsyncIterable[ResponseType] |
|
|
|
|
|
|
|
|
|
def __init__( # pylint: disable=R0913 |
|
|
|
|
self, request: RequestType, deadline: Optional[float], |
|
|
|
|
credentials: Optional[grpc.CallCredentials], |
|
|
|
|
channel: cygrpc.AioChannel, method: bytes, |
|
|
|
|
request_serializer: SerializingFunction, |
|
|
|
|
response_deserializer: DeserializingFunction) -> None: |
|
|
|
|
super().__init__() |
|
|
|
|
# pylint: disable=too-many-arguments |
|
|
|
|
def __init__(self, request: RequestType, deadline: Optional[float], |
|
|
|
|
credentials: Optional[grpc.CallCredentials], |
|
|
|
|
channel: cygrpc.AioChannel, method: bytes, |
|
|
|
|
request_serializer: SerializingFunction, |
|
|
|
|
response_deserializer: DeserializingFunction) -> None: |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials)) |
|
|
|
|
self._request = request |
|
|
|
|
self._channel = channel |
|
|
|
|
self._request_serializer = request_serializer |
|
|
|
|
self._response_deserializer = response_deserializer |
|
|
|
|
self._send_unary_request_task = self._loop.create_task( |
|
|
|
|
self._send_unary_request()) |
|
|
|
|
self._message_aiter = self._fetch_stream_responses() |
|
|
|
|
self._message_aiter = None |
|
|
|
|
|
|
|
|
|
if credentials is not None: |
|
|
|
|
grpc_credentials = credentials._credentials |
|
|
|
|
def cancel(self) -> bool: |
|
|
|
|
if super().cancel(): |
|
|
|
|
self._send_unary_request_task.cancel() |
|
|
|
|
return True |
|
|
|
|
else: |
|
|
|
|
grpc_credentials = None |
|
|
|
|
|
|
|
|
|
self._cython_call = self._channel.call(method, deadline, |
|
|
|
|
grpc_credentials) |
|
|
|
|
|
|
|
|
|
def __del__(self) -> None: |
|
|
|
|
if not self._status.done(): |
|
|
|
|
self._cancel( |
|
|
|
|
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, |
|
|
|
|
_GC_CANCELLATION_DETAILS, None, None)) |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
async def _send_unary_request(self) -> ResponseType: |
|
|
|
|
serialized_request = _common.serialize(self._request, |
|
|
|
|
self._request_serializer) |
|
|
|
|
try: |
|
|
|
|
await self._cython_call.unary_stream(serialized_request, |
|
|
|
|
self._set_initial_metadata, |
|
|
|
|
self._set_status) |
|
|
|
|
await self._cython_call.initiate_unary_stream( |
|
|
|
|
serialized_request, self._set_initial_metadata, |
|
|
|
|
self._set_status) |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if self._code != grpc.StatusCode.CANCELLED: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
raise |
|
|
|
|
|
|
|
|
|
async def _fetch_stream_responses(self) -> ResponseType: |
|
|
|
|
await self._send_unary_request_task |
|
|
|
|
message = await self._read() |
|
|
|
|
while message: |
|
|
|
|
while message is not cygrpc.EOF: |
|
|
|
|
yield message |
|
|
|
|
message = await self._read() |
|
|
|
|
|
|
|
|
|
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool: |
|
|
|
|
"""Forwards the application cancellation reasoning. |
|
|
|
|
def __aiter__(self) -> AsyncIterable[ResponseType]: |
|
|
|
|
if self._message_aiter is None: |
|
|
|
|
self._message_aiter = self._fetch_stream_responses() |
|
|
|
|
return self._message_aiter |
|
|
|
|
|
|
|
|
|
Async generator will receive an exception. The cancellation will go |
|
|
|
|
deep down into Core, and then propagates backup as the |
|
|
|
|
`cygrpc.AioRpcStatus` exception. |
|
|
|
|
async def _read(self) -> ResponseType: |
|
|
|
|
# Wait for the request being sent |
|
|
|
|
await self._send_unary_request_task |
|
|
|
|
|
|
|
|
|
So, under race condition, e.g. the server sent out final state headers |
|
|
|
|
and the client calling "cancel" at the same time, this method respects |
|
|
|
|
the winner in Core. |
|
|
|
|
""" |
|
|
|
|
if not self._status.done(): |
|
|
|
|
self._set_status(status) |
|
|
|
|
self._cython_call.cancel(status) |
|
|
|
|
# Reads response message from Core |
|
|
|
|
try: |
|
|
|
|
raw_response = await self._cython_call.receive_serialized_message() |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
|
|
|
|
|
if raw_response is cygrpc.EOF: |
|
|
|
|
return cygrpc.EOF |
|
|
|
|
else: |
|
|
|
|
return _common.deserialize(raw_response, |
|
|
|
|
self._response_deserializer) |
|
|
|
|
|
|
|
|
|
async def read(self) -> ResponseType: |
|
|
|
|
if self._status.done(): |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
return cygrpc.EOF |
|
|
|
|
|
|
|
|
|
response_message = await self._read() |
|
|
|
|
|
|
|
|
|
if response_message is cygrpc.EOF: |
|
|
|
|
# If the read operation failed, Core should explain why. |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
return response_message |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression |
|
|
|
|
# pylint: disable=abstract-method |
|
|
|
|
class StreamUnaryCall(Call, _base_call.StreamUnaryCall): |
|
|
|
|
"""Object for managing stream-unary RPC calls. |
|
|
|
|
|
|
|
|
|
Returned when an instance of `StreamUnaryMultiCallable` object is called. |
|
|
|
|
""" |
|
|
|
|
_metadata: MetadataType |
|
|
|
|
_request_serializer: SerializingFunction |
|
|
|
|
_response_deserializer: DeserializingFunction |
|
|
|
|
|
|
|
|
|
_metadata_sent: asyncio.Event |
|
|
|
|
_done_writing: bool |
|
|
|
|
_call_finisher: asyncio.Task |
|
|
|
|
_async_request_poller: asyncio.Task |
|
|
|
|
|
|
|
|
|
if not self._send_unary_request_task.done(): |
|
|
|
|
# Injects CancelledError to the Task. The exception will |
|
|
|
|
# propagate to _fetch_stream_responses as well, if the sending |
|
|
|
|
# is not done. |
|
|
|
|
self._send_unary_request_task.cancel() |
|
|
|
|
# pylint: disable=too-many-arguments |
|
|
|
|
def __init__(self, |
|
|
|
|
request_async_iterator: Optional[AsyncIterable[RequestType]], |
|
|
|
|
deadline: Optional[float], |
|
|
|
|
credentials: Optional[grpc.CallCredentials], |
|
|
|
|
channel: cygrpc.AioChannel, method: bytes, |
|
|
|
|
request_serializer: SerializingFunction, |
|
|
|
|
response_deserializer: DeserializingFunction) -> None: |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials)) |
|
|
|
|
self._metadata = _EMPTY_METADATA |
|
|
|
|
self._request_serializer = request_serializer |
|
|
|
|
self._response_deserializer = response_deserializer |
|
|
|
|
|
|
|
|
|
self._metadata_sent = asyncio.Event(loop=self._loop) |
|
|
|
|
self._done_writing = False |
|
|
|
|
|
|
|
|
|
self._call_finisher = self._loop.create_task(self._conduct_rpc()) |
|
|
|
|
|
|
|
|
|
# If user passes in an async iterator, create a consumer Task. |
|
|
|
|
if request_async_iterator is not None: |
|
|
|
|
self._async_request_poller = self._loop.create_task( |
|
|
|
|
self._consume_request_iterator(request_async_iterator)) |
|
|
|
|
else: |
|
|
|
|
self._async_request_poller = None |
|
|
|
|
|
|
|
|
|
def cancel(self) -> bool: |
|
|
|
|
if super().cancel(): |
|
|
|
|
self._call_finisher.cancel() |
|
|
|
|
if self._async_request_poller is not None: |
|
|
|
|
self._async_request_poller.cancel() |
|
|
|
|
return True |
|
|
|
|
else: |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
def _metadata_sent_observer(self): |
|
|
|
|
self._metadata_sent.set() |
|
|
|
|
|
|
|
|
|
async def _conduct_rpc(self) -> ResponseType: |
|
|
|
|
try: |
|
|
|
|
serialized_response = await self._cython_call.stream_unary( |
|
|
|
|
self._metadata, |
|
|
|
|
self._metadata_sent_observer, |
|
|
|
|
self._set_initial_metadata, |
|
|
|
|
self._set_status, |
|
|
|
|
) |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
|
|
|
|
|
# Raises RpcError if the RPC failed or cancelled |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
|
|
|
|
|
return _common.deserialize(serialized_response, |
|
|
|
|
self._response_deserializer) |
|
|
|
|
|
|
|
|
|
async def _consume_request_iterator( |
|
|
|
|
self, request_async_iterator: AsyncIterable[RequestType]) -> None: |
|
|
|
|
async for request in request_async_iterator: |
|
|
|
|
await self.write(request) |
|
|
|
|
await self.done_writing() |
|
|
|
|
|
|
|
|
|
def __await__(self) -> ResponseType: |
|
|
|
|
"""Wait till the ongoing RPC request finishes.""" |
|
|
|
|
try: |
|
|
|
|
response = yield from self._call_finisher |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
raise |
|
|
|
|
return response |
|
|
|
|
|
|
|
|
|
async def write(self, request: RequestType) -> None: |
|
|
|
|
if self._status.done(): |
|
|
|
|
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) |
|
|
|
|
if self._done_writing: |
|
|
|
|
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) |
|
|
|
|
if not self._metadata_sent.is_set(): |
|
|
|
|
await self._metadata_sent.wait() |
|
|
|
|
|
|
|
|
|
serialized_request = _common.serialize(request, |
|
|
|
|
self._request_serializer) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
await self._cython_call.send_serialized_message(serialized_request) |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
|
|
|
|
|
async def done_writing(self) -> None: |
|
|
|
|
"""Implementation of done_writing is idempotent.""" |
|
|
|
|
if self._status.done(): |
|
|
|
|
# If the RPC is finished, do nothing. |
|
|
|
|
return |
|
|
|
|
if not self._done_writing: |
|
|
|
|
# If the done writing is not sent before, try to send it. |
|
|
|
|
self._done_writing = True |
|
|
|
|
try: |
|
|
|
|
await self._cython_call.send_receive_close() |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO(https://github.com/grpc/grpc/issues/21623) remove this suppression |
|
|
|
|
# pylint: disable=abstract-method |
|
|
|
|
class StreamStreamCall(Call, _base_call.StreamStreamCall): |
|
|
|
|
"""Object for managing stream-stream RPC calls. |
|
|
|
|
|
|
|
|
|
Returned when an instance of `StreamStreamMultiCallable` object is called. |
|
|
|
|
""" |
|
|
|
|
_metadata: MetadataType |
|
|
|
|
_request_serializer: SerializingFunction |
|
|
|
|
_response_deserializer: DeserializingFunction |
|
|
|
|
|
|
|
|
|
_metadata_sent: asyncio.Event |
|
|
|
|
_done_writing: bool |
|
|
|
|
_initializer: asyncio.Task |
|
|
|
|
_async_request_poller: asyncio.Task |
|
|
|
|
_message_aiter: AsyncIterable[ResponseType] |
|
|
|
|
|
|
|
|
|
# pylint: disable=too-many-arguments |
|
|
|
|
def __init__(self, |
|
|
|
|
request_async_iterator: Optional[AsyncIterable[RequestType]], |
|
|
|
|
deadline: Optional[float], |
|
|
|
|
credentials: Optional[grpc.CallCredentials], |
|
|
|
|
channel: cygrpc.AioChannel, method: bytes, |
|
|
|
|
request_serializer: SerializingFunction, |
|
|
|
|
response_deserializer: DeserializingFunction) -> None: |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials)) |
|
|
|
|
self._metadata = _EMPTY_METADATA |
|
|
|
|
self._request_serializer = request_serializer |
|
|
|
|
self._response_deserializer = response_deserializer |
|
|
|
|
|
|
|
|
|
self._metadata_sent = asyncio.Event(loop=self._loop) |
|
|
|
|
self._done_writing = False |
|
|
|
|
|
|
|
|
|
self._initializer = self._loop.create_task(self._prepare_rpc()) |
|
|
|
|
|
|
|
|
|
# If user passes in an async iterator, create a consumer coroutine. |
|
|
|
|
if request_async_iterator is not None: |
|
|
|
|
self._async_request_poller = self._loop.create_task( |
|
|
|
|
self._consume_request_iterator(request_async_iterator)) |
|
|
|
|
else: |
|
|
|
|
self._async_request_poller = None |
|
|
|
|
self._message_aiter = None |
|
|
|
|
|
|
|
|
|
def cancel(self) -> bool: |
|
|
|
|
return self._cancel( |
|
|
|
|
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, |
|
|
|
|
_LOCAL_CANCELLATION_DETAILS, None, None)) |
|
|
|
|
if super().cancel(): |
|
|
|
|
self._initializer.cancel() |
|
|
|
|
if self._async_request_poller is not None: |
|
|
|
|
self._async_request_poller.cancel() |
|
|
|
|
return True |
|
|
|
|
else: |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
def _metadata_sent_observer(self): |
|
|
|
|
self._metadata_sent.set() |
|
|
|
|
|
|
|
|
|
async def _prepare_rpc(self): |
|
|
|
|
"""This method prepares the RPC for receiving/sending messages. |
|
|
|
|
|
|
|
|
|
All other operations around the stream should only happen after the |
|
|
|
|
completion of this method. |
|
|
|
|
""" |
|
|
|
|
try: |
|
|
|
|
await self._cython_call.initiate_stream_stream( |
|
|
|
|
self._metadata, |
|
|
|
|
self._metadata_sent_observer, |
|
|
|
|
self._set_initial_metadata, |
|
|
|
|
self._set_status, |
|
|
|
|
) |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
# No need to raise RpcError here, because no one will `await` this task. |
|
|
|
|
|
|
|
|
|
async def _consume_request_iterator( |
|
|
|
|
self, request_async_iterator: Optional[AsyncIterable[RequestType]] |
|
|
|
|
) -> None: |
|
|
|
|
async for request in request_async_iterator: |
|
|
|
|
await self.write(request) |
|
|
|
|
await self.done_writing() |
|
|
|
|
|
|
|
|
|
async def write(self, request: RequestType) -> None: |
|
|
|
|
if self._status.done(): |
|
|
|
|
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) |
|
|
|
|
if self._done_writing: |
|
|
|
|
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) |
|
|
|
|
if not self._metadata_sent.is_set(): |
|
|
|
|
await self._metadata_sent.wait() |
|
|
|
|
|
|
|
|
|
serialized_request = _common.serialize(request, |
|
|
|
|
self._request_serializer) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
await self._cython_call.send_serialized_message(serialized_request) |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
|
|
|
|
|
async def done_writing(self) -> None: |
|
|
|
|
"""Implementation of done_writing is idempotent.""" |
|
|
|
|
if self._status.done(): |
|
|
|
|
# If the RPC is finished, do nothing. |
|
|
|
|
return |
|
|
|
|
if not self._done_writing: |
|
|
|
|
# If the done writing is not sent before, try to send it. |
|
|
|
|
self._done_writing = True |
|
|
|
|
try: |
|
|
|
|
await self._cython_call.send_receive_close() |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
|
|
|
|
|
async def _fetch_stream_responses(self) -> ResponseType: |
|
|
|
|
"""The async generator that yields responses from peer.""" |
|
|
|
|
message = await self._read() |
|
|
|
|
while message is not cygrpc.EOF: |
|
|
|
|
yield message |
|
|
|
|
message = await self._read() |
|
|
|
|
|
|
|
|
|
def __aiter__(self) -> AsyncIterable[ResponseType]: |
|
|
|
|
if self._message_aiter is None: |
|
|
|
|
self._message_aiter = self._fetch_stream_responses() |
|
|
|
|
return self._message_aiter |
|
|
|
|
|
|
|
|
|
async def _read(self) -> ResponseType: |
|
|
|
|
# Wait for the request being sent |
|
|
|
|
await self._send_unary_request_task |
|
|
|
|
# Wait for the setup |
|
|
|
|
await self._initializer |
|
|
|
|
|
|
|
|
|
# Reads response message from Core |
|
|
|
|
try: |
|
|
|
|
raw_response = await self._cython_call.receive_serialized_message() |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if self._code != grpc.StatusCode.CANCELLED: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
|
raise |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
|
|
|
|
|
if raw_response is None: |
|
|
|
|
return None |
|
|
|
|
if raw_response is cygrpc.EOF: |
|
|
|
|
return cygrpc.EOF |
|
|
|
|
else: |
|
|
|
|
return _common.deserialize(raw_response, |
|
|
|
|
self._response_deserializer) |
|
|
|
@ -454,14 +690,11 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): |
|
|
|
|
async def read(self) -> ResponseType: |
|
|
|
|
if self._status.done(): |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) |
|
|
|
|
return cygrpc.EOF |
|
|
|
|
|
|
|
|
|
response_message = await self._read() |
|
|
|
|
|
|
|
|
|
if response_message is None: |
|
|
|
|
if response_message is cygrpc.EOF: |
|
|
|
|
# If the read operation failed, Core should explain why. |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
# If no exception raised, there is something wrong internally. |
|
|
|
|
assert False, 'Read operation failed with StatusCode.OK' |
|
|
|
|
else: |
|
|
|
|
return response_message |
|
|
|
|
return response_message |
|
|
|
|