|
|
|
@ -14,7 +14,8 @@ |
|
|
|
|
"""Invocation-side implementation of gRPC Asyncio Python.""" |
|
|
|
|
|
|
|
|
|
import asyncio |
|
|
|
|
from typing import AsyncIterable, Awaitable, List, Dict, Optional |
|
|
|
|
from functools import partial |
|
|
|
|
from typing import AsyncIterable, Dict, Optional |
|
|
|
|
|
|
|
|
|
import grpc |
|
|
|
|
from grpc import _common |
|
|
|
@ -42,8 +43,6 @@ _NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n' |
|
|
|
|
'\tdebug_error_string = "{}"\n' |
|
|
|
|
'>') |
|
|
|
|
|
|
|
|
|
_EMPTY_METADATA = tuple() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AioRpcError(grpc.RpcError): |
|
|
|
|
"""An implementation of RpcError to be used by the asynchronous API. |
|
|
|
@ -153,116 +152,67 @@ class Call(_base_call.Call): |
|
|
|
|
""" |
|
|
|
|
_loop: asyncio.AbstractEventLoop |
|
|
|
|
_code: grpc.StatusCode |
|
|
|
|
_status: Awaitable[cygrpc.AioRpcStatus] |
|
|
|
|
_initial_metadata: Awaitable[MetadataType] |
|
|
|
|
_locally_cancelled: bool |
|
|
|
|
_cython_call: cygrpc._AioCall |
|
|
|
|
_done_callbacks: List[DoneCallbackType] |
|
|
|
|
|
|
|
|
|
def __init__(self, cython_call: cygrpc._AioCall) -> None: |
|
|
|
|
self._loop = asyncio.get_event_loop() |
|
|
|
|
self._code = None |
|
|
|
|
self._status = self._loop.create_future() |
|
|
|
|
self._initial_metadata = self._loop.create_future() |
|
|
|
|
self._locally_cancelled = False |
|
|
|
|
|
|
|
|
|
def __init__(self, cython_call: cygrpc._AioCall, |
|
|
|
|
loop: asyncio.AbstractEventLoop) -> None: |
|
|
|
|
self._loop = loop |
|
|
|
|
self._cython_call = cython_call |
|
|
|
|
self._done_callbacks = [] |
|
|
|
|
|
|
|
|
|
def __del__(self) -> None: |
|
|
|
|
if not self._status.done(): |
|
|
|
|
self._cancel( |
|
|
|
|
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, |
|
|
|
|
_GC_CANCELLATION_DETAILS, None, None)) |
|
|
|
|
if not self._cython_call.done(): |
|
|
|
|
self._cancel(_GC_CANCELLATION_DETAILS) |
|
|
|
|
|
|
|
|
|
def cancelled(self) -> bool: |
|
|
|
|
return self._code == grpc.StatusCode.CANCELLED |
|
|
|
|
return self._cython_call.cancelled() |
|
|
|
|
|
|
|
|
|
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool: |
|
|
|
|
def _cancel(self, details: str) -> bool: |
|
|
|
|
"""Forwards the application cancellation reasoning.""" |
|
|
|
|
if not self._status.done(): |
|
|
|
|
self._set_status(status) |
|
|
|
|
self._cython_call.cancel(status) |
|
|
|
|
if not self._cython_call.done(): |
|
|
|
|
self._cython_call.cancel(details) |
|
|
|
|
return True |
|
|
|
|
else: |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
def cancel(self) -> bool: |
|
|
|
|
return self._cancel( |
|
|
|
|
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled, |
|
|
|
|
_LOCAL_CANCELLATION_DETAILS, None, None)) |
|
|
|
|
return self._cancel(_LOCAL_CANCELLATION_DETAILS) |
|
|
|
|
|
|
|
|
|
def done(self) -> bool: |
|
|
|
|
return self._status.done() |
|
|
|
|
return self._cython_call.done() |
|
|
|
|
|
|
|
|
|
def add_done_callback(self, callback: DoneCallbackType) -> None: |
|
|
|
|
if self.done(): |
|
|
|
|
callback(self) |
|
|
|
|
else: |
|
|
|
|
self._done_callbacks.append(callback) |
|
|
|
|
cb = partial(callback, self) |
|
|
|
|
self._cython_call.add_done_callback(cb) |
|
|
|
|
|
|
|
|
|
def time_remaining(self) -> Optional[float]: |
|
|
|
|
return self._cython_call.time_remaining() |
|
|
|
|
|
|
|
|
|
async def initial_metadata(self) -> MetadataType: |
|
|
|
|
return await self._initial_metadata |
|
|
|
|
return await self._cython_call.initial_metadata() |
|
|
|
|
|
|
|
|
|
async def trailing_metadata(self) -> MetadataType: |
|
|
|
|
return (await self._status).trailing_metadata() |
|
|
|
|
return (await self._cython_call.status()).trailing_metadata() |
|
|
|
|
|
|
|
|
|
async def code(self) -> grpc.StatusCode: |
|
|
|
|
await self._status |
|
|
|
|
return self._code |
|
|
|
|
cygrpc_code = (await self._cython_call.status()).code() |
|
|
|
|
return _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[cygrpc_code] |
|
|
|
|
|
|
|
|
|
async def details(self) -> str: |
|
|
|
|
return (await self._status).details() |
|
|
|
|
return (await self._cython_call.status()).details() |
|
|
|
|
|
|
|
|
|
async def debug_error_string(self) -> str: |
|
|
|
|
return (await self._status).debug_error_string() |
|
|
|
|
|
|
|
|
|
def _set_initial_metadata(self, metadata: MetadataType) -> None: |
|
|
|
|
self._initial_metadata.set_result(metadata) |
|
|
|
|
|
|
|
|
|
def _set_status(self, status: cygrpc.AioRpcStatus) -> None: |
|
|
|
|
"""Private method to set final status of the RPC. |
|
|
|
|
|
|
|
|
|
This method should only be invoked once. |
|
|
|
|
""" |
|
|
|
|
# In case of local cancellation, flip the flag. |
|
|
|
|
if status.details() is _LOCAL_CANCELLATION_DETAILS: |
|
|
|
|
self._locally_cancelled = True |
|
|
|
|
|
|
|
|
|
# In case of the RPC finished without receiving metadata. |
|
|
|
|
if not self._initial_metadata.done(): |
|
|
|
|
self._initial_metadata.set_result(_EMPTY_METADATA) |
|
|
|
|
|
|
|
|
|
# Sets final status |
|
|
|
|
self._status.set_result(status) |
|
|
|
|
self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()] |
|
|
|
|
|
|
|
|
|
for callback in self._done_callbacks: |
|
|
|
|
callback(self) |
|
|
|
|
return (await self._cython_call.status()).debug_error_string() |
|
|
|
|
|
|
|
|
|
async def _raise_for_status(self) -> None: |
|
|
|
|
if self._locally_cancelled: |
|
|
|
|
if self._cython_call.is_locally_cancelled(): |
|
|
|
|
raise asyncio.CancelledError() |
|
|
|
|
await self._status |
|
|
|
|
if self._code != grpc.StatusCode.OK: |
|
|
|
|
raise _create_rpc_error(await self.initial_metadata(), |
|
|
|
|
self._status.result()) |
|
|
|
|
code = await self.code() |
|
|
|
|
if code != grpc.StatusCode.OK: |
|
|
|
|
raise _create_rpc_error(await self.initial_metadata(), await |
|
|
|
|
self._cython_call.status()) |
|
|
|
|
|
|
|
|
|
def _repr(self) -> str: |
|
|
|
|
"""Assembles the RPC representation string.""" |
|
|
|
|
if not self._status.done(): |
|
|
|
|
return '<{} object>'.format(self.__class__.__name__) |
|
|
|
|
if self._code is grpc.StatusCode.OK: |
|
|
|
|
return _OK_CALL_REPRESENTATION.format( |
|
|
|
|
self.__class__.__name__, self._code, |
|
|
|
|
self._status.result().details()) |
|
|
|
|
else: |
|
|
|
|
return _NON_OK_CALL_REPRESENTATION.format( |
|
|
|
|
self.__class__.__name__, self._code, |
|
|
|
|
self._status.result().details(), |
|
|
|
|
self._status.result().debug_error_string()) |
|
|
|
|
return repr(self._cython_call) |
|
|
|
|
|
|
|
|
|
def __repr__(self) -> str: |
|
|
|
|
return self._repr() |
|
|
|
@ -288,13 +238,14 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): |
|
|
|
|
credentials: Optional[grpc.CallCredentials], |
|
|
|
|
channel: cygrpc.AioChannel, method: bytes, |
|
|
|
|
request_serializer: SerializingFunction, |
|
|
|
|
response_deserializer: DeserializingFunction) -> None: |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials)) |
|
|
|
|
response_deserializer: DeserializingFunction, |
|
|
|
|
loop: asyncio.AbstractEventLoop) -> None: |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials), loop) |
|
|
|
|
self._request = request |
|
|
|
|
self._metadata = metadata |
|
|
|
|
self._request_serializer = request_serializer |
|
|
|
|
self._response_deserializer = response_deserializer |
|
|
|
|
self._call = self._loop.create_task(self._invoke()) |
|
|
|
|
self._call = loop.create_task(self._invoke()) |
|
|
|
|
|
|
|
|
|
def cancel(self) -> bool: |
|
|
|
|
if super().cancel(): |
|
|
|
@ -312,11 +263,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): |
|
|
|
|
# https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785 |
|
|
|
|
try: |
|
|
|
|
serialized_response = await self._cython_call.unary_unary( |
|
|
|
|
serialized_request, |
|
|
|
|
self._metadata, |
|
|
|
|
self._set_initial_metadata, |
|
|
|
|
self._set_status, |
|
|
|
|
) |
|
|
|
|
serialized_request, self._metadata) |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
@ -360,13 +307,14 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): |
|
|
|
|
credentials: Optional[grpc.CallCredentials], |
|
|
|
|
channel: cygrpc.AioChannel, method: bytes, |
|
|
|
|
request_serializer: SerializingFunction, |
|
|
|
|
response_deserializer: DeserializingFunction) -> None: |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials)) |
|
|
|
|
response_deserializer: DeserializingFunction, |
|
|
|
|
loop: asyncio.AbstractEventLoop) -> None: |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials), loop) |
|
|
|
|
self._request = request |
|
|
|
|
self._metadata = metadata |
|
|
|
|
self._request_serializer = request_serializer |
|
|
|
|
self._response_deserializer = response_deserializer |
|
|
|
|
self._send_unary_request_task = self._loop.create_task( |
|
|
|
|
self._send_unary_request_task = loop.create_task( |
|
|
|
|
self._send_unary_request()) |
|
|
|
|
self._message_aiter = None |
|
|
|
|
|
|
|
|
@ -382,8 +330,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): |
|
|
|
|
self._request_serializer) |
|
|
|
|
try: |
|
|
|
|
await self._cython_call.initiate_unary_stream( |
|
|
|
|
serialized_request, self._metadata, self._set_initial_metadata, |
|
|
|
|
self._set_status) |
|
|
|
|
serialized_request, self._metadata) |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
@ -419,7 +366,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): |
|
|
|
|
self._response_deserializer) |
|
|
|
|
|
|
|
|
|
async def read(self) -> ResponseType: |
|
|
|
|
if self._status.done(): |
|
|
|
|
if self._cython_call.done(): |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
return cygrpc.EOF |
|
|
|
|
|
|
|
|
@ -452,16 +399,17 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall): |
|
|
|
|
credentials: Optional[grpc.CallCredentials], |
|
|
|
|
channel: cygrpc.AioChannel, method: bytes, |
|
|
|
|
request_serializer: SerializingFunction, |
|
|
|
|
response_deserializer: DeserializingFunction) -> None: |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials)) |
|
|
|
|
response_deserializer: DeserializingFunction, |
|
|
|
|
loop: asyncio.AbstractEventLoop) -> None: |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials), loop) |
|
|
|
|
self._metadata = metadata |
|
|
|
|
self._request_serializer = request_serializer |
|
|
|
|
self._response_deserializer = response_deserializer |
|
|
|
|
|
|
|
|
|
self._metadata_sent = asyncio.Event(loop=self._loop) |
|
|
|
|
self._metadata_sent = asyncio.Event(loop=loop) |
|
|
|
|
self._done_writing = False |
|
|
|
|
|
|
|
|
|
self._call_finisher = self._loop.create_task(self._conduct_rpc()) |
|
|
|
|
self._call_finisher = loop.create_task(self._conduct_rpc()) |
|
|
|
|
|
|
|
|
|
# If user passes in an async iterator, create a consumer Task. |
|
|
|
|
if request_async_iterator is not None: |
|
|
|
@ -485,11 +433,7 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall): |
|
|
|
|
async def _conduct_rpc(self) -> ResponseType: |
|
|
|
|
try: |
|
|
|
|
serialized_response = await self._cython_call.stream_unary( |
|
|
|
|
self._metadata, |
|
|
|
|
self._metadata_sent_observer, |
|
|
|
|
self._set_initial_metadata, |
|
|
|
|
self._set_status, |
|
|
|
|
) |
|
|
|
|
self._metadata, self._metadata_sent_observer) |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
@ -517,7 +461,7 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall): |
|
|
|
|
return response |
|
|
|
|
|
|
|
|
|
async def write(self, request: RequestType) -> None: |
|
|
|
|
if self._status.done(): |
|
|
|
|
if self._cython_call.done(): |
|
|
|
|
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) |
|
|
|
|
if self._done_writing: |
|
|
|
|
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) |
|
|
|
@ -536,7 +480,7 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall): |
|
|
|
|
|
|
|
|
|
async def done_writing(self) -> None: |
|
|
|
|
"""Implementation of done_writing is idempotent.""" |
|
|
|
|
if self._status.done(): |
|
|
|
|
if self._cython_call.done(): |
|
|
|
|
# If the RPC is finished, do nothing. |
|
|
|
|
return |
|
|
|
|
if not self._done_writing: |
|
|
|
@ -572,20 +516,21 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall): |
|
|
|
|
credentials: Optional[grpc.CallCredentials], |
|
|
|
|
channel: cygrpc.AioChannel, method: bytes, |
|
|
|
|
request_serializer: SerializingFunction, |
|
|
|
|
response_deserializer: DeserializingFunction) -> None: |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials)) |
|
|
|
|
response_deserializer: DeserializingFunction, |
|
|
|
|
loop: asyncio.AbstractEventLoop) -> None: |
|
|
|
|
super().__init__(channel.call(method, deadline, credentials), loop) |
|
|
|
|
self._metadata = metadata |
|
|
|
|
self._request_serializer = request_serializer |
|
|
|
|
self._response_deserializer = response_deserializer |
|
|
|
|
|
|
|
|
|
self._metadata_sent = asyncio.Event(loop=self._loop) |
|
|
|
|
self._metadata_sent = asyncio.Event(loop=loop) |
|
|
|
|
self._done_writing = False |
|
|
|
|
|
|
|
|
|
self._initializer = self._loop.create_task(self._prepare_rpc()) |
|
|
|
|
|
|
|
|
|
# If user passes in an async iterator, create a consumer coroutine. |
|
|
|
|
if request_async_iterator is not None: |
|
|
|
|
self._async_request_poller = self._loop.create_task( |
|
|
|
|
self._async_request_poller = loop.create_task( |
|
|
|
|
self._consume_request_iterator(request_async_iterator)) |
|
|
|
|
else: |
|
|
|
|
self._async_request_poller = None |
|
|
|
@ -611,11 +556,7 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall): |
|
|
|
|
""" |
|
|
|
|
try: |
|
|
|
|
await self._cython_call.initiate_stream_stream( |
|
|
|
|
self._metadata, |
|
|
|
|
self._metadata_sent_observer, |
|
|
|
|
self._set_initial_metadata, |
|
|
|
|
self._set_status, |
|
|
|
|
) |
|
|
|
|
self._metadata, self._metadata_sent_observer) |
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
if not self.cancelled(): |
|
|
|
|
self.cancel() |
|
|
|
@ -629,7 +570,7 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall): |
|
|
|
|
await self.done_writing() |
|
|
|
|
|
|
|
|
|
async def write(self, request: RequestType) -> None: |
|
|
|
|
if self._status.done(): |
|
|
|
|
if self._cython_call.done(): |
|
|
|
|
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) |
|
|
|
|
if self._done_writing: |
|
|
|
|
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) |
|
|
|
@ -648,7 +589,7 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall): |
|
|
|
|
|
|
|
|
|
async def done_writing(self) -> None: |
|
|
|
|
"""Implementation of done_writing is idempotent.""" |
|
|
|
|
if self._status.done(): |
|
|
|
|
if self._cython_call.done(): |
|
|
|
|
# If the RPC is finished, do nothing. |
|
|
|
|
return |
|
|
|
|
if not self._done_writing: |
|
|
|
@ -692,7 +633,7 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall): |
|
|
|
|
self._response_deserializer) |
|
|
|
|
|
|
|
|
|
async def read(self) -> ResponseType: |
|
|
|
|
if self._status.done(): |
|
|
|
|
if self._cython_call.done(): |
|
|
|
|
await self._raise_for_status() |
|
|
|
|
return cygrpc.EOF |
|
|
|
|
|
|
|
|
|