|
|
|
@ -13,15 +13,12 @@ |
|
|
|
|
# limitations under the License. |
|
|
|
|
"""Invocation-side implementation of gRPC Asyncio Python.""" |
|
|
|
|
import asyncio |
|
|
|
|
from typing import Any, AsyncIterable, Optional, Sequence, AbstractSet |
|
|
|
|
from weakref import WeakSet |
|
|
|
|
import sys |
|
|
|
|
from typing import AbstractSet, Any, AsyncIterable, Optional, Sequence |
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
|
import grpc |
|
|
|
|
from grpc import _common |
|
|
|
|
from grpc import _common, _compression, _grpcio_metadata |
|
|
|
|
from grpc._cython import cygrpc |
|
|
|
|
from grpc import _compression |
|
|
|
|
from grpc import _grpcio_metadata |
|
|
|
|
|
|
|
|
|
from . import _base_call |
|
|
|
|
from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall, |
|
|
|
@ -35,6 +32,15 @@ from ._utils import _timeout_to_deadline |
|
|
|
|
_IMMUTABLE_EMPTY_TUPLE = tuple() |
|
|
|
|
_USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__) |
|
|
|
|
|
|
|
|
|
if sys.version_info[1] < 7: |
|
|
|
|
|
|
|
|
|
def _all_tasks() -> Sequence[asyncio.Task]: |
|
|
|
|
return asyncio.Task.all_tasks() |
|
|
|
|
else: |
|
|
|
|
|
|
|
|
|
def _all_tasks() -> Sequence[asyncio.Task]: |
|
|
|
|
return asyncio.all_tasks() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _augment_channel_arguments(base_options: ChannelArgumentType, |
|
|
|
|
compression: Optional[grpc.Compression]): |
|
|
|
@ -48,38 +54,6 @@ def _augment_channel_arguments(base_options: ChannelArgumentType, |
|
|
|
|
) + compression_channel_argument + user_agent_channel_argument |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_LOGGER = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _OngoingCalls: |
|
|
|
|
"""Internal class used for have visibility of the ongoing calls.""" |
|
|
|
|
|
|
|
|
|
_calls: AbstractSet[_base_call.RpcContext] |
|
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
|
self._calls = WeakSet() |
|
|
|
|
|
|
|
|
|
def _remove_call(self, call: _base_call.RpcContext): |
|
|
|
|
try: |
|
|
|
|
self._calls.remove(call) |
|
|
|
|
except KeyError: |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
def calls(self) -> AbstractSet[_base_call.RpcContext]: |
|
|
|
|
"""Returns the set of ongoing calls.""" |
|
|
|
|
return self._calls |
|
|
|
|
|
|
|
|
|
def size(self) -> int: |
|
|
|
|
"""Returns the number of ongoing calls.""" |
|
|
|
|
return len(self._calls) |
|
|
|
|
|
|
|
|
|
def trace_call(self, call: _base_call.RpcContext): |
|
|
|
|
"""Adds and manages a new ongoing call.""" |
|
|
|
|
self._calls.add(call) |
|
|
|
|
call.add_done_callback(self._remove_call) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _BaseMultiCallable: |
|
|
|
|
"""Base class of all multi callable objects. |
|
|
|
|
|
|
|
|
@ -87,7 +61,6 @@ class _BaseMultiCallable: |
|
|
|
|
""" |
|
|
|
|
_loop: asyncio.AbstractEventLoop |
|
|
|
|
_channel: cygrpc.AioChannel |
|
|
|
|
_ongoing_calls: _OngoingCalls |
|
|
|
|
_method: bytes |
|
|
|
|
_request_serializer: SerializingFunction |
|
|
|
|
_response_deserializer: DeserializingFunction |
|
|
|
@ -103,7 +76,6 @@ class _BaseMultiCallable: |
|
|
|
|
def __init__( |
|
|
|
|
self, |
|
|
|
|
channel: cygrpc.AioChannel, |
|
|
|
|
ongoing_calls: _OngoingCalls, |
|
|
|
|
method: bytes, |
|
|
|
|
request_serializer: SerializingFunction, |
|
|
|
|
response_deserializer: DeserializingFunction, |
|
|
|
@ -112,7 +84,6 @@ class _BaseMultiCallable: |
|
|
|
|
) -> None: |
|
|
|
|
self._loop = loop |
|
|
|
|
self._channel = channel |
|
|
|
|
self._ongoing_calls = ongoing_calls |
|
|
|
|
self._method = method |
|
|
|
|
self._request_serializer = request_serializer |
|
|
|
|
self._response_deserializer = response_deserializer |
|
|
|
@ -170,7 +141,6 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable): |
|
|
|
|
self._request_serializer, self._response_deserializer, |
|
|
|
|
self._loop) |
|
|
|
|
|
|
|
|
|
self._ongoing_calls.trace_call(call) |
|
|
|
|
return call |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -213,7 +183,7 @@ class UnaryStreamMultiCallable(_BaseMultiCallable): |
|
|
|
|
wait_for_ready, self._channel, self._method, |
|
|
|
|
self._request_serializer, |
|
|
|
|
self._response_deserializer, self._loop) |
|
|
|
|
self._ongoing_calls.trace_call(call) |
|
|
|
|
|
|
|
|
|
return call |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -260,7 +230,7 @@ class StreamUnaryMultiCallable(_BaseMultiCallable): |
|
|
|
|
credentials, wait_for_ready, self._channel, |
|
|
|
|
self._method, self._request_serializer, |
|
|
|
|
self._response_deserializer, self._loop) |
|
|
|
|
self._ongoing_calls.trace_call(call) |
|
|
|
|
|
|
|
|
|
return call |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -307,7 +277,7 @@ class StreamStreamMultiCallable(_BaseMultiCallable): |
|
|
|
|
credentials, wait_for_ready, self._channel, |
|
|
|
|
self._method, self._request_serializer, |
|
|
|
|
self._response_deserializer, self._loop) |
|
|
|
|
self._ongoing_calls.trace_call(call) |
|
|
|
|
|
|
|
|
|
return call |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -319,7 +289,6 @@ class Channel: |
|
|
|
|
_loop: asyncio.AbstractEventLoop |
|
|
|
|
_channel: cygrpc.AioChannel |
|
|
|
|
_unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] |
|
|
|
|
_ongoing_calls: _OngoingCalls |
|
|
|
|
|
|
|
|
|
def __init__(self, target: str, options: ChannelArgumentType, |
|
|
|
|
credentials: Optional[grpc.ChannelCredentials], |
|
|
|
@ -359,7 +328,6 @@ class Channel: |
|
|
|
|
_common.encode(target), |
|
|
|
|
_augment_channel_arguments(options, compression), credentials, |
|
|
|
|
self._loop) |
|
|
|
|
self._ongoing_calls = _OngoingCalls() |
|
|
|
|
|
|
|
|
|
async def __aenter__(self): |
|
|
|
|
"""Starts an asynchronous context manager. |
|
|
|
@ -383,22 +351,32 @@ class Channel: |
|
|
|
|
# No new calls will be accepted by the Cython channel. |
|
|
|
|
self._channel.closing() |
|
|
|
|
|
|
|
|
|
if grace: |
|
|
|
|
# pylint: disable=unused-variable |
|
|
|
|
_, pending = await asyncio.wait(self._ongoing_calls.calls, |
|
|
|
|
timeout=grace, |
|
|
|
|
loop=self._loop) |
|
|
|
|
|
|
|
|
|
if not pending: |
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
# A new set is created acting as a shallow copy because |
|
|
|
|
# when cancellation happens the calls are automatically |
|
|
|
|
# removed from the originally set. |
|
|
|
|
calls = WeakSet(data=self._ongoing_calls.calls) |
|
|
|
|
# Iterate through running tasks |
|
|
|
|
tasks = _all_tasks() |
|
|
|
|
calls = [] |
|
|
|
|
call_tasks = [] |
|
|
|
|
for task in tasks: |
|
|
|
|
stack = task.get_stack(limit=1) |
|
|
|
|
if not stack: |
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
# Locate ones created by `aio.Call`. |
|
|
|
|
frame = stack[0] |
|
|
|
|
if 'self' in frame.f_locals: |
|
|
|
|
if isinstance(frame.f_locals['self'], _base_call.Call): |
|
|
|
|
calls.append(frame.f_locals['self']) |
|
|
|
|
call_tasks.append(task) |
|
|
|
|
|
|
|
|
|
# If needed, try to wait for them to finish. |
|
|
|
|
# Call objects are not always awaitables. |
|
|
|
|
if grace and call_tasks: |
|
|
|
|
await asyncio.wait(call_tasks, timeout=grace, loop=self._loop) |
|
|
|
|
|
|
|
|
|
# Time to cancel existing calls. |
|
|
|
|
for call in calls: |
|
|
|
|
call.cancel() |
|
|
|
|
|
|
|
|
|
# Destroy the channel |
|
|
|
|
self._channel.close() |
|
|
|
|
|
|
|
|
|
async def close(self, grace: Optional[float] = None): |
|
|
|
@ -487,8 +465,7 @@ class Channel: |
|
|
|
|
Returns: |
|
|
|
|
A UnaryUnaryMultiCallable value for the named unary-unary method. |
|
|
|
|
""" |
|
|
|
|
return UnaryUnaryMultiCallable(self._channel, self._ongoing_calls, |
|
|
|
|
_common.encode(method), |
|
|
|
|
return UnaryUnaryMultiCallable(self._channel, _common.encode(method), |
|
|
|
|
request_serializer, |
|
|
|
|
response_deserializer, |
|
|
|
|
self._unary_unary_interceptors, |
|
|
|
@ -500,8 +477,7 @@ class Channel: |
|
|
|
|
request_serializer: Optional[SerializingFunction] = None, |
|
|
|
|
response_deserializer: Optional[DeserializingFunction] = None |
|
|
|
|
) -> UnaryStreamMultiCallable: |
|
|
|
|
return UnaryStreamMultiCallable(self._channel, self._ongoing_calls, |
|
|
|
|
_common.encode(method), |
|
|
|
|
return UnaryStreamMultiCallable(self._channel, _common.encode(method), |
|
|
|
|
request_serializer, |
|
|
|
|
response_deserializer, None, self._loop) |
|
|
|
|
|
|
|
|
@ -511,8 +487,7 @@ class Channel: |
|
|
|
|
request_serializer: Optional[SerializingFunction] = None, |
|
|
|
|
response_deserializer: Optional[DeserializingFunction] = None |
|
|
|
|
) -> StreamUnaryMultiCallable: |
|
|
|
|
return StreamUnaryMultiCallable(self._channel, self._ongoing_calls, |
|
|
|
|
_common.encode(method), |
|
|
|
|
return StreamUnaryMultiCallable(self._channel, _common.encode(method), |
|
|
|
|
request_serializer, |
|
|
|
|
response_deserializer, None, self._loop) |
|
|
|
|
|
|
|
|
@ -522,8 +497,7 @@ class Channel: |
|
|
|
|
request_serializer: Optional[SerializingFunction] = None, |
|
|
|
|
response_deserializer: Optional[DeserializingFunction] = None |
|
|
|
|
) -> StreamStreamMultiCallable: |
|
|
|
|
return StreamStreamMultiCallable(self._channel, self._ongoing_calls, |
|
|
|
|
_common.encode(method), |
|
|
|
|
return StreamStreamMultiCallable(self._channel, _common.encode(method), |
|
|
|
|
request_serializer, |
|
|
|
|
response_deserializer, None, |
|
|
|
|
self._loop) |
|
|
|
|