Use set as data structure for trace ongoing calls

pull/21819/head
Pau Freixes 5 years ago
parent d0b218ae18
commit 2cef2fce39
  1. 19
      src/python/grpcio/grpc/experimental/aio/_channel.py
  2. 4
      src/python/grpcio_tests/tests_aio/unit/close_channel_test.py

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python.""" """Invocation-side implementation of gRPC Asyncio Python."""
import asyncio import asyncio
from typing import Any, AsyncIterable, Optional, Sequence, Text from typing import Any, AsyncIterable, Optional, Sequence, Set, Text
import logging import logging
import grpc import grpc
@ -37,18 +37,18 @@ _LOGGER = logging.getLogger(__name__)
class _OngoingCalls: class _OngoingCalls:
"""Internal class used for have visibility of the ongoing calls.""" """Internal class used for have visibility of the ongoing calls."""
_calls: Sequence[_base_call.RpcContext] _calls: Set[_base_call.RpcContext]
def __init__(self): def __init__(self):
self._calls = [] self._calls = set()
def _remove_call(self, call: _base_call.RpcContext): def _remove_call(self, call: _base_call.RpcContext):
self._calls.remove(call) self._calls.remove(call)
@property @property
def calls(self) -> Sequence[_base_call.RpcContext]: def calls(self) -> Set[_base_call.RpcContext]:
"""Returns a shallow copy of the ongoing calls sequence.""" """Returns the set of ongoing calls."""
return self._calls[:] return self._calls
def size(self) -> int: def size(self) -> int:
"""Returns the number of ongoing calls.""" """Returns the number of ongoing calls."""
@ -56,7 +56,7 @@ class _OngoingCalls:
def trace_call(self, call: _base_call.RpcContext): def trace_call(self, call: _base_call.RpcContext):
"""Adds and manages a new ongoing call.""" """Adds and manages a new ongoing call."""
self._calls.append(call) self._calls.add(call)
call.add_done_callback(self._remove_call) call.add_done_callback(self._remove_call)
@ -398,7 +398,10 @@ class Channel:
if not pending: if not pending:
return return
calls = self._ongoing_calls.calls # A new set is created acting as a shallow copy because
# when cancellation happens the calls are automatically
# removed from the originally set.
calls = set(self._ongoing_calls.calls)
for call in calls: for call in calls:
call.cancel() call.cancel()

@ -57,11 +57,11 @@ class TestOngoingCalls(unittest.TestCase):
call = FakeCall() call = FakeCall()
ongoing_calls.trace_call(call) ongoing_calls.trace_call(call)
self.assertEqual(ongoing_calls.size(), 1) self.assertEqual(ongoing_calls.size(), 1)
self.assertEqual(ongoing_calls.calls, [call]) self.assertEqual(ongoing_calls.calls, set([call]))
call.callback(call) call.callback(call)
self.assertEqual(ongoing_calls.size(), 0) self.assertEqual(ongoing_calls.size(), 0)
self.assertEqual(ongoing_calls.calls, []) self.assertEqual(ongoing_calls.calls, set())
class TestCloseChannel(AioTestBase): class TestCloseChannel(AioTestBase):

Loading…
Cancel
Save