From fbd213d04b88cf73b576b0e85c2d10cc42c7d12c Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Tue, 11 Feb 2020 16:45:03 -0800 Subject: [PATCH 1/7] Make client-side graceful shutdown faster --- .../grpcio/grpc/experimental/aio/_channel.py | 110 +++++++----------- .../tests_aio/unit/close_channel_test.py | 59 ---------- 2 files changed, 42 insertions(+), 127 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 7201aabcc73..64db553cdc4 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -13,15 +13,12 @@ # limitations under the License. """Invocation-side implementation of gRPC Asyncio Python.""" import asyncio -from typing import Any, AsyncIterable, Optional, Sequence, AbstractSet -from weakref import WeakSet +import sys +from typing import AbstractSet, Any, AsyncIterable, Optional, Sequence -import logging import grpc -from grpc import _common +from grpc import _common, _compression, _grpcio_metadata from grpc._cython import cygrpc -from grpc import _compression -from grpc import _grpcio_metadata from . import _base_call from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall, @@ -35,6 +32,15 @@ from ._utils import _timeout_to_deadline _IMMUTABLE_EMPTY_TUPLE = tuple() _USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__) +if sys.version_info[1] < 7: + + def _all_tasks() -> Sequence[asyncio.Task]: + return asyncio.Task.all_tasks() +else: + + def _all_tasks() -> Sequence[asyncio.Task]: + return asyncio.all_tasks() + def _augment_channel_arguments(base_options: ChannelArgumentType, compression: Optional[grpc.Compression]): @@ -48,38 +54,6 @@ def _augment_channel_arguments(base_options: ChannelArgumentType, ) + compression_channel_argument + user_agent_channel_argument -_LOGGER = logging.getLogger(__name__) - - -class _OngoingCalls: - """Internal class used for have visibility of the ongoing calls.""" - - _calls: AbstractSet[_base_call.RpcContext] - - def __init__(self): - self._calls = WeakSet() - - def _remove_call(self, call: _base_call.RpcContext): - try: - self._calls.remove(call) - except KeyError: - pass - - @property - def calls(self) -> AbstractSet[_base_call.RpcContext]: - """Returns the set of ongoing calls.""" - return self._calls - - def size(self) -> int: - """Returns the number of ongoing calls.""" - return len(self._calls) - - def trace_call(self, call: _base_call.RpcContext): - """Adds and manages a new ongoing call.""" - self._calls.add(call) - call.add_done_callback(self._remove_call) - - class _BaseMultiCallable: """Base class of all multi callable objects. @@ -87,7 +61,6 @@ class _BaseMultiCallable: """ _loop: asyncio.AbstractEventLoop _channel: cygrpc.AioChannel - _ongoing_calls: _OngoingCalls _method: bytes _request_serializer: SerializingFunction _response_deserializer: DeserializingFunction @@ -103,7 +76,6 @@ class _BaseMultiCallable: def __init__( self, channel: cygrpc.AioChannel, - ongoing_calls: _OngoingCalls, method: bytes, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction, @@ -112,7 +84,6 @@ class _BaseMultiCallable: ) -> None: self._loop = loop self._channel = channel - self._ongoing_calls = ongoing_calls self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer @@ -170,7 +141,6 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable): self._request_serializer, self._response_deserializer, self._loop) - self._ongoing_calls.trace_call(call) return call @@ -213,7 +183,7 @@ class UnaryStreamMultiCallable(_BaseMultiCallable): wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) - self._ongoing_calls.trace_call(call) + return call @@ -260,7 +230,7 @@ class StreamUnaryMultiCallable(_BaseMultiCallable): credentials, wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) - self._ongoing_calls.trace_call(call) + return call @@ -307,7 +277,7 @@ class StreamStreamMultiCallable(_BaseMultiCallable): credentials, wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) - self._ongoing_calls.trace_call(call) + return call @@ -319,7 +289,6 @@ class Channel: _loop: asyncio.AbstractEventLoop _channel: cygrpc.AioChannel _unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] - _ongoing_calls: _OngoingCalls def __init__(self, target: str, options: ChannelArgumentType, credentials: Optional[grpc.ChannelCredentials], @@ -359,7 +328,6 @@ class Channel: _common.encode(target), _augment_channel_arguments(options, compression), credentials, self._loop) - self._ongoing_calls = _OngoingCalls() async def __aenter__(self): """Starts an asynchronous context manager. @@ -383,22 +351,32 @@ class Channel: # No new calls will be accepted by the Cython channel. self._channel.closing() - if grace: - # pylint: disable=unused-variable - _, pending = await asyncio.wait(self._ongoing_calls.calls, - timeout=grace, - loop=self._loop) - - if not pending: - return - - # A new set is created acting as a shallow copy because - # when cancellation happens the calls are automatically - # removed from the originally set. - calls = WeakSet(data=self._ongoing_calls.calls) + # Iterate through running tasks + tasks = _all_tasks() + calls = [] + call_tasks = [] + for task in tasks: + stack = task.get_stack(limit=1) + if not stack: + continue + + # Locate ones created by `aio.Call`. + frame = stack[0] + if 'self' in frame.f_locals: + if isinstance(frame.f_locals['self'], _base_call.Call): + calls.append(frame.f_locals['self']) + call_tasks.append(task) + + # If needed, try to wait for them to finish. + # Call objects are not always awaitables. + if grace and call_tasks: + await asyncio.wait(call_tasks, timeout=grace, loop=self._loop) + + # Time to cancel existing calls. for call in calls: call.cancel() + # Destroy the channel self._channel.close() async def close(self, grace: Optional[float] = None): @@ -487,8 +465,7 @@ class Channel: Returns: A UnaryUnaryMultiCallable value for the named unary-unary method. """ - return UnaryUnaryMultiCallable(self._channel, self._ongoing_calls, - _common.encode(method), + return UnaryUnaryMultiCallable(self._channel, _common.encode(method), request_serializer, response_deserializer, self._unary_unary_interceptors, @@ -500,8 +477,7 @@ class Channel: request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> UnaryStreamMultiCallable: - return UnaryStreamMultiCallable(self._channel, self._ongoing_calls, - _common.encode(method), + return UnaryStreamMultiCallable(self._channel, _common.encode(method), request_serializer, response_deserializer, None, self._loop) @@ -511,8 +487,7 @@ class Channel: request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> StreamUnaryMultiCallable: - return StreamUnaryMultiCallable(self._channel, self._ongoing_calls, - _common.encode(method), + return StreamUnaryMultiCallable(self._channel, _common.encode(method), request_serializer, response_deserializer, None, self._loop) @@ -522,8 +497,7 @@ class Channel: request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> StreamStreamMultiCallable: - return StreamStreamMultiCallable(self._channel, self._ongoing_calls, - _common.encode(method), + return StreamStreamMultiCallable(self._channel, _common.encode(method), request_serializer, response_deserializer, None, self._loop) diff --git a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py index f7d9de5daf5..d19fbff2bbb 100644 --- a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py @@ -21,7 +21,6 @@ from weakref import WeakSet import grpc from grpc.experimental import aio from grpc.experimental.aio import _base_call -from grpc.experimental.aio._channel import _OngoingCalls from src.proto.grpc.testing import messages_pb2, test_pb2_grpc from tests_aio.unit._test_base import AioTestBase @@ -31,47 +30,6 @@ _UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' _LONG_TIMEOUT_THAT_SHOULD_NOT_EXPIRE = 60 -class TestOngoingCalls(unittest.TestCase): - - class FakeCall(_base_call.RpcContext): - - def add_done_callback(self, callback): - self.callback = callback - - def cancel(self): - raise NotImplementedError - - def cancelled(self): - raise NotImplementedError - - def done(self): - raise NotImplementedError - - def time_remaining(self): - raise NotImplementedError - - def test_trace_call(self): - ongoing_calls = _OngoingCalls() - self.assertEqual(ongoing_calls.size(), 0) - - call = TestOngoingCalls.FakeCall() - ongoing_calls.trace_call(call) - self.assertEqual(ongoing_calls.size(), 1) - self.assertEqual(ongoing_calls.calls, WeakSet([call])) - - call.callback(call) - self.assertEqual(ongoing_calls.size(), 0) - self.assertEqual(ongoing_calls.calls, WeakSet()) - - def test_deleted_call(self): - ongoing_calls = _OngoingCalls() - - call = TestOngoingCalls.FakeCall() - ongoing_calls.trace_call(call) - del (call) - self.assertEqual(ongoing_calls.size(), 0) - - class TestCloseChannel(AioTestBase): async def setUp(self): @@ -114,15 +72,11 @@ class TestCloseChannel(AioTestBase): calls = [stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2)] - self.assertEqual(channel._ongoing_calls.size(), 2) - await channel.close() for call in calls: self.assertTrue(call.cancelled()) - self.assertEqual(channel._ongoing_calls.size(), 0) - async def test_close_unary_stream(self): channel = aio.insecure_channel(self._server_target) stub = test_pb2_grpc.TestServiceStub(channel) @@ -130,15 +84,11 @@ class TestCloseChannel(AioTestBase): request = messages_pb2.StreamingOutputCallRequest() calls = [stub.StreamingOutputCall(request) for _ in range(2)] - self.assertEqual(channel._ongoing_calls.size(), 2) - await channel.close() for call in calls: self.assertTrue(call.cancelled()) - self.assertEqual(channel._ongoing_calls.size(), 0) - async def test_close_stream_unary(self): channel = aio.insecure_channel(self._server_target) stub = test_pb2_grpc.TestServiceStub(channel) @@ -150,36 +100,27 @@ class TestCloseChannel(AioTestBase): for call in calls: self.assertTrue(call.cancelled()) - self.assertEqual(channel._ongoing_calls.size(), 0) - async def test_close_stream_stream(self): channel = aio.insecure_channel(self._server_target) stub = test_pb2_grpc.TestServiceStub(channel) calls = [stub.FullDuplexCall() for _ in range(2)] - self.assertEqual(channel._ongoing_calls.size(), 2) - await channel.close() for call in calls: self.assertTrue(call.cancelled()) - self.assertEqual(channel._ongoing_calls.size(), 0) - async def test_close_async_context(self): async with aio.insecure_channel(self._server_target) as channel: stub = test_pb2_grpc.TestServiceStub(channel) calls = [ stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2) ] - self.assertEqual(channel._ongoing_calls.size(), 2) for call in calls: self.assertTrue(call.cancelled()) - self.assertEqual(channel._ongoing_calls.size(), 0) - if __name__ == '__main__': logging.basicConfig(level=logging.INFO) From 3061ee37c00f315e1875c81b129adf15bd40563c Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Tue, 11 Feb 2020 16:46:48 -0800 Subject: [PATCH 2/7] Remove unused import --- src/python/grpcio/grpc/experimental/aio/_channel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 64db553cdc4..65863862831 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -14,7 +14,7 @@ """Invocation-side implementation of gRPC Asyncio Python.""" import asyncio import sys -from typing import AbstractSet, Any, AsyncIterable, Optional, Sequence +from typing import Any, AsyncIterable, Optional, Sequence import grpc from grpc import _common, _compression, _grpcio_metadata From d2832e4bc6660b77d3709e19db2cc1c01d9d1d85 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Wed, 12 Feb 2020 10:44:15 -0800 Subject: [PATCH 3/7] Update tests.json --- src/python/grpcio_tests/tests_aio/tests.json | 1 - 1 file changed, 1 deletion(-) diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index d3765c7a531..e05d64ac474 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -12,7 +12,6 @@ "unit.channel_ready_test.TestChannelReady", "unit.channel_test.TestChannel", "unit.close_channel_test.TestCloseChannel", - "unit.close_channel_test.TestOngoingCalls", "unit.compression_test.TestCompression", "unit.connectivity_test.TestConnectivityState", "unit.done_callback_test.TestDoneCallback", From 2789f83f8519b006a4ab83c55c556679d68d28a3 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Wed, 12 Feb 2020 10:46:56 -0800 Subject: [PATCH 4/7] Make pytype happy --- src/python/grpcio/grpc/experimental/aio/_channel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 65863862831..27e899d61e7 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -14,7 +14,7 @@ """Invocation-side implementation of gRPC Asyncio Python.""" import asyncio import sys -from typing import Any, AsyncIterable, Optional, Sequence +from typing import Any, AsyncIterable, Iterable, Optional, Sequence import grpc from grpc import _common, _compression, _grpcio_metadata @@ -34,11 +34,11 @@ _USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__) if sys.version_info[1] < 7: - def _all_tasks() -> Sequence[asyncio.Task]: + def _all_tasks() -> Iterable[asyncio.Task]: return asyncio.Task.all_tasks() else: - def _all_tasks() -> Sequence[asyncio.Task]: + def _all_tasks() -> Iterable[asyncio.Task]: return asyncio.all_tasks() From 5a29f33a25d96852b8e64805ae960c95bb2136ab Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Wed, 12 Feb 2020 14:32:17 -0800 Subject: [PATCH 5/7] Optimize the logic & add comments --- src/python/grpcio/grpc/experimental/aio/_channel.py | 10 +++++++--- .../grpcio_tests/tests_aio/unit/close_channel_test.py | 1 - 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 27e899d61e7..19a35bb142a 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Invocation-side implementation of gRPC Asyncio Python.""" + import asyncio import sys from typing import Any, AsyncIterable, Iterable, Optional, Sequence @@ -357,14 +358,17 @@ class Channel: call_tasks = [] for task in tasks: stack = task.get_stack(limit=1) + + # If the Task is created by a C-extension, the stack will be empty. if not stack: continue # Locate ones created by `aio.Call`. frame = stack[0] - if 'self' in frame.f_locals: - if isinstance(frame.f_locals['self'], _base_call.Call): - calls.append(frame.f_locals['self']) + candidate = frame.f_locals.get('self') + if candidate: + if isinstance(candidate, _base_call.Call): + calls.append(candidate) call_tasks.append(task) # If needed, try to wait for them to finish. diff --git a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py index d19fbff2bbb..a5a5bc3b080 100644 --- a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py @@ -16,7 +16,6 @@ import asyncio import logging import unittest -from weakref import WeakSet import grpc from grpc.experimental import aio From e0f8fe3254110fb4ba0d7aed08e29e970d94da3b Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Thu, 13 Feb 2020 15:05:50 -0800 Subject: [PATCH 6/7] Ensure channel isolation is maintained while graceful close --- .../grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi | 2 +- src/python/grpcio/grpc/experimental/aio/_channel.py | 10 +++------- .../tests_aio/unit/close_channel_test.py | 12 ++++++++++++ 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi index acee440e673..867245a6944 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi @@ -15,7 +15,7 @@ cdef class _AioCall(GrpcCallWrapper): cdef: - AioChannel _channel + readonly AioChannel _channel list _references object _deadline list _done_callbacks diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 19a35bb142a..db61bc589e8 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -61,11 +61,6 @@ class _BaseMultiCallable: Handles the initialization logic and stores common attributes. """ _loop: asyncio.AbstractEventLoop - _channel: cygrpc.AioChannel - _method: bytes - _request_serializer: SerializingFunction - _response_deserializer: DeserializingFunction - _channel: cygrpc.AioChannel _method: bytes _request_serializer: SerializingFunction @@ -368,8 +363,9 @@ class Channel: candidate = frame.f_locals.get('self') if candidate: if isinstance(candidate, _base_call.Call): - calls.append(candidate) - call_tasks.append(task) + if candidate._cython_call._channel is self._channel: + calls.append(candidate) + call_tasks.append(task) # If needed, try to wait for them to finish. # Call objects are not always awaitables. diff --git a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py index a5a5bc3b080..f05c74392d9 100644 --- a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py @@ -120,6 +120,18 @@ class TestCloseChannel(AioTestBase): for call in calls: self.assertTrue(call.cancelled()) + async def test_channel_isolation(self): + async with aio.insecure_channel(self._server_target) as channel1: + async with aio.insecure_channel(self._server_target) as channel2: + stub1 = test_pb2_grpc.TestServiceStub(channel1) + stub2 = test_pb2_grpc.TestServiceStub(channel2) + + call1 = stub1.UnaryCall(messages_pb2.SimpleRequest()) + call2 = stub2.UnaryCall(messages_pb2.SimpleRequest()) + + self.assertFalse(call1.cancelled()) + self.assertTrue(call2.cancelled()) + if __name__ == '__main__': logging.basicConfig(level=logging.INFO) From 7a50172cc8381346c718086c2a0592756a28e2c3 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Thu, 13 Feb 2020 16:00:54 -0800 Subject: [PATCH 7/7] Handle the intercepted call case --- .../grpcio/grpc/experimental/aio/_channel.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index db61bc589e8..cb4ed80339c 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -363,9 +363,21 @@ class Channel: candidate = frame.f_locals.get('self') if candidate: if isinstance(candidate, _base_call.Call): - if candidate._cython_call._channel is self._channel: - calls.append(candidate) - call_tasks.append(task) + if hasattr(candidate, '_channel'): + # For intercepted Call object + if candidate._channel is not self._channel: + continue + elif hasattr(candidate, '_cython_call'): + # For normal Call object + if candidate._cython_call._channel is not self._channel: + continue + else: + # Unidentified Call object + raise cygrpc.InternalError( + f'Unrecognized call object: {candidate}') + + calls.append(candidate) + call_tasks.append(task) # If needed, try to wait for them to finish. # Call objects are not always awaitables.