From c94364f311296502490ff476fadd6c3d204193c1 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Mon, 3 Feb 2020 17:04:54 +0100 Subject: [PATCH] Use a weakset for storing ongoing calls --- .../grpcio/grpc/experimental/aio/_channel.py | 9 +++-- .../tests_aio/unit/close_channel_test.py | 40 +++++++++++-------- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 2210d2fd641..e0c761ae18e 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -13,7 +13,8 @@ # limitations under the License. """Invocation-side implementation of gRPC Asyncio Python.""" import asyncio -from typing import Any, AsyncIterable, Optional, Sequence, Set, Text +from typing import Any, AsyncIterable, Optional, Sequence, AbstractSet, Text +from weakref import WeakSet import logging import grpc @@ -37,10 +38,10 @@ _LOGGER = logging.getLogger(__name__) class _OngoingCalls: """Internal class used for have visibility of the ongoing calls.""" - _calls: Set[_base_call.RpcContext] + _calls: AbstractSet[_base_call.RpcContext] def __init__(self): - self._calls = set() + self._calls = WeakSet() def _remove_call(self, call: _base_call.RpcContext): self._calls.remove(call) @@ -401,7 +402,7 @@ class Channel: # A new set is created acting as a shallow copy because # when cancellation happens the calls are automatically # removed from the originally set. - calls = set(self._ongoing_calls.calls) + calls = WeakSet(data=self._ongoing_calls.calls) for call in calls: call.cancel() diff --git a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py index 6807bb5b6cc..3ae0baf62d7 100644 --- a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py @@ -16,6 +16,7 @@ import asyncio import logging import unittest +from weakref import WeakSet import grpc from grpc.experimental import aio @@ -32,36 +33,43 @@ _UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' class TestOngoingCalls(unittest.TestCase): - def test_trace_call(self): - - class FakeCall(_base_call.RpcContext): + class FakeCall(_base_call.RpcContext): - def add_done_callback(self, callback): - self.callback = callback + def add_done_callback(self, callback): + self.callback = callback - def cancel(self): - raise NotImplementedError + def cancel(self): + raise NotImplementedError - def cancelled(self): - raise NotImplementedError + def cancelled(self): + raise NotImplementedError - def done(self): - raise NotImplementedError + def done(self): + raise NotImplementedError - def time_remaining(self): - raise NotImplementedError + def time_remaining(self): + raise NotImplementedError + def test_trace_call(self): ongoing_calls = _OngoingCalls() self.assertEqual(ongoing_calls.size(), 0) - call = FakeCall() + call = TestOngoingCalls.FakeCall() ongoing_calls.trace_call(call) self.assertEqual(ongoing_calls.size(), 1) - self.assertEqual(ongoing_calls.calls, set([call])) + self.assertEqual(ongoing_calls.calls, WeakSet([call])) call.callback(call) self.assertEqual(ongoing_calls.size(), 0) - self.assertEqual(ongoing_calls.calls, set()) + self.assertEqual(ongoing_calls.calls, WeakSet()) + + def test_deleted_call(self): + ongoing_calls = _OngoingCalls() + + call = TestOngoingCalls.FakeCall() + ongoing_calls.trace_call(call) + del(call) + self.assertEqual(ongoing_calls.size(), 0) class TestCloseChannel(AioTestBase):