@ -14,7 +14,8 @@
""" Invocation-side implementation of gRPC Asyncio Python. """
import asyncio
from typing import AsyncIterable , Awaitable , List , Dict , Optional
from functools import partial
from typing import AsyncIterable , List , Dict , Optional
import grpc
from grpc import _common
@ -42,8 +43,6 @@ _NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
' \t debug_error_string = " {} " \n '
' > ' )
_EMPTY_METADATA = tuple ( )
class AioRpcError ( grpc . RpcError ) :
""" An implementation of RpcError to be used by the asynchronous API.
@ -153,116 +152,69 @@ class Call(_base_call.Call):
"""
_loop : asyncio . AbstractEventLoop
_code : grpc . StatusCode
_status : Awaitable [ cygrpc . AioRpcStatus ]
_initial_metadata : Awaitable [ MetadataType ]
_locally_cancelled : bool
_cython_call : cygrpc . _AioCall
_done_callbacks : List [ DoneCallbackType ]
def __init__ ( self , cython_call : cygrpc . _AioCall ) - > None :
self . _loop = asyncio . get_event_loop ( )
self . _code = None
self . _status = self . _loop . create_future ( )
self . _initial_metadata = self . _loop . create_future ( )
self . _locally_cancelled = False
def __init__ ( self , cython_call : cygrpc . _AioCall ,
loop : asyncio . AbstractEventLoop ) - > None :
self . _loop = loop
self . _cython_call = cython_call
self . _done_callbacks = [ ]
def __del__ ( self ) - > None :
if not self . _status . done ( ) :
self . _cancel (
cygrpc . AioRpcStatus ( cygrpc . StatusCode . cancelled ,
_GC_CANCELLATION_DETAILS , None , None ) )
if not self . _cython_call . done ( ) :
self . _cancel ( _GC_CANCELLATION_DETAILS )
def cancelled ( self ) - > bool :
return self . _code == grpc . StatusCode . CANCELLED
return self . _cython_call . cancelled ( )
def _cancel ( self , status : cygrpc . AioRpcStatus ) - > bool :
def _cancel ( self , details : str ) - > bool :
""" Forwards the application cancellation reasoning. """
if not self . _status . done ( ) :
self . _set_status ( status )
self . _cython_call . cancel ( status )
if not self . _cython_call . done ( ) :
self . _cython_call . cancel ( details )
return True
else :
return False
def cancel ( self ) - > bool :
return self . _cancel (
cygrpc . AioRpcStatus ( cygrpc . StatusCode . cancelled ,
_LOCAL_CANCELLATION_DETAILS , None , None ) )
return self . _cancel ( _LOCAL_CANCELLATION_DETAILS )
def done ( self ) - > bool :
return self . _status . done ( )
return self . _cython_call . done ( )
def add_done_callback ( self , callback : DoneCallbackType ) - > None :
if self . done ( ) :
callback ( self )
else :
self . _done_callbacks . append ( callback )
cb = partial ( callback , self )
self . _cython_call . add_done_callback ( cb )
def time_remaining ( self ) - > Optional [ float ] :
return self . _cython_call . time_remaining ( )
async def initial_metadata ( self ) - > MetadataType :
return await self . _initial_metadata
return await self . _cython_call . initial_metadata ( )
async def trailing_metadata ( self ) - > MetadataType :
return ( await self . _status ) . trailing_metadata ( )
return ( await self . _cython_call . status ( ) ) . trailing_metadata ( )
async def code ( self ) - > grpc . StatusCode :
await self . _status
return self . _code
cygrpc_code = ( await self . _cython_call . status ( ) ) . code ( )
return _common . CYGRPC_STATUS_CODE_TO_STATUS_CODE [ cygrpc_code ]
async def details ( self ) - > str :
return ( await self . _status ) . details ( )
return ( await self . _cython_call . status ( ) ) . details ( )
async def debug_error_string ( self ) - > str :
return ( await self . _status ) . debug_error_string ( )
def _set_initial_metadata ( self , metadata : MetadataType ) - > None :
self . _initial_metadata . set_result ( metadata )
def _set_status ( self , status : cygrpc . AioRpcStatus ) - > None :
""" Private method to set final status of the RPC.
This method should only be invoked once .
"""
# In case of local cancellation, flip the flag.
if status . details ( ) is _LOCAL_CANCELLATION_DETAILS :
self . _locally_cancelled = True
# In case of the RPC finished without receiving metadata.
if not self . _initial_metadata . done ( ) :
self . _initial_metadata . set_result ( _EMPTY_METADATA )
# Sets final status
self . _status . set_result ( status )
self . _code = _common . CYGRPC_STATUS_CODE_TO_STATUS_CODE [ status . code ( ) ]
for callback in self . _done_callbacks :
callback ( self )
return ( await self . _cython_call . status ( ) ) . debug_error_string ( )
async def _raise_for_status ( self ) - > None :
if self . _locally_cancelled :
if self . _cython_call . is_locally_cancelled ( ) :
raise asyncio . CancelledError ( )
await self . _status
if self . _ code != grpc . StatusCode . OK :
raise _create_rpc_error ( await self . initial_metadata ( ) ,
self . _status . result ( ) )
code = await self . code ( )
if code != grpc . StatusCode . OK :
raise _create_rpc_error ( await self . initial_metadata ( ) , await
self . _cython_call . status ( ) )
def _repr ( self ) - > str :
""" Assembles the RPC representation string. """
if not self . _status . done ( ) :
return ' < {} object> ' . format ( self . __class__ . __name__ )
if self . _code is grpc . StatusCode . OK :
return _OK_CALL_REPRESENTATION . format (
self . __class__ . __name__ , self . _code ,
self . _status . result ( ) . details ( ) )
else :
return _NON_OK_CALL_REPRESENTATION . format (
self . __class__ . __name__ , self . _code ,
self . _status . result ( ) . details ( ) ,
self . _status . result ( ) . debug_error_string ( ) )
return repr ( self . _cython_call )
def __repr__ ( self ) - > str :
return self . _repr ( )
@ -288,13 +240,14 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
credentials : Optional [ grpc . CallCredentials ] ,
channel : cygrpc . AioChannel , method : bytes ,
request_serializer : SerializingFunction ,
response_deserializer : DeserializingFunction ) - > None :
super ( ) . __init__ ( channel . call ( method , deadline , credentials ) )
response_deserializer : DeserializingFunction ,
loop : asyncio . AbstractEventLoop ) - > None :
super ( ) . __init__ ( channel . call ( method , deadline , credentials ) , loop )
self . _request = request
self . _metadata = metadata
self . _request_serializer = request_serializer
self . _response_deserializer = response_deserializer
self . _call = self . _ loop. create_task ( self . _invoke ( ) )
self . _call = loop . create_task ( self . _invoke ( ) )
def cancel ( self ) - > bool :
if super ( ) . cancel ( ) :
@ -312,11 +265,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
# https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
try :
serialized_response = await self . _cython_call . unary_unary (
serialized_request ,
self . _metadata ,
self . _set_initial_metadata ,
self . _set_status ,
)
serialized_request , self . _metadata )
except asyncio . CancelledError :
if not self . cancelled ( ) :
self . cancel ( )
@ -360,13 +309,14 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
credentials : Optional [ grpc . CallCredentials ] ,
channel : cygrpc . AioChannel , method : bytes ,
request_serializer : SerializingFunction ,
response_deserializer : DeserializingFunction ) - > None :
super ( ) . __init__ ( channel . call ( method , deadline , credentials ) )
response_deserializer : DeserializingFunction ,
loop : asyncio . AbstractEventLoop ) - > None :
super ( ) . __init__ ( channel . call ( method , deadline , credentials ) , loop )
self . _request = request
self . _metadata = metadata
self . _request_serializer = request_serializer
self . _response_deserializer = response_deserializer
self . _send_unary_request_task = self . _ loop. create_task (
self . _send_unary_request_task = loop . create_task (
self . _send_unary_request ( ) )
self . _message_aiter = None
@ -382,8 +332,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
self . _request_serializer )
try :
await self . _cython_call . initiate_unary_stream (
serialized_request , self . _metadata , self . _set_initial_metadata ,
self . _set_status )
serialized_request , self . _metadata )
except asyncio . CancelledError :
if not self . cancelled ( ) :
self . cancel ( )
@ -419,7 +368,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
self . _response_deserializer )
async def read ( self ) - > ResponseType :
if self . _status . done ( ) :
if self . _cython_call . done ( ) :
await self . _raise_for_status ( )
return cygrpc . EOF
@ -452,16 +401,17 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
credentials : Optional [ grpc . CallCredentials ] ,
channel : cygrpc . AioChannel , method : bytes ,
request_serializer : SerializingFunction ,
response_deserializer : DeserializingFunction ) - > None :
super ( ) . __init__ ( channel . call ( method , deadline , credentials ) )
response_deserializer : DeserializingFunction ,
loop : asyncio . AbstractEventLoop ) - > None :
super ( ) . __init__ ( channel . call ( method , deadline , credentials ) , loop )
self . _metadata = metadata
self . _request_serializer = request_serializer
self . _response_deserializer = response_deserializer
self . _metadata_sent = asyncio . Event ( loop = self . _ loop)
self . _metadata_sent = asyncio . Event ( loop = loop )
self . _done_writing = False
self . _call_finisher = self . _ loop. create_task ( self . _conduct_rpc ( ) )
self . _call_finisher = loop . create_task ( self . _conduct_rpc ( ) )
# If user passes in an async iterator, create a consumer Task.
if request_async_iterator is not None :
@ -485,11 +435,7 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
async def _conduct_rpc ( self ) - > ResponseType :
try :
serialized_response = await self . _cython_call . stream_unary (
self . _metadata ,
self . _metadata_sent_observer ,
self . _set_initial_metadata ,
self . _set_status ,
)
self . _metadata , self . _metadata_sent_observer )
except asyncio . CancelledError :
if not self . cancelled ( ) :
self . cancel ( )
@ -517,7 +463,7 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
return response
async def write ( self , request : RequestType ) - > None :
if self . _status . done ( ) :
if self . _cython_call . done ( ) :
raise asyncio . InvalidStateError ( _RPC_ALREADY_FINISHED_DETAILS )
if self . _done_writing :
raise asyncio . InvalidStateError ( _RPC_HALF_CLOSED_DETAILS )
@ -536,7 +482,7 @@ class StreamUnaryCall(Call, _base_call.StreamUnaryCall):
async def done_writing ( self ) - > None :
""" Implementation of done_writing is idempotent. """
if self . _status . done ( ) :
if self . _cython_call . done ( ) :
# If the RPC is finished, do nothing.
return
if not self . _done_writing :
@ -572,20 +518,21 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall):
credentials : Optional [ grpc . CallCredentials ] ,
channel : cygrpc . AioChannel , method : bytes ,
request_serializer : SerializingFunction ,
response_deserializer : DeserializingFunction ) - > None :
super ( ) . __init__ ( channel . call ( method , deadline , credentials ) )
response_deserializer : DeserializingFunction ,
loop : asyncio . AbstractEventLoop ) - > None :
super ( ) . __init__ ( channel . call ( method , deadline , credentials ) , loop )
self . _metadata = metadata
self . _request_serializer = request_serializer
self . _response_deserializer = response_deserializer
self . _metadata_sent = asyncio . Event ( loop = self . _ loop)
self . _metadata_sent = asyncio . Event ( loop = loop )
self . _done_writing = False
self . _initializer = self . _loop . create_task ( self . _prepare_rpc ( ) )
# If user passes in an async iterator, create a consumer coroutine.
if request_async_iterator is not None :
self . _async_request_poller = self . _ loop. create_task (
self . _async_request_poller = loop . create_task (
self . _consume_request_iterator ( request_async_iterator ) )
else :
self . _async_request_poller = None
@ -611,11 +558,7 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall):
"""
try :
await self . _cython_call . initiate_stream_stream (
self . _metadata ,
self . _metadata_sent_observer ,
self . _set_initial_metadata ,
self . _set_status ,
)
self . _metadata , self . _metadata_sent_observer )
except asyncio . CancelledError :
if not self . cancelled ( ) :
self . cancel ( )
@ -629,7 +572,7 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall):
await self . done_writing ( )
async def write ( self , request : RequestType ) - > None :
if self . _status . done ( ) :
if self . _cython_call . done ( ) :
raise asyncio . InvalidStateError ( _RPC_ALREADY_FINISHED_DETAILS )
if self . _done_writing :
raise asyncio . InvalidStateError ( _RPC_HALF_CLOSED_DETAILS )
@ -648,7 +591,7 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall):
async def done_writing ( self ) - > None :
""" Implementation of done_writing is idempotent. """
if self . _status . done ( ) :
if self . _cython_call . done ( ) :
# If the RPC is finished, do nothing.
return
if not self . _done_writing :
@ -692,7 +635,7 @@ class StreamStreamCall(Call, _base_call.StreamStreamCall):
self . _response_deserializer )
async def read ( self ) - > ResponseType :
if self . _status . done ( ) :
if self . _cython_call . done ( ) :
await self . _raise_for_status ( )
return cygrpc . EOF