Use set as data structure for trace ongoing calls

pull/21819/head
Pau Freixes 5 years ago
parent d0b218ae18
commit 2cef2fce39
  1. 19
      src/python/grpcio/grpc/experimental/aio/_channel.py
  2. 4
      src/python/grpcio_tests/tests_aio/unit/close_channel_test.py

@ -13,7 +13,7 @@
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
from typing import Any, AsyncIterable, Optional, Sequence, Text
from typing import Any, AsyncIterable, Optional, Sequence, Set, Text
import logging
import grpc
@ -37,18 +37,18 @@ _LOGGER = logging.getLogger(__name__)
class _OngoingCalls:
"""Internal class used for have visibility of the ongoing calls."""
_calls: Sequence[_base_call.RpcContext]
_calls: Set[_base_call.RpcContext]
def __init__(self):
self._calls = []
self._calls = set()
def _remove_call(self, call: _base_call.RpcContext):
self._calls.remove(call)
@property
def calls(self) -> Sequence[_base_call.RpcContext]:
"""Returns a shallow copy of the ongoing calls sequence."""
return self._calls[:]
def calls(self) -> Set[_base_call.RpcContext]:
"""Returns the set of ongoing calls."""
return self._calls
def size(self) -> int:
"""Returns the number of ongoing calls."""
@ -56,7 +56,7 @@ class _OngoingCalls:
def trace_call(self, call: _base_call.RpcContext):
"""Adds and manages a new ongoing call."""
self._calls.append(call)
self._calls.add(call)
call.add_done_callback(self._remove_call)
@ -398,7 +398,10 @@ class Channel:
if not pending:
return
calls = self._ongoing_calls.calls
# A new set is created acting as a shallow copy because
# when cancellation happens the calls are automatically
# removed from the originally set.
calls = set(self._ongoing_calls.calls)
for call in calls:
call.cancel()

@ -57,11 +57,11 @@ class TestOngoingCalls(unittest.TestCase):
call = FakeCall()
ongoing_calls.trace_call(call)
self.assertEqual(ongoing_calls.size(), 1)
self.assertEqual(ongoing_calls.calls, [call])
self.assertEqual(ongoing_calls.calls, set([call]))
call.callback(call)
self.assertEqual(ongoing_calls.size(), 0)
self.assertEqual(ongoing_calls.calls, [])
self.assertEqual(ongoing_calls.calls, set())
class TestCloseChannel(AioTestBase):

Loading…
Cancel
Save