@ -15,6 +15,7 @@
import asyncio
from typing import Any , AsyncIterable , Optional , Sequence , Text
import logging
import grpc
from grpc import _common
from grpc . _cython import cygrpc
@ -28,8 +29,37 @@ from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
SerializingFunction )
from . _utils import _timeout_to_deadline
_TIMEOUT_WAIT_FOR_CLOSE_ONGOING_CALLS_SEC = 0.1
_IMMUTABLE_EMPTY_TUPLE = tuple ( )
_LOGGER = logging . getLogger ( __name__ )
class _OngoingCalls :
""" Internal class used for have visibility of the ongoing calls. """
_calls : Sequence [ _base_call . RpcContext ]
def __init__ ( self ) :
self . _calls = [ ]
def _remove_call ( self , call : _base_call . RpcContext ) :
self . _calls . remove ( call )
@property
def calls ( self ) - > Sequence [ _base_call . RpcContext ] :
""" Returns a shallow copy of the ongoing calls sequence. """
return self . _calls [ : ]
def size ( self ) - > int :
""" Returns the number of ongoing calls. """
return len ( self . _calls )
def trace_call ( self , call : _base_call . RpcContext ) :
""" Adds and manages a new ongoing call. """
self . _calls . append ( call )
call . add_done_callback ( self . _remove_call )
class _BaseMultiCallable :
""" Base class of all multi callable objects.
@ -38,6 +68,7 @@ class _BaseMultiCallable:
"""
_loop : asyncio . AbstractEventLoop
_channel : cygrpc . AioChannel
_ongoing_calls : _OngoingCalls
_method : bytes
_request_serializer : SerializingFunction
_response_deserializer : DeserializingFunction
@ -49,9 +80,11 @@ class _BaseMultiCallable:
_interceptors : Optional [ Sequence [ UnaryUnaryClientInterceptor ] ]
_loop : asyncio . AbstractEventLoop
# pylint: disable=too-many-arguments
def __init__ (
self ,
channel : cygrpc . AioChannel ,
ongoing_calls : _OngoingCalls ,
method : bytes ,
request_serializer : SerializingFunction ,
response_deserializer : DeserializingFunction ,
@ -60,6 +93,7 @@ class _BaseMultiCallable:
) - > None :
self . _loop = loop
self . _channel = channel
self . _ongoing_calls = ongoing_calls
self . _method = method
self . _request_serializer = request_serializer
self . _response_deserializer = response_deserializer
@ -111,18 +145,21 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable):
metadata = _IMMUTABLE_EMPTY_TUPLE
if not self . _interceptors :
return UnaryUnaryCall ( request , _timeout_to_deadline ( timeout ) ,
call = UnaryUnaryCall ( request , _timeout_to_deadline ( timeout ) ,
metadata , credentials , self . _channel ,
self . _method , self . _request_serializer ,
self . _response_deserializer , self . _loop )
else :
return InterceptedUnaryUnaryCall ( self . _interceptors , request ,
call = InterceptedUnaryUnaryCall ( self . _interceptors , request ,
timeout , metadata , credentials ,
self . _channel , self . _method ,
self . _request_serializer ,
self . _response_deserializer ,
self . _loop )
self . _ongoing_calls . trace_call ( call )
return call
class UnaryStreamMultiCallable ( _BaseMultiCallable ) :
""" Affords invoking a unary-stream RPC from client-side in an asynchronous way. """
@ -165,10 +202,12 @@ class UnaryStreamMultiCallable(_BaseMultiCallable):
if metadata is None :
metadata = _IMMUTABLE_EMPTY_TUPLE
return UnaryStreamCall ( request , deadline , metadata , credentials ,
call = UnaryStreamCall ( request , deadline , metadata , credentials ,
self . _channel , self . _method ,
self . _request_serializer ,
self . _response_deserializer , self . _loop )
self . _ongoing_calls . trace_call ( call )
return call
class StreamUnaryMultiCallable ( _BaseMultiCallable ) :
@ -216,10 +255,12 @@ class StreamUnaryMultiCallable(_BaseMultiCallable):
if metadata is None :
metadata = _IMMUTABLE_EMPTY_TUPLE
return StreamUnaryCall ( request_async_iterator , deadline , metadata ,
call = StreamUnaryCall ( request_async_iterator , deadline , metadata ,
credentials , self . _channel , self . _method ,
self . _request_serializer ,
self . _response_deserializer , self . _loop )
self . _ongoing_calls . trace_call ( call )
return call
class StreamStreamMultiCallable ( _BaseMultiCallable ) :
@ -267,10 +308,12 @@ class StreamStreamMultiCallable(_BaseMultiCallable):
if metadata is None :
metadata = _IMMUTABLE_EMPTY_TUPLE
return StreamStreamCall ( request_async_iterator , deadline , metadata ,
call = StreamStreamCall ( request_async_iterator , deadline , metadata ,
credentials , self . _channel , self . _method ,
self . _request_serializer ,
self . _response_deserializer , self . _loop )
self . _ongoing_calls . trace_call ( call )
return call
class Channel :
@ -281,6 +324,7 @@ class Channel:
_loop : asyncio . AbstractEventLoop
_channel : cygrpc . AioChannel
_unary_unary_interceptors : Optional [ Sequence [ UnaryUnaryClientInterceptor ] ]
_ongoing_calls : _OngoingCalls
def __init__ ( self , target : Text , options : Optional [ ChannelArgumentType ] ,
credentials : Optional [ grpc . ChannelCredentials ] ,
@ -322,6 +366,53 @@ class Channel:
self . _loop = asyncio . get_event_loop ( )
self . _channel = cygrpc . AioChannel ( _common . encode ( target ) , options ,
credentials , self . _loop )
self . _ongoing_calls = _OngoingCalls ( )
async def __aenter__ ( self ) :
""" Starts an asynchronous context manager.
Returns :
Channel the channel that was instantiated .
"""
return self
async def __aexit__ ( self , exc_type , exc_val , exc_tb ) :
""" Finishes the asynchronous context manager by closing gracefully the channel. """
await self . _close ( )
async def _wait_for_close_ongoing_calls ( self ) :
sleep_iterations_sec = 0.001
while self . _ongoing_calls . size ( ) > 0 :
await asyncio . sleep ( sleep_iterations_sec )
async def _close ( self ) :
# No new calls will be accepted by the Cython channel.
self . _channel . closing ( )
calls = self . _ongoing_calls . calls
for call in calls :
call . cancel ( )
try :
await asyncio . wait_for ( self . _wait_for_close_ongoing_calls ( ) ,
_TIMEOUT_WAIT_FOR_CLOSE_ONGOING_CALLS_SEC ,
loop = self . _loop )
except asyncio . TimeoutError :
_LOGGER . warning ( " Closing channel %s , closing RPCs timed out " ,
str ( self ) )
self . _channel . close ( )
async def close ( self ) :
""" Closes this Channel and releases all resources held by it.
Closing the Channel will proactively terminate all RPCs active with the
Channel and it is not valid to invoke new RPCs with the Channel .
This method is idempotent .
"""
await self . _close ( )
def get_state ( self ,
try_to_connect : bool = False ) - > grpc . ChannelConnectivity :
@ -387,7 +478,8 @@ class Channel:
Returns :
A UnaryUnaryMultiCallable value for the named unary - unary method .
"""
return UnaryUnaryMultiCallable ( self . _channel , _common . encode ( method ) ,
return UnaryUnaryMultiCallable ( self . _channel , self . _ongoing_calls ,
_common . encode ( method ) ,
request_serializer ,
response_deserializer ,
self . _unary_unary_interceptors ,
@ -399,7 +491,8 @@ class Channel:
request_serializer : Optional [ SerializingFunction ] = None ,
response_deserializer : Optional [ DeserializingFunction ] = None
) - > UnaryStreamMultiCallable :
return UnaryStreamMultiCallable ( self . _channel , _common . encode ( method ) ,
return UnaryStreamMultiCallable ( self . _channel , self . _ongoing_calls ,
_common . encode ( method ) ,
request_serializer ,
response_deserializer , None , self . _loop )
@ -409,7 +502,8 @@ class Channel:
request_serializer : Optional [ SerializingFunction ] = None ,
response_deserializer : Optional [ DeserializingFunction ] = None
) - > StreamUnaryMultiCallable :
return StreamUnaryMultiCallable ( self . _channel , _common . encode ( method ) ,
return StreamUnaryMultiCallable ( self . _channel , self . _ongoing_calls ,
_common . encode ( method ) ,
request_serializer ,
response_deserializer , None , self . _loop )
@ -419,33 +513,8 @@ class Channel:
request_serializer : Optional [ SerializingFunction ] = None ,
response_deserializer : Optional [ DeserializingFunction ] = None
) - > StreamStreamMultiCallable :
return StreamStreamMultiCallable ( self . _channel , _common . encode ( method ) ,
return StreamStreamMultiCallable ( self . _channel , self . _ongoing_calls ,
_common . encode ( method ) ,
request_serializer ,
response_deserializer , None ,
self . _loop )
async def _close ( self ) :
# TODO: Send cancellation status
self . _channel . close ( )
async def __aenter__ ( self ) :
""" Starts an asynchronous context manager.
Returns :
Channel the channel that was instantiated .
"""
return self
async def __aexit__ ( self , exc_type , exc_val , exc_tb ) :
""" Finishes the asynchronous context manager by closing gracefully the channel. """
await self . _close ( )
async def close ( self ) :
""" Closes this Channel and releases all resources held by it.
Closing the Channel will proactively terminate all RPCs active with the
Channel and it is not valid to invoke new RPCs with the Channel .
This method is idempotent .
"""
await self . _close ( )