diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi index acee440e673..867245a6944 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi @@ -15,7 +15,7 @@ cdef class _AioCall(GrpcCallWrapper): cdef: - AioChannel _channel + readonly AioChannel _channel list _references object _deadline list _done_callbacks diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 7201aabcc73..cb4ed80339c 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -12,16 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Invocation-side implementation of gRPC Asyncio Python.""" + import asyncio -from typing import Any, AsyncIterable, Optional, Sequence, AbstractSet -from weakref import WeakSet +import sys +from typing import Any, AsyncIterable, Iterable, Optional, Sequence -import logging import grpc -from grpc import _common +from grpc import _common, _compression, _grpcio_metadata from grpc._cython import cygrpc -from grpc import _compression -from grpc import _grpcio_metadata from . import _base_call from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall, @@ -35,6 +33,15 @@ from ._utils import _timeout_to_deadline _IMMUTABLE_EMPTY_TUPLE = tuple() _USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__) +if sys.version_info[1] < 7: + + def _all_tasks() -> Iterable[asyncio.Task]: + return asyncio.Task.all_tasks() +else: + + def _all_tasks() -> Iterable[asyncio.Task]: + return asyncio.all_tasks() + def _augment_channel_arguments(base_options: ChannelArgumentType, compression: Optional[grpc.Compression]): @@ -48,50 +55,12 @@ def _augment_channel_arguments(base_options: ChannelArgumentType, ) + compression_channel_argument + user_agent_channel_argument -_LOGGER = logging.getLogger(__name__) - - -class _OngoingCalls: - """Internal class used for have visibility of the ongoing calls.""" - - _calls: AbstractSet[_base_call.RpcContext] - - def __init__(self): - self._calls = WeakSet() - - def _remove_call(self, call: _base_call.RpcContext): - try: - self._calls.remove(call) - except KeyError: - pass - - @property - def calls(self) -> AbstractSet[_base_call.RpcContext]: - """Returns the set of ongoing calls.""" - return self._calls - - def size(self) -> int: - """Returns the number of ongoing calls.""" - return len(self._calls) - - def trace_call(self, call: _base_call.RpcContext): - """Adds and manages a new ongoing call.""" - self._calls.add(call) - call.add_done_callback(self._remove_call) - - class _BaseMultiCallable: """Base class of all multi callable objects. Handles the initialization logic and stores common attributes. """ _loop: asyncio.AbstractEventLoop - _channel: cygrpc.AioChannel - _ongoing_calls: _OngoingCalls - _method: bytes - _request_serializer: SerializingFunction - _response_deserializer: DeserializingFunction - _channel: cygrpc.AioChannel _method: bytes _request_serializer: SerializingFunction @@ -103,7 +72,6 @@ class _BaseMultiCallable: def __init__( self, channel: cygrpc.AioChannel, - ongoing_calls: _OngoingCalls, method: bytes, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction, @@ -112,7 +80,6 @@ class _BaseMultiCallable: ) -> None: self._loop = loop self._channel = channel - self._ongoing_calls = ongoing_calls self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer @@ -170,7 +137,6 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable): self._request_serializer, self._response_deserializer, self._loop) - self._ongoing_calls.trace_call(call) return call @@ -213,7 +179,7 @@ class UnaryStreamMultiCallable(_BaseMultiCallable): wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) - self._ongoing_calls.trace_call(call) + return call @@ -260,7 +226,7 @@ class StreamUnaryMultiCallable(_BaseMultiCallable): credentials, wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) - self._ongoing_calls.trace_call(call) + return call @@ -307,7 +273,7 @@ class StreamStreamMultiCallable(_BaseMultiCallable): credentials, wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) - self._ongoing_calls.trace_call(call) + return call @@ -319,7 +285,6 @@ class Channel: _loop: asyncio.AbstractEventLoop _channel: cygrpc.AioChannel _unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] - _ongoing_calls: _OngoingCalls def __init__(self, target: str, options: ChannelArgumentType, credentials: Optional[grpc.ChannelCredentials], @@ -359,7 +324,6 @@ class Channel: _common.encode(target), _augment_channel_arguments(options, compression), credentials, self._loop) - self._ongoing_calls = _OngoingCalls() async def __aenter__(self): """Starts an asynchronous context manager. @@ -383,22 +347,48 @@ class Channel: # No new calls will be accepted by the Cython channel. self._channel.closing() - if grace: - # pylint: disable=unused-variable - _, pending = await asyncio.wait(self._ongoing_calls.calls, - timeout=grace, - loop=self._loop) - - if not pending: - return - - # A new set is created acting as a shallow copy because - # when cancellation happens the calls are automatically - # removed from the originally set. - calls = WeakSet(data=self._ongoing_calls.calls) + # Iterate through running tasks + tasks = _all_tasks() + calls = [] + call_tasks = [] + for task in tasks: + stack = task.get_stack(limit=1) + + # If the Task is created by a C-extension, the stack will be empty. + if not stack: + continue + + # Locate ones created by `aio.Call`. + frame = stack[0] + candidate = frame.f_locals.get('self') + if candidate: + if isinstance(candidate, _base_call.Call): + if hasattr(candidate, '_channel'): + # For intercepted Call object + if candidate._channel is not self._channel: + continue + elif hasattr(candidate, '_cython_call'): + # For normal Call object + if candidate._cython_call._channel is not self._channel: + continue + else: + # Unidentified Call object + raise cygrpc.InternalError( + f'Unrecognized call object: {candidate}') + + calls.append(candidate) + call_tasks.append(task) + + # If needed, try to wait for them to finish. + # Call objects are not always awaitables. + if grace and call_tasks: + await asyncio.wait(call_tasks, timeout=grace, loop=self._loop) + + # Time to cancel existing calls. for call in calls: call.cancel() + # Destroy the channel self._channel.close() async def close(self, grace: Optional[float] = None): @@ -487,8 +477,7 @@ class Channel: Returns: A UnaryUnaryMultiCallable value for the named unary-unary method. """ - return UnaryUnaryMultiCallable(self._channel, self._ongoing_calls, - _common.encode(method), + return UnaryUnaryMultiCallable(self._channel, _common.encode(method), request_serializer, response_deserializer, self._unary_unary_interceptors, @@ -500,8 +489,7 @@ class Channel: request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> UnaryStreamMultiCallable: - return UnaryStreamMultiCallable(self._channel, self._ongoing_calls, - _common.encode(method), + return UnaryStreamMultiCallable(self._channel, _common.encode(method), request_serializer, response_deserializer, None, self._loop) @@ -511,8 +499,7 @@ class Channel: request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> StreamUnaryMultiCallable: - return StreamUnaryMultiCallable(self._channel, self._ongoing_calls, - _common.encode(method), + return StreamUnaryMultiCallable(self._channel, _common.encode(method), request_serializer, response_deserializer, None, self._loop) @@ -522,8 +509,7 @@ class Channel: request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> StreamStreamMultiCallable: - return StreamStreamMultiCallable(self._channel, self._ongoing_calls, - _common.encode(method), + return StreamStreamMultiCallable(self._channel, _common.encode(method), request_serializer, response_deserializer, None, self._loop) diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index d3765c7a531..e05d64ac474 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -12,7 +12,6 @@ "unit.channel_ready_test.TestChannelReady", "unit.channel_test.TestChannel", "unit.close_channel_test.TestCloseChannel", - "unit.close_channel_test.TestOngoingCalls", "unit.compression_test.TestCompression", "unit.connectivity_test.TestConnectivityState", "unit.done_callback_test.TestDoneCallback", diff --git a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py index f7d9de5daf5..f05c74392d9 100644 --- a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py @@ -16,12 +16,10 @@ import asyncio import logging import unittest -from weakref import WeakSet import grpc from grpc.experimental import aio from grpc.experimental.aio import _base_call -from grpc.experimental.aio._channel import _OngoingCalls from src.proto.grpc.testing import messages_pb2, test_pb2_grpc from tests_aio.unit._test_base import AioTestBase @@ -31,47 +29,6 @@ _UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' _LONG_TIMEOUT_THAT_SHOULD_NOT_EXPIRE = 60 -class TestOngoingCalls(unittest.TestCase): - - class FakeCall(_base_call.RpcContext): - - def add_done_callback(self, callback): - self.callback = callback - - def cancel(self): - raise NotImplementedError - - def cancelled(self): - raise NotImplementedError - - def done(self): - raise NotImplementedError - - def time_remaining(self): - raise NotImplementedError - - def test_trace_call(self): - ongoing_calls = _OngoingCalls() - self.assertEqual(ongoing_calls.size(), 0) - - call = TestOngoingCalls.FakeCall() - ongoing_calls.trace_call(call) - self.assertEqual(ongoing_calls.size(), 1) - self.assertEqual(ongoing_calls.calls, WeakSet([call])) - - call.callback(call) - self.assertEqual(ongoing_calls.size(), 0) - self.assertEqual(ongoing_calls.calls, WeakSet()) - - def test_deleted_call(self): - ongoing_calls = _OngoingCalls() - - call = TestOngoingCalls.FakeCall() - ongoing_calls.trace_call(call) - del (call) - self.assertEqual(ongoing_calls.size(), 0) - - class TestCloseChannel(AioTestBase): async def setUp(self): @@ -114,15 +71,11 @@ class TestCloseChannel(AioTestBase): calls = [stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2)] - self.assertEqual(channel._ongoing_calls.size(), 2) - await channel.close() for call in calls: self.assertTrue(call.cancelled()) - self.assertEqual(channel._ongoing_calls.size(), 0) - async def test_close_unary_stream(self): channel = aio.insecure_channel(self._server_target) stub = test_pb2_grpc.TestServiceStub(channel) @@ -130,15 +83,11 @@ class TestCloseChannel(AioTestBase): request = messages_pb2.StreamingOutputCallRequest() calls = [stub.StreamingOutputCall(request) for _ in range(2)] - self.assertEqual(channel._ongoing_calls.size(), 2) - await channel.close() for call in calls: self.assertTrue(call.cancelled()) - self.assertEqual(channel._ongoing_calls.size(), 0) - async def test_close_stream_unary(self): channel = aio.insecure_channel(self._server_target) stub = test_pb2_grpc.TestServiceStub(channel) @@ -150,35 +99,38 @@ class TestCloseChannel(AioTestBase): for call in calls: self.assertTrue(call.cancelled()) - self.assertEqual(channel._ongoing_calls.size(), 0) - async def test_close_stream_stream(self): channel = aio.insecure_channel(self._server_target) stub = test_pb2_grpc.TestServiceStub(channel) calls = [stub.FullDuplexCall() for _ in range(2)] - self.assertEqual(channel._ongoing_calls.size(), 2) - await channel.close() for call in calls: self.assertTrue(call.cancelled()) - self.assertEqual(channel._ongoing_calls.size(), 0) - async def test_close_async_context(self): async with aio.insecure_channel(self._server_target) as channel: stub = test_pb2_grpc.TestServiceStub(channel) calls = [ stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2) ] - self.assertEqual(channel._ongoing_calls.size(), 2) for call in calls: self.assertTrue(call.cancelled()) - self.assertEqual(channel._ongoing_calls.size(), 0) + async def test_channel_isolation(self): + async with aio.insecure_channel(self._server_target) as channel1: + async with aio.insecure_channel(self._server_target) as channel2: + stub1 = test_pb2_grpc.TestServiceStub(channel1) + stub2 = test_pb2_grpc.TestServiceStub(channel2) + + call1 = stub1.UnaryCall(messages_pb2.SimpleRequest()) + call2 = stub2.UnaryCall(messages_pb2.SimpleRequest()) + + self.assertFalse(call1.cancelled()) + self.assertTrue(call2.cancelled()) if __name__ == '__main__':