Use a weakset for storing ongoing calls

pull/21819/head
Pau Freixes 5 years ago
parent 2cef2fce39
commit c94364f311
  1. 9
      src/python/grpcio/grpc/experimental/aio/_channel.py
  2. 40
      src/python/grpcio_tests/tests_aio/unit/close_channel_test.py

@ -13,7 +13,8 @@
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
from typing import Any, AsyncIterable, Optional, Sequence, Set, Text
from typing import Any, AsyncIterable, Optional, Sequence, AbstractSet, Text
from weakref import WeakSet
import logging
import grpc
@ -37,10 +38,10 @@ _LOGGER = logging.getLogger(__name__)
class _OngoingCalls:
"""Internal class used for have visibility of the ongoing calls."""
_calls: Set[_base_call.RpcContext]
_calls: AbstractSet[_base_call.RpcContext]
def __init__(self):
self._calls = set()
self._calls = WeakSet()
def _remove_call(self, call: _base_call.RpcContext):
self._calls.remove(call)
@ -401,7 +402,7 @@ class Channel:
# A new set is created acting as a shallow copy because
# when cancellation happens the calls are automatically
# removed from the originally set.
calls = set(self._ongoing_calls.calls)
calls = WeakSet(data=self._ongoing_calls.calls)
for call in calls:
call.cancel()

@ -16,6 +16,7 @@
import asyncio
import logging
import unittest
from weakref import WeakSet
import grpc
from grpc.experimental import aio
@ -32,36 +33,43 @@ _UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep'
class TestOngoingCalls(unittest.TestCase):
def test_trace_call(self):
class FakeCall(_base_call.RpcContext):
class FakeCall(_base_call.RpcContext):
def add_done_callback(self, callback):
self.callback = callback
def add_done_callback(self, callback):
self.callback = callback
def cancel(self):
raise NotImplementedError
def cancel(self):
raise NotImplementedError
def cancelled(self):
raise NotImplementedError
def cancelled(self):
raise NotImplementedError
def done(self):
raise NotImplementedError
def done(self):
raise NotImplementedError
def time_remaining(self):
raise NotImplementedError
def time_remaining(self):
raise NotImplementedError
def test_trace_call(self):
ongoing_calls = _OngoingCalls()
self.assertEqual(ongoing_calls.size(), 0)
call = FakeCall()
call = TestOngoingCalls.FakeCall()
ongoing_calls.trace_call(call)
self.assertEqual(ongoing_calls.size(), 1)
self.assertEqual(ongoing_calls.calls, set([call]))
self.assertEqual(ongoing_calls.calls, WeakSet([call]))
call.callback(call)
self.assertEqual(ongoing_calls.size(), 0)
self.assertEqual(ongoing_calls.calls, set())
self.assertEqual(ongoing_calls.calls, WeakSet())
def test_deleted_call(self):
ongoing_calls = _OngoingCalls()
call = TestOngoingCalls.FakeCall()
ongoing_calls.trace_call(call)
del(call)
self.assertEqual(ongoing_calls.size(), 0)
class TestCloseChannel(AioTestBase):

Loading…
Cancel
Save