From 2cef2fce3996c140702e9c460d909e881ba960e3 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Mon, 3 Feb 2020 15:52:26 +0100 Subject: [PATCH] Use set as data structure for trace ongoing calls --- .../grpcio/grpc/experimental/aio/_channel.py | 19 +++++++++++-------- .../tests_aio/unit/close_channel_test.py | 4 ++-- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 7682ca96df2..2210d2fd641 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -13,7 +13,7 @@ # limitations under the License. """Invocation-side implementation of gRPC Asyncio Python.""" import asyncio -from typing import Any, AsyncIterable, Optional, Sequence, Text +from typing import Any, AsyncIterable, Optional, Sequence, Set, Text import logging import grpc @@ -37,18 +37,18 @@ _LOGGER = logging.getLogger(__name__) class _OngoingCalls: """Internal class used for have visibility of the ongoing calls.""" - _calls: Sequence[_base_call.RpcContext] + _calls: Set[_base_call.RpcContext] def __init__(self): - self._calls = [] + self._calls = set() def _remove_call(self, call: _base_call.RpcContext): self._calls.remove(call) @property - def calls(self) -> Sequence[_base_call.RpcContext]: - """Returns a shallow copy of the ongoing calls sequence.""" - return self._calls[:] + def calls(self) -> Set[_base_call.RpcContext]: + """Returns the set of ongoing calls.""" + return self._calls def size(self) -> int: """Returns the number of ongoing calls.""" @@ -56,7 +56,7 @@ class _OngoingCalls: def trace_call(self, call: _base_call.RpcContext): """Adds and manages a new ongoing call.""" - self._calls.append(call) + self._calls.add(call) call.add_done_callback(self._remove_call) @@ -398,7 +398,10 @@ class Channel: if not pending: return - calls = self._ongoing_calls.calls + # A new set is created acting as a shallow copy because + # when cancellation happens the calls are automatically + # removed from the originally set. + calls = set(self._ongoing_calls.calls) for call in calls: call.cancel() diff --git a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py index 61bc18180bb..6807bb5b6cc 100644 --- a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py @@ -57,11 +57,11 @@ class TestOngoingCalls(unittest.TestCase): call = FakeCall() ongoing_calls.trace_call(call) self.assertEqual(ongoing_calls.size(), 1) - self.assertEqual(ongoing_calls.calls, [call]) + self.assertEqual(ongoing_calls.calls, set([call])) call.callback(call) self.assertEqual(ongoing_calls.size(), 0) - self.assertEqual(ongoing_calls.calls, []) + self.assertEqual(ongoing_calls.calls, set()) class TestCloseChannel(AioTestBase):