Use a weakset for storing ongoing calls

pull/21819/head
Pau Freixes 5 years ago
parent 2cef2fce39
commit c94364f311
  1. 9
      src/python/grpcio/grpc/experimental/aio/_channel.py
  2. 40
      src/python/grpcio_tests/tests_aio/unit/close_channel_test.py

@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python.""" """Invocation-side implementation of gRPC Asyncio Python."""
import asyncio import asyncio
from typing import Any, AsyncIterable, Optional, Sequence, Set, Text from typing import Any, AsyncIterable, Optional, Sequence, AbstractSet, Text
from weakref import WeakSet
import logging import logging
import grpc import grpc
@ -37,10 +38,10 @@ _LOGGER = logging.getLogger(__name__)
class _OngoingCalls: class _OngoingCalls:
"""Internal class used for have visibility of the ongoing calls.""" """Internal class used for have visibility of the ongoing calls."""
_calls: Set[_base_call.RpcContext] _calls: AbstractSet[_base_call.RpcContext]
def __init__(self): def __init__(self):
self._calls = set() self._calls = WeakSet()
def _remove_call(self, call: _base_call.RpcContext): def _remove_call(self, call: _base_call.RpcContext):
self._calls.remove(call) self._calls.remove(call)
@ -401,7 +402,7 @@ class Channel:
# A new set is created acting as a shallow copy because # A new set is created acting as a shallow copy because
# when cancellation happens the calls are automatically # when cancellation happens the calls are automatically
# removed from the originally set. # removed from the originally set.
calls = set(self._ongoing_calls.calls) calls = WeakSet(data=self._ongoing_calls.calls)
for call in calls: for call in calls:
call.cancel() call.cancel()

@ -16,6 +16,7 @@
import asyncio import asyncio
import logging import logging
import unittest import unittest
from weakref import WeakSet
import grpc import grpc
from grpc.experimental import aio from grpc.experimental import aio
@ -32,36 +33,43 @@ _UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep'
class TestOngoingCalls(unittest.TestCase): class TestOngoingCalls(unittest.TestCase):
def test_trace_call(self): class FakeCall(_base_call.RpcContext):
class FakeCall(_base_call.RpcContext):
def add_done_callback(self, callback): def add_done_callback(self, callback):
self.callback = callback self.callback = callback
def cancel(self): def cancel(self):
raise NotImplementedError raise NotImplementedError
def cancelled(self): def cancelled(self):
raise NotImplementedError raise NotImplementedError
def done(self): def done(self):
raise NotImplementedError raise NotImplementedError
def time_remaining(self): def time_remaining(self):
raise NotImplementedError raise NotImplementedError
def test_trace_call(self):
ongoing_calls = _OngoingCalls() ongoing_calls = _OngoingCalls()
self.assertEqual(ongoing_calls.size(), 0) self.assertEqual(ongoing_calls.size(), 0)
call = FakeCall() call = TestOngoingCalls.FakeCall()
ongoing_calls.trace_call(call) ongoing_calls.trace_call(call)
self.assertEqual(ongoing_calls.size(), 1) self.assertEqual(ongoing_calls.size(), 1)
self.assertEqual(ongoing_calls.calls, set([call])) self.assertEqual(ongoing_calls.calls, WeakSet([call]))
call.callback(call) call.callback(call)
self.assertEqual(ongoing_calls.size(), 0) self.assertEqual(ongoing_calls.size(), 0)
self.assertEqual(ongoing_calls.calls, set()) self.assertEqual(ongoing_calls.calls, WeakSet())
def test_deleted_call(self):
ongoing_calls = _OngoingCalls()
call = TestOngoingCalls.FakeCall()
ongoing_calls.trace_call(call)
del(call)
self.assertEqual(ongoing_calls.size(), 0)
class TestCloseChannel(AioTestBase): class TestCloseChannel(AioTestBase):

Loading…
Cancel
Save