diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 7201aabcc73..64db553cdc4 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -13,15 +13,12 @@ # limitations under the License. """Invocation-side implementation of gRPC Asyncio Python.""" import asyncio -from typing import Any, AsyncIterable, Optional, Sequence, AbstractSet -from weakref import WeakSet +import sys +from typing import AbstractSet, Any, AsyncIterable, Optional, Sequence -import logging import grpc -from grpc import _common +from grpc import _common, _compression, _grpcio_metadata from grpc._cython import cygrpc -from grpc import _compression -from grpc import _grpcio_metadata from . import _base_call from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall, @@ -35,6 +32,15 @@ from ._utils import _timeout_to_deadline _IMMUTABLE_EMPTY_TUPLE = tuple() _USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__) +if sys.version_info[1] < 7: + + def _all_tasks() -> Sequence[asyncio.Task]: + return asyncio.Task.all_tasks() +else: + + def _all_tasks() -> Sequence[asyncio.Task]: + return asyncio.all_tasks() + def _augment_channel_arguments(base_options: ChannelArgumentType, compression: Optional[grpc.Compression]): @@ -48,38 +54,6 @@ def _augment_channel_arguments(base_options: ChannelArgumentType, ) + compression_channel_argument + user_agent_channel_argument -_LOGGER = logging.getLogger(__name__) - - -class _OngoingCalls: - """Internal class used for have visibility of the ongoing calls.""" - - _calls: AbstractSet[_base_call.RpcContext] - - def __init__(self): - self._calls = WeakSet() - - def _remove_call(self, call: _base_call.RpcContext): - try: - self._calls.remove(call) - except KeyError: - pass - - @property - def calls(self) -> AbstractSet[_base_call.RpcContext]: - """Returns the set of ongoing calls.""" - return self._calls - - def size(self) -> int: - """Returns the number of ongoing calls.""" - return len(self._calls) - - def trace_call(self, call: _base_call.RpcContext): - """Adds and manages a new ongoing call.""" - self._calls.add(call) - call.add_done_callback(self._remove_call) - - class _BaseMultiCallable: """Base class of all multi callable objects. @@ -87,7 +61,6 @@ class _BaseMultiCallable: """ _loop: asyncio.AbstractEventLoop _channel: cygrpc.AioChannel - _ongoing_calls: _OngoingCalls _method: bytes _request_serializer: SerializingFunction _response_deserializer: DeserializingFunction @@ -103,7 +76,6 @@ class _BaseMultiCallable: def __init__( self, channel: cygrpc.AioChannel, - ongoing_calls: _OngoingCalls, method: bytes, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction, @@ -112,7 +84,6 @@ class _BaseMultiCallable: ) -> None: self._loop = loop self._channel = channel - self._ongoing_calls = ongoing_calls self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer @@ -170,7 +141,6 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable): self._request_serializer, self._response_deserializer, self._loop) - self._ongoing_calls.trace_call(call) return call @@ -213,7 +183,7 @@ class UnaryStreamMultiCallable(_BaseMultiCallable): wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) - self._ongoing_calls.trace_call(call) + return call @@ -260,7 +230,7 @@ class StreamUnaryMultiCallable(_BaseMultiCallable): credentials, wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) - self._ongoing_calls.trace_call(call) + return call @@ -307,7 +277,7 @@ class StreamStreamMultiCallable(_BaseMultiCallable): credentials, wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) - self._ongoing_calls.trace_call(call) + return call @@ -319,7 +289,6 @@ class Channel: _loop: asyncio.AbstractEventLoop _channel: cygrpc.AioChannel _unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] - _ongoing_calls: _OngoingCalls def __init__(self, target: str, options: ChannelArgumentType, credentials: Optional[grpc.ChannelCredentials], @@ -359,7 +328,6 @@ class Channel: _common.encode(target), _augment_channel_arguments(options, compression), credentials, self._loop) - self._ongoing_calls = _OngoingCalls() async def __aenter__(self): """Starts an asynchronous context manager. @@ -383,22 +351,32 @@ class Channel: # No new calls will be accepted by the Cython channel. self._channel.closing() - if grace: - # pylint: disable=unused-variable - _, pending = await asyncio.wait(self._ongoing_calls.calls, - timeout=grace, - loop=self._loop) - - if not pending: - return - - # A new set is created acting as a shallow copy because - # when cancellation happens the calls are automatically - # removed from the originally set. - calls = WeakSet(data=self._ongoing_calls.calls) + # Iterate through running tasks + tasks = _all_tasks() + calls = [] + call_tasks = [] + for task in tasks: + stack = task.get_stack(limit=1) + if not stack: + continue + + # Locate ones created by `aio.Call`. + frame = stack[0] + if 'self' in frame.f_locals: + if isinstance(frame.f_locals['self'], _base_call.Call): + calls.append(frame.f_locals['self']) + call_tasks.append(task) + + # If needed, try to wait for them to finish. + # Call objects are not always awaitables. + if grace and call_tasks: + await asyncio.wait(call_tasks, timeout=grace, loop=self._loop) + + # Time to cancel existing calls. for call in calls: call.cancel() + # Destroy the channel self._channel.close() async def close(self, grace: Optional[float] = None): @@ -487,8 +465,7 @@ class Channel: Returns: A UnaryUnaryMultiCallable value for the named unary-unary method. """ - return UnaryUnaryMultiCallable(self._channel, self._ongoing_calls, - _common.encode(method), + return UnaryUnaryMultiCallable(self._channel, _common.encode(method), request_serializer, response_deserializer, self._unary_unary_interceptors, @@ -500,8 +477,7 @@ class Channel: request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> UnaryStreamMultiCallable: - return UnaryStreamMultiCallable(self._channel, self._ongoing_calls, - _common.encode(method), + return UnaryStreamMultiCallable(self._channel, _common.encode(method), request_serializer, response_deserializer, None, self._loop) @@ -511,8 +487,7 @@ class Channel: request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> StreamUnaryMultiCallable: - return StreamUnaryMultiCallable(self._channel, self._ongoing_calls, - _common.encode(method), + return StreamUnaryMultiCallable(self._channel, _common.encode(method), request_serializer, response_deserializer, None, self._loop) @@ -522,8 +497,7 @@ class Channel: request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> StreamStreamMultiCallable: - return StreamStreamMultiCallable(self._channel, self._ongoing_calls, - _common.encode(method), + return StreamStreamMultiCallable(self._channel, _common.encode(method), request_serializer, response_deserializer, None, self._loop) diff --git a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py index f7d9de5daf5..d19fbff2bbb 100644 --- a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py @@ -21,7 +21,6 @@ from weakref import WeakSet import grpc from grpc.experimental import aio from grpc.experimental.aio import _base_call -from grpc.experimental.aio._channel import _OngoingCalls from src.proto.grpc.testing import messages_pb2, test_pb2_grpc from tests_aio.unit._test_base import AioTestBase @@ -31,47 +30,6 @@ _UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' _LONG_TIMEOUT_THAT_SHOULD_NOT_EXPIRE = 60 -class TestOngoingCalls(unittest.TestCase): - - class FakeCall(_base_call.RpcContext): - - def add_done_callback(self, callback): - self.callback = callback - - def cancel(self): - raise NotImplementedError - - def cancelled(self): - raise NotImplementedError - - def done(self): - raise NotImplementedError - - def time_remaining(self): - raise NotImplementedError - - def test_trace_call(self): - ongoing_calls = _OngoingCalls() - self.assertEqual(ongoing_calls.size(), 0) - - call = TestOngoingCalls.FakeCall() - ongoing_calls.trace_call(call) - self.assertEqual(ongoing_calls.size(), 1) - self.assertEqual(ongoing_calls.calls, WeakSet([call])) - - call.callback(call) - self.assertEqual(ongoing_calls.size(), 0) - self.assertEqual(ongoing_calls.calls, WeakSet()) - - def test_deleted_call(self): - ongoing_calls = _OngoingCalls() - - call = TestOngoingCalls.FakeCall() - ongoing_calls.trace_call(call) - del (call) - self.assertEqual(ongoing_calls.size(), 0) - - class TestCloseChannel(AioTestBase): async def setUp(self): @@ -114,15 +72,11 @@ class TestCloseChannel(AioTestBase): calls = [stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2)] - self.assertEqual(channel._ongoing_calls.size(), 2) - await channel.close() for call in calls: self.assertTrue(call.cancelled()) - self.assertEqual(channel._ongoing_calls.size(), 0) - async def test_close_unary_stream(self): channel = aio.insecure_channel(self._server_target) stub = test_pb2_grpc.TestServiceStub(channel) @@ -130,15 +84,11 @@ class TestCloseChannel(AioTestBase): request = messages_pb2.StreamingOutputCallRequest() calls = [stub.StreamingOutputCall(request) for _ in range(2)] - self.assertEqual(channel._ongoing_calls.size(), 2) - await channel.close() for call in calls: self.assertTrue(call.cancelled()) - self.assertEqual(channel._ongoing_calls.size(), 0) - async def test_close_stream_unary(self): channel = aio.insecure_channel(self._server_target) stub = test_pb2_grpc.TestServiceStub(channel) @@ -150,36 +100,27 @@ class TestCloseChannel(AioTestBase): for call in calls: self.assertTrue(call.cancelled()) - self.assertEqual(channel._ongoing_calls.size(), 0) - async def test_close_stream_stream(self): channel = aio.insecure_channel(self._server_target) stub = test_pb2_grpc.TestServiceStub(channel) calls = [stub.FullDuplexCall() for _ in range(2)] - self.assertEqual(channel._ongoing_calls.size(), 2) - await channel.close() for call in calls: self.assertTrue(call.cancelled()) - self.assertEqual(channel._ongoing_calls.size(), 0) - async def test_close_async_context(self): async with aio.insecure_channel(self._server_target) as channel: stub = test_pb2_grpc.TestServiceStub(channel) calls = [ stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2) ] - self.assertEqual(channel._ongoing_calls.size(), 2) for call in calls: self.assertTrue(call.cancelled()) - self.assertEqual(channel._ongoing_calls.size(), 0) - if __name__ == '__main__': logging.basicConfig(level=logging.INFO)