From c2b3e00068fc1e8a50a1d2714d858773664d8313 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Mon, 27 Jan 2020 23:39:08 +0100 Subject: [PATCH 01/14] [Aio] Close ongoing calls when the channel is closed When the channel is closed, either by calling explicitly the `close()` method or by leaving an asyncrhonous channel context all ongoing RPCs will be cancelled. --- .../grpc/_cython/_cygrpc/aio/call.pyx.pxi | 9 +- .../grpc/_cython/_cygrpc/aio/channel.pxd.pxi | 1 + .../grpc/_cython/_cygrpc/aio/channel.pyx.pxi | 9 +- .../grpcio/grpc/experimental/aio/_channel.py | 139 +++++++++++++----- .../grpc/experimental/aio/_interceptor.py | 37 ++++- src/python/grpcio_tests/tests_aio/tests.json | 1 + .../tests_aio/unit/channel_test.py | 100 ++++++++++++- .../tests_aio/unit/interceptor_test.py | 94 ++++++++++++ 8 files changed, 339 insertions(+), 51 deletions(-) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi index 6de1fa0b834..5ffbe403905 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi @@ -119,14 +119,14 @@ cdef class _AioCall(GrpcCallWrapper): cdef void _set_status(self, AioRpcStatus status) except *: cdef list waiters + self._status = status + if self._initial_metadata is None: self._set_initial_metadata(_IMMUTABLE_EMPTY_METADATA) - self._status = status - waiters = self._waiters_status - # No more waiters should be expected since status # has been set. + waiters = self._waiters_status self._waiters_status = None for waiter in waiters: @@ -141,10 +141,9 @@ cdef class _AioCall(GrpcCallWrapper): self._initial_metadata = initial_metadata - waiters = self._waiters_initial_metadata - # No more waiters should be expected since initial metadata # has been set. + waiters = self._waiters_initial_metadata self._waiters_initial_metadata = None for waiter in waiters: diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi index 45464cb2a62..50879b4a224 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi @@ -15,6 +15,7 @@ cdef enum AioChannelStatus: AIO_CHANNEL_STATUS_UNKNOWN AIO_CHANNEL_STATUS_READY + AIO_CHANNEL_STATUS_CLOSING AIO_CHANNEL_STATUS_DESTROYED cdef class AioChannel: diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi index 6c4b8422cdd..8b80e28edde 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# class _WatchConnectivityFailed(Exception): @@ -69,9 +70,10 @@ cdef class AioChannel: Keeps mirroring the behavior from Core, so we can easily switch to other design of API if necessary. """ - if self._status == AIO_CHANNEL_STATUS_DESTROYED: + if self._status in (AIO_CHANNEL_STATUS_DESTROYED, AIO_CHANNEL_STATUS_CLOSING): # TODO(lidiz) switch to UsageError raise RuntimeError('Channel is closed.') + cdef gpr_timespec c_deadline = _timespec_from_time(deadline) cdef object future = self.loop.create_future() @@ -92,6 +94,9 @@ cdef class AioChannel: else: return True + def closing(self): + self._status = AIO_CHANNEL_STATUS_CLOSING + def close(self): self._status = AIO_CHANNEL_STATUS_DESTROYED grpc_channel_destroy(self.channel) @@ -105,7 +110,7 @@ cdef class AioChannel: Returns: The _AioCall object. """ - if self._status == AIO_CHANNEL_STATUS_DESTROYED: + if self._status in (AIO_CHANNEL_STATUS_CLOSING, AIO_CHANNEL_STATUS_DESTROYED): # TODO(lidiz) switch to UsageError raise RuntimeError('Channel is closed.') diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 2788f4416e0..12fd52eaca2 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -15,6 +15,7 @@ import asyncio from typing import Any, AsyncIterable, Optional, Sequence, Text +import logging import grpc from grpc import _common from grpc._cython import cygrpc @@ -28,8 +29,37 @@ from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType, SerializingFunction) from ._utils import _timeout_to_deadline +_TIMEOUT_WAIT_FOR_CLOSE_ONGOING_CALLS_SEC = 0.1 _IMMUTABLE_EMPTY_TUPLE = tuple() +_LOGGER = logging.getLogger(__name__) + + +class _OngoingCalls: + """Internal class used for have visibility of the ongoing calls.""" + + _calls: Sequence[_base_call.RpcContext] + + def __init__(self): + self._calls = [] + + def _remove_call(self, call: _base_call.RpcContext): + self._calls.remove(call) + + @property + def calls(self) -> Sequence[_base_call.RpcContext]: + """Returns a shallow copy of the ongoing calls sequence.""" + return self._calls[:] + + def size(self) -> int: + """Returns the number of ongoing calls.""" + return len(self._calls) + + def trace_call(self, call: _base_call.RpcContext): + """Adds and manages a new ongoing call.""" + self._calls.append(call) + call.add_done_callback(self._remove_call) + class _BaseMultiCallable: """Base class of all multi callable objects. @@ -38,6 +68,7 @@ class _BaseMultiCallable: """ _loop: asyncio.AbstractEventLoop _channel: cygrpc.AioChannel + _ongoing_calls: _OngoingCalls _method: bytes _request_serializer: SerializingFunction _response_deserializer: DeserializingFunction @@ -49,9 +80,11 @@ class _BaseMultiCallable: _interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] _loop: asyncio.AbstractEventLoop + # pylint: disable=too-many-arguments def __init__( self, channel: cygrpc.AioChannel, + ongoing_calls: _OngoingCalls, method: bytes, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction, @@ -60,6 +93,7 @@ class _BaseMultiCallable: ) -> None: self._loop = loop self._channel = channel + self._ongoing_calls = ongoing_calls self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer @@ -111,18 +145,21 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable): metadata = _IMMUTABLE_EMPTY_TUPLE if not self._interceptors: - return UnaryUnaryCall(request, _timeout_to_deadline(timeout), + call = UnaryUnaryCall(request, _timeout_to_deadline(timeout), metadata, credentials, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) else: - return InterceptedUnaryUnaryCall(self._interceptors, request, + call = InterceptedUnaryUnaryCall(self._interceptors, request, timeout, metadata, credentials, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) + self._ongoing_calls.trace_call(call) + return call + class UnaryStreamMultiCallable(_BaseMultiCallable): """Affords invoking a unary-stream RPC from client-side in an asynchronous way.""" @@ -165,10 +202,12 @@ class UnaryStreamMultiCallable(_BaseMultiCallable): if metadata is None: metadata = _IMMUTABLE_EMPTY_TUPLE - return UnaryStreamCall(request, deadline, metadata, credentials, + call = UnaryStreamCall(request, deadline, metadata, credentials, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) + self._ongoing_calls.trace_call(call) + return call class StreamUnaryMultiCallable(_BaseMultiCallable): @@ -216,10 +255,12 @@ class StreamUnaryMultiCallable(_BaseMultiCallable): if metadata is None: metadata = _IMMUTABLE_EMPTY_TUPLE - return StreamUnaryCall(request_async_iterator, deadline, metadata, + call = StreamUnaryCall(request_async_iterator, deadline, metadata, credentials, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) + self._ongoing_calls.trace_call(call) + return call class StreamStreamMultiCallable(_BaseMultiCallable): @@ -267,10 +308,12 @@ class StreamStreamMultiCallable(_BaseMultiCallable): if metadata is None: metadata = _IMMUTABLE_EMPTY_TUPLE - return StreamStreamCall(request_async_iterator, deadline, metadata, + call = StreamStreamCall(request_async_iterator, deadline, metadata, credentials, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) + self._ongoing_calls.trace_call(call) + return call class Channel: @@ -281,6 +324,7 @@ class Channel: _loop: asyncio.AbstractEventLoop _channel: cygrpc.AioChannel _unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] + _ongoing_calls: _OngoingCalls def __init__(self, target: Text, options: Optional[ChannelArgumentType], credentials: Optional[grpc.ChannelCredentials], @@ -322,6 +366,53 @@ class Channel: self._loop = asyncio.get_event_loop() self._channel = cygrpc.AioChannel(_common.encode(target), options, credentials, self._loop) + self._ongoing_calls = _OngoingCalls() + + async def __aenter__(self): + """Starts an asynchronous context manager. + + Returns: + Channel the channel that was instantiated. + """ + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Finishes the asynchronous context manager by closing gracefully the channel.""" + await self._close() + + async def _wait_for_close_ongoing_calls(self): + sleep_iterations_sec = 0.001 + + while self._ongoing_calls.size() > 0: + await asyncio.sleep(sleep_iterations_sec) + + async def _close(self): + # No new calls will be accepted by the Cython channel. + self._channel.closing() + + calls = self._ongoing_calls.calls + for call in calls: + call.cancel() + + try: + await asyncio.wait_for(self._wait_for_close_ongoing_calls(), + _TIMEOUT_WAIT_FOR_CLOSE_ONGOING_CALLS_SEC, + loop=self._loop) + except asyncio.TimeoutError: + _LOGGER.warning("Closing channel %s, closing RPCs timed out", + str(self)) + + self._channel.close() + + async def close(self): + """Closes this Channel and releases all resources held by it. + + Closing the Channel will proactively terminate all RPCs active with the + Channel and it is not valid to invoke new RPCs with the Channel. + + This method is idempotent. + """ + await self._close() def get_state(self, try_to_connect: bool = False) -> grpc.ChannelConnectivity: @@ -387,7 +478,8 @@ class Channel: Returns: A UnaryUnaryMultiCallable value for the named unary-unary method. """ - return UnaryUnaryMultiCallable(self._channel, _common.encode(method), + return UnaryUnaryMultiCallable(self._channel, self._ongoing_calls, + _common.encode(method), request_serializer, response_deserializer, self._unary_unary_interceptors, @@ -399,7 +491,8 @@ class Channel: request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> UnaryStreamMultiCallable: - return UnaryStreamMultiCallable(self._channel, _common.encode(method), + return UnaryStreamMultiCallable(self._channel, self._ongoing_calls, + _common.encode(method), request_serializer, response_deserializer, None, self._loop) @@ -409,7 +502,8 @@ class Channel: request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> StreamUnaryMultiCallable: - return StreamUnaryMultiCallable(self._channel, _common.encode(method), + return StreamUnaryMultiCallable(self._channel, self._ongoing_calls, + _common.encode(method), request_serializer, response_deserializer, None, self._loop) @@ -419,33 +513,8 @@ class Channel: request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> StreamStreamMultiCallable: - return StreamStreamMultiCallable(self._channel, _common.encode(method), + return StreamStreamMultiCallable(self._channel, self._ongoing_calls, + _common.encode(method), request_serializer, response_deserializer, None, self._loop) - - async def _close(self): - # TODO: Send cancellation status - self._channel.close() - - async def __aenter__(self): - """Starts an asynchronous context manager. - - Returns: - Channel the channel that was instantiated. - """ - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Finishes the asynchronous context manager by closing gracefully the channel.""" - await self._close() - - async def close(self): - """Closes this Channel and releases all resources held by it. - - Closing the Channel will proactively terminate all RPCs active with the - Channel and it is not valid to invoke new RPCs with the Channel. - - This method is idempotent. - """ - await self._close() diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index 87a28fae796..da03c793e0b 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -25,7 +25,7 @@ from . import _base_call from ._call import UnaryUnaryCall, AioRpcError from ._utils import _timeout_to_deadline from ._typing import (RequestType, SerializingFunction, DeserializingFunction, - MetadataType, ResponseType) + MetadataType, ResponseType, DoneCallbackType) _LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!' @@ -102,6 +102,7 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): _intercepted_call: Optional[_base_call.UnaryUnaryCall] _intercepted_call_created: asyncio.Event _interceptors_task: asyncio.Task + _pending_add_done_callbacks: Sequence[DoneCallbackType] # pylint: disable=too-many-arguments def __init__(self, interceptors: Sequence[UnaryUnaryClientInterceptor], @@ -118,6 +119,9 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): interceptors, method, timeout, metadata, credentials, request, request_serializer, response_deserializer), loop=loop) + self._pending_add_done_callbacks = [] + self._interceptors_task.add_done_callback( + self._fire_pending_add_done_callbacks) def __del__(self): self.cancel() @@ -163,6 +167,17 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): return await _run_interceptor(iter(interceptors), client_call_details, request) + def _fire_pending_add_done_callbacks(self, + unused_task: asyncio.Task) -> None: + for callback in self._pending_add_done_callbacks: + callback(self) + + self._pending_add_done_callbacks = [] + + def _wrap_add_done_callback(self, callback: DoneCallbackType, + unused_task: asyncio.Task) -> None: + callback(self) + def cancel(self) -> bool: if self._interceptors_task.done(): return False @@ -186,15 +201,21 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): if not self._interceptors_task.done(): return False - try: - call = self._interceptors_task.result() - except (AioRpcError, asyncio.CancelledError): - return True - + call = self._interceptors_task.result() return call.done() - def add_done_callback(self, unused_callback) -> None: - raise NotImplementedError() + def add_done_callback(self, callback: DoneCallbackType) -> None: + if not self._interceptors_task.done(): + self._pending_add_done_callbacks.append(callback) + return + + call = self._interceptors_task.result() + + if call.done(): + callback(self) + else: + callback = functools.partial(self._wrap_add_done_callback, callback) + call.add_done_callback(self._wrap_add_done_callback) def time_remaining(self) -> Optional[float]: raise NotImplementedError() diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index 884d7c98f1c..3fd3358e1dc 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -8,6 +8,7 @@ "unit.call_test.TestUnaryUnaryCall", "unit.channel_argument_test.TestChannelArgument", "unit.channel_test.TestChannel", + "unit.channel_test.Test_OngoingCalls", "unit.connectivity_test.TestConnectivityState", "unit.done_callback_test.TestDoneCallback", "unit.init_test.TestInsecureChannel", diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py index 77c7fabdc20..da7fab9d66d 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -20,6 +20,8 @@ import unittest import grpc from grpc.experimental import aio +from grpc.experimental.aio import _base_call +from grpc.experimental.aio._channel import _OngoingCalls from src.proto.grpc.testing import messages_pb2, test_pb2_grpc from tests.unit.framework.common import test_constants @@ -42,6 +44,43 @@ _REQUEST_PAYLOAD_SIZE = 7 _RESPONSE_PAYLOAD_SIZE = 42 +class Test_OngoingCalls(unittest.TestCase): + + def test_trace_call(self): + + class FakeCall(_base_call.RpcContext): + + def __init__(self): + self.callback = None + + def add_done_callback(self, callback): + self.callback = callback + + def cancel(self): + raise NotImplementedError + + def cancelled(self): + raise NotImplementedError + + def done(self): + raise NotImplementedError + + def time_remaining(self): + raise NotImplementedError + + ongoing_calls = _OngoingCalls() + self.assertEqual(ongoing_calls.size(), 0) + + call = FakeCall() + ongoing_calls.trace_call(call) + self.assertEqual(ongoing_calls.size(), 1) + self.assertEqual(ongoing_calls.calls, [call]) + + call.callback(call) + self.assertEqual(ongoing_calls.size(), 0) + self.assertEqual(ongoing_calls.calls, []) + + class TestChannel(AioTestBase): async def setUp(self): @@ -225,7 +264,66 @@ class TestChannel(AioTestBase): self.assertEqual(grpc.StatusCode.OK, await call.code()) await channel.close() + async def test_close_unary_unary(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + calls = [stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2)] + + self.assertEqual(channel._ongoing_calls.size(), 2) + + await channel.close() + + for call in calls: + self.assertTrue(call.cancelled()) + + self.assertEqual(channel._ongoing_calls.size(), 0) + + async def test_close_unary_stream(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + request = messages_pb2.StreamingOutputCallRequest() + calls = [stub.StreamingOutputCall(request) for _ in range(2)] + + self.assertEqual(channel._ongoing_calls.size(), 2) + + await channel.close() + + for call in calls: + self.assertTrue(call.cancelled()) + + self.assertEqual(channel._ongoing_calls.size(), 0) + + async def test_close_stream_stream(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + calls = [stub.FullDuplexCall() for _ in range(2)] + + self.assertEqual(channel._ongoing_calls.size(), 2) + + await channel.close() + + for call in calls: + self.assertTrue(call.cancelled()) + + self.assertEqual(channel._ongoing_calls.size(), 0) + + async def test_close_async_context(self): + async with aio.insecure_channel(self._server_target) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + calls = [ + stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2) + ] + self.assertEqual(channel._ongoing_calls.size(), 2) + + for call in calls: + self.assertTrue(call.cancelled()) + + self.assertEqual(channel._ongoing_calls.size(), 0) + if __name__ == '__main__': - logging.basicConfig(level=logging.DEBUG) + logging.basicConfig(level=logging.INFO) unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py index 6d1ae543b34..c1a68f0352a 100644 --- a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py @@ -573,6 +573,100 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(await call.code(), grpc.StatusCode.OK) + async def test_add_done_callback_before_finishes(self): + called = False + interceptor_can_continue = asyncio.Event() + + def callback(call): + nonlocal called + called = True + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + + await interceptor_can_continue.wait() + call = await continuation(client_call_details, request) + return call + + async with aio.insecure_channel(self._server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + call.add_done_callback(callback) + interceptor_can_continue.set() + await call + + self.assertTrue(called) + + async def test_add_done_callback_after_finishes(self): + called = False + + def callback(call): + nonlocal called + called = True + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + + call = await continuation(client_call_details, request) + return call + + async with aio.insecure_channel(self._server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + + await call + + call.add_done_callback(callback) + + self.assertTrue(called) + + async def test_add_done_callback_after_finishes_before_await(self): + called = False + + def callback(call): + nonlocal called + called = True + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + + call = await continuation(client_call_details, request) + return call + + async with aio.insecure_channel(self._server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + + call.add_done_callback(callback) + + await call + + self.assertTrue(called) + if __name__ == '__main__': logging.basicConfig() From 90331211a6fec6f6d3e80cd5e3e1b5053e385482 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Wed, 29 Jan 2020 22:39:52 +0100 Subject: [PATCH 02/14] Add graceful period for closing the channel Separated the tests for testing how a channel is closed into another test file. Fixed a bug introduced into the interceptors code in the previous commit. --- .../grpc/_cython/_cygrpc/aio/channel.pyx.pxi | 5 +- .../grpcio/grpc/experimental/aio/_channel.py | 38 ++-- .../grpc/experimental/aio/_interceptor.py | 12 +- .../tests_aio/unit/channel_test.py | 96 --------- .../tests_aio/unit/close_channel_test.py | 185 ++++++++++++++++++ 5 files changed, 225 insertions(+), 111 deletions(-) create mode 100644 src/python/grpcio_tests/tests_aio/unit/close_channel_test.py diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi index 8b80e28edde..b58030a08ee 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi @@ -101,6 +101,9 @@ cdef class AioChannel: self._status = AIO_CHANNEL_STATUS_DESTROYED grpc_channel_destroy(self.channel) + def closed(self): + return self._status in (AIO_CHANNEL_STATUS_CLOSING, AIO_CHANNEL_STATUS_DESTROYED) + def call(self, bytes method, object deadline, @@ -110,7 +113,7 @@ cdef class AioChannel: Returns: The _AioCall object. """ - if self._status in (AIO_CHANNEL_STATUS_CLOSING, AIO_CHANNEL_STATUS_DESTROYED): + if self.closed(): # TODO(lidiz) switch to UsageError raise RuntimeError('Channel is closed.') diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 12fd52eaca2..a2fa9bd269a 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -377,25 +377,34 @@ class Channel: return self async def __aexit__(self, exc_type, exc_val, exc_tb): - """Finishes the asynchronous context manager by closing gracefully the channel.""" - await self._close() + """Finishes the asynchronous context manager by closing the channel. - async def _wait_for_close_ongoing_calls(self): - sleep_iterations_sec = 0.001 + Still active RPCs will be cancelled. + """ + await self._close(None) - while self._ongoing_calls.size() > 0: - await asyncio.sleep(sleep_iterations_sec) + async def _close(self, grace): + if self._channel.closed(): + return - async def _close(self): # No new calls will be accepted by the Cython channel. self._channel.closing() + if grace: + _, pending = await asyncio.wait(self._ongoing_calls.calls, + timeout=grace, + loop=self._loop) + + if not pending: + return + calls = self._ongoing_calls.calls for call in calls: call.cancel() try: - await asyncio.wait_for(self._wait_for_close_ongoing_calls(), + await asyncio.wait_for(asyncio.gather(*self._ongoing_calls.calls, + loop=self._loop), _TIMEOUT_WAIT_FOR_CLOSE_ONGOING_CALLS_SEC, loop=self._loop) except asyncio.TimeoutError: @@ -404,15 +413,20 @@ class Channel: self._channel.close() - async def close(self): + async def close(self, grace: Optional[float] = None): """Closes this Channel and releases all resources held by it. - Closing the Channel will proactively terminate all RPCs active with the - Channel and it is not valid to invoke new RPCs with the Channel. + This method immediately stops the channel from executing new RPCs in + all cases. + + If a grace period is specified, this method wait until all active + RPCs are finshed, once the grace period is reached the ones that haven't + been terminated are cancelled. If a grace period is not specified + (by passing None for grace), all existing RPCs are cancelled immediately. This method is idempotent. """ - await self._close() + await self._close(grace) def get_state(self, try_to_connect: bool = False) -> grpc.ChannelConnectivity: diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index da03c793e0b..30429f25c69 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -201,7 +201,11 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): if not self._interceptors_task.done(): return False - call = self._interceptors_task.result() + try: + call = self._interceptors_task.result() + except (AioRpcError, asyncio.CancelledError): + return True + return call.done() def add_done_callback(self, callback: DoneCallbackType) -> None: @@ -209,7 +213,11 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): self._pending_add_done_callbacks.append(callback) return - call = self._interceptors_task.result() + try: + call = self._interceptors_task.result() + except (AioRpcError, asyncio.CancelledError): + callback(self) + return if call.done(): callback(self) diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py index da7fab9d66d..c3f95a4c221 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -44,43 +44,6 @@ _REQUEST_PAYLOAD_SIZE = 7 _RESPONSE_PAYLOAD_SIZE = 42 -class Test_OngoingCalls(unittest.TestCase): - - def test_trace_call(self): - - class FakeCall(_base_call.RpcContext): - - def __init__(self): - self.callback = None - - def add_done_callback(self, callback): - self.callback = callback - - def cancel(self): - raise NotImplementedError - - def cancelled(self): - raise NotImplementedError - - def done(self): - raise NotImplementedError - - def time_remaining(self): - raise NotImplementedError - - ongoing_calls = _OngoingCalls() - self.assertEqual(ongoing_calls.size(), 0) - - call = FakeCall() - ongoing_calls.trace_call(call) - self.assertEqual(ongoing_calls.size(), 1) - self.assertEqual(ongoing_calls.calls, [call]) - - call.callback(call) - self.assertEqual(ongoing_calls.size(), 0) - self.assertEqual(ongoing_calls.calls, []) - - class TestChannel(AioTestBase): async def setUp(self): @@ -264,65 +227,6 @@ class TestChannel(AioTestBase): self.assertEqual(grpc.StatusCode.OK, await call.code()) await channel.close() - async def test_close_unary_unary(self): - channel = aio.insecure_channel(self._server_target) - stub = test_pb2_grpc.TestServiceStub(channel) - - calls = [stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2)] - - self.assertEqual(channel._ongoing_calls.size(), 2) - - await channel.close() - - for call in calls: - self.assertTrue(call.cancelled()) - - self.assertEqual(channel._ongoing_calls.size(), 0) - - async def test_close_unary_stream(self): - channel = aio.insecure_channel(self._server_target) - stub = test_pb2_grpc.TestServiceStub(channel) - - request = messages_pb2.StreamingOutputCallRequest() - calls = [stub.StreamingOutputCall(request) for _ in range(2)] - - self.assertEqual(channel._ongoing_calls.size(), 2) - - await channel.close() - - for call in calls: - self.assertTrue(call.cancelled()) - - self.assertEqual(channel._ongoing_calls.size(), 0) - - async def test_close_stream_stream(self): - channel = aio.insecure_channel(self._server_target) - stub = test_pb2_grpc.TestServiceStub(channel) - - calls = [stub.FullDuplexCall() for _ in range(2)] - - self.assertEqual(channel._ongoing_calls.size(), 2) - - await channel.close() - - for call in calls: - self.assertTrue(call.cancelled()) - - self.assertEqual(channel._ongoing_calls.size(), 0) - - async def test_close_async_context(self): - async with aio.insecure_channel(self._server_target) as channel: - stub = test_pb2_grpc.TestServiceStub(channel) - calls = [ - stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2) - ] - self.assertEqual(channel._ongoing_calls.size(), 2) - - for call in calls: - self.assertTrue(call.cancelled()) - - self.assertEqual(channel._ongoing_calls.size(), 0) - if __name__ == '__main__': logging.basicConfig(level=logging.INFO) diff --git a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py new file mode 100644 index 00000000000..2d3734fc5e3 --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py @@ -0,0 +1,185 @@ +# Copyright 2019 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests behavior of closing a grpc.aio.Channel.""" + +import asyncio +import logging +import os +import threading +import unittest + +import grpc +from grpc.experimental import aio +from grpc.experimental.aio import _base_call +from grpc.experimental.aio._channel import _OngoingCalls + +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests.unit.framework.common import test_constants +from tests_aio.unit._constants import (UNARY_CALL_WITH_SLEEP_VALUE, + UNREACHABLE_TARGET) +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit._test_server import start_test_server + +_UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall' +_UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' +_STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall' + +_INVOCATION_METADATA = ( + ('initial-md-key', 'initial-md-value'), + ('trailing-md-key-bin', b'\x00\x02'), +) + +_NUM_STREAM_RESPONSES = 5 +_REQUEST_PAYLOAD_SIZE = 7 +_RESPONSE_PAYLOAD_SIZE = 42 + + +class TestOngoingCalls(unittest.TestCase): + + def test_trace_call(self): + + class FakeCall(_base_call.RpcContext): + + def __init__(self): + self.callback = None + + def add_done_callback(self, callback): + self.callback = callback + + def cancel(self): + raise NotImplementedError + + def cancelled(self): + raise NotImplementedError + + def done(self): + raise NotImplementedError + + def time_remaining(self): + raise NotImplementedError + + ongoing_calls = _OngoingCalls() + self.assertEqual(ongoing_calls.size(), 0) + + call = FakeCall() + ongoing_calls.trace_call(call) + self.assertEqual(ongoing_calls.size(), 1) + self.assertEqual(ongoing_calls.calls, [call]) + + call.callback(call) + self.assertEqual(ongoing_calls.size(), 0) + self.assertEqual(ongoing_calls.calls, []) + + +class TestCloseChannel(AioTestBase): + + async def setUp(self): + self._server_target, self._server = await start_test_server() + + async def tearDown(self): + await self._server.stop(None) + + async def test_graceful_close(self): + channel = aio.insecure_channel(self._server_target) + UnaryCallWithSleep = channel.unary_unary( + _UNARY_CALL_METHOD_WITH_SLEEP, + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) + + call = UnaryCallWithSleep(messages_pb2.SimpleRequest()) + task = asyncio.ensure_future(call) + + await channel.close(grace=UNARY_CALL_WITH_SLEEP_VALUE * 2) + + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + async def test_none_graceful_close(self): + channel = aio.insecure_channel(self._server_target) + UnaryCallWithSleep = channel.unary_unary( + _UNARY_CALL_METHOD_WITH_SLEEP, + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) + + call = UnaryCallWithSleep(messages_pb2.SimpleRequest()) + task = asyncio.ensure_future(call) + + await channel.close(grace=UNARY_CALL_WITH_SLEEP_VALUE / 2) + + self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) + + async def test_close_unary_unary(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + calls = [stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2)] + + self.assertEqual(channel._ongoing_calls.size(), 2) + + await channel.close() + + for call in calls: + self.assertTrue(call.cancelled()) + + self.assertEqual(channel._ongoing_calls.size(), 0) + + async def test_close_unary_stream(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + request = messages_pb2.StreamingOutputCallRequest() + calls = [stub.StreamingOutputCall(request) for _ in range(2)] + + self.assertEqual(channel._ongoing_calls.size(), 2) + + await channel.close() + + for call in calls: + self.assertTrue(call.cancelled()) + + self.assertEqual(channel._ongoing_calls.size(), 0) + + async def test_close_stream_stream(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + calls = [stub.FullDuplexCall() for _ in range(2)] + + self.assertEqual(channel._ongoing_calls.size(), 2) + + await channel.close() + + for call in calls: + self.assertTrue(call.cancelled()) + + self.assertEqual(channel._ongoing_calls.size(), 0) + + async def test_close_async_context(self): + async with aio.insecure_channel(self._server_target) as channel: + stub = test_pb2_grpc.TestServiceStub(channel) + calls = [ + stub.UnaryCall(messages_pb2.SimpleRequest()) for _ in range(2) + ] + self.assertEqual(channel._ongoing_calls.size(), 2) + + for call in calls: + self.assertTrue(call.cancelled()) + + self.assertEqual(channel._ongoing_calls.size(), 0) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + unittest.main(verbosity=2) From 5959b685e803a648e0f4a8ae557579c6ce4617aa Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Wed, 29 Jan 2020 23:13:06 +0100 Subject: [PATCH 03/14] Remove unused imports, add pylint exceptions --- .../grpcio/grpc/experimental/aio/_channel.py | 1 + .../grpcio_tests/tests_aio/unit/channel_test.py | 2 -- .../tests_aio/unit/close_channel_test.py | 17 +---------------- 3 files changed, 2 insertions(+), 18 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index a2fa9bd269a..586a431fabc 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -391,6 +391,7 @@ class Channel: self._channel.closing() if grace: + # pylint: disable=unused-variable _, pending = await asyncio.wait(self._ongoing_calls.calls, timeout=grace, loop=self._loop) diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py index c3f95a4c221..2a1ea97f6c1 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -15,13 +15,11 @@ import logging import os -import threading import unittest import grpc from grpc.experimental import aio from grpc.experimental.aio import _base_call -from grpc.experimental.aio._channel import _OngoingCalls from src.proto.grpc.testing import messages_pb2, test_pb2_grpc from tests.unit.framework.common import test_constants diff --git a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py index 2d3734fc5e3..462169abf2d 100644 --- a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py @@ -15,8 +15,6 @@ import asyncio import logging -import os -import threading import unittest import grpc @@ -25,24 +23,11 @@ from grpc.experimental.aio import _base_call from grpc.experimental.aio._channel import _OngoingCalls from src.proto.grpc.testing import messages_pb2, test_pb2_grpc -from tests.unit.framework.common import test_constants -from tests_aio.unit._constants import (UNARY_CALL_WITH_SLEEP_VALUE, - UNREACHABLE_TARGET) +from tests_aio.unit._constants import UNARY_CALL_WITH_SLEEP_VALUE from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_server import start_test_server -_UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall' _UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' -_STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall' - -_INVOCATION_METADATA = ( - ('initial-md-key', 'initial-md-value'), - ('trailing-md-key-bin', b'\x00\x02'), -) - -_NUM_STREAM_RESPONSES = 5 -_REQUEST_PAYLOAD_SIZE = 7 -_RESPONSE_PAYLOAD_SIZE = 42 class TestOngoingCalls(unittest.TestCase): From 61177a2cd85d71cf9631ed7de16e860202f3bb6c Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Wed, 29 Jan 2020 23:15:51 +0100 Subject: [PATCH 04/14] Make none graceful test explicit --- src/python/grpcio_tests/tests_aio/unit/close_channel_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py index 462169abf2d..111be3f1daf 100644 --- a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py @@ -101,7 +101,7 @@ class TestCloseChannel(AioTestBase): call = UnaryCallWithSleep(messages_pb2.SimpleRequest()) task = asyncio.ensure_future(call) - await channel.close(grace=UNARY_CALL_WITH_SLEEP_VALUE / 2) + await channel.close(None) self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) From cc8bd8cfdacd8fc1c72b168d88a72d49db0c8071 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Wed, 29 Jan 2020 23:23:51 +0100 Subject: [PATCH 05/14] Use Event for knowing if a callback was called or not --- .../tests_aio/unit/interceptor_test.py | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py index c1a68f0352a..6cc5ab3f912 100644 --- a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py @@ -574,12 +574,11 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_add_done_callback_before_finishes(self): - called = False + called = asyncio.Event() interceptor_can_continue = asyncio.Event() def callback(call): - nonlocal called - called = True + called.set() class Interceptor(aio.UnaryUnaryClientInterceptor): @@ -603,14 +602,16 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): interceptor_can_continue.set() await call - self.assertTrue(called) + try: + await asyncio.wait_for(called.wait(), timeout=0.1) + except: + self.fail("Callback was not called") async def test_add_done_callback_after_finishes(self): - called = False + called = asyncio.Event() def callback(call): - nonlocal called - called = True + called.set() class Interceptor(aio.UnaryUnaryClientInterceptor): @@ -634,14 +635,16 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): call.add_done_callback(callback) - self.assertTrue(called) + try: + await asyncio.wait_for(called.wait(), timeout=0.1) + except: + self.fail("Callback was not called") async def test_add_done_callback_after_finishes_before_await(self): - called = False + called = asyncio.Event() def callback(call): - nonlocal called - called = True + called.set() class Interceptor(aio.UnaryUnaryClientInterceptor): @@ -665,7 +668,10 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): await call - self.assertTrue(called) + try: + await asyncio.wait_for(called.wait(), timeout=0.1) + except: + self.fail("Callback was not called") if __name__ == '__main__': From a1eb58c6ffdfdf9c19dcb3e8f3c6aeb0fb38dc72 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Wed, 29 Jan 2020 23:52:01 +0100 Subject: [PATCH 06/14] Fix pytype issue --- src/python/grpcio_tests/tests_aio/unit/close_channel_test.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py index 111be3f1daf..9c6f0c24f41 100644 --- a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py @@ -36,9 +36,6 @@ class TestOngoingCalls(unittest.TestCase): class FakeCall(_base_call.RpcContext): - def __init__(self): - self.callback = None - def add_done_callback(self, callback): self.callback = callback From 2f81f3df82dfe00d8b7f158c06d67b2035f45ed3 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Wed, 29 Jan 2020 23:53:14 +0100 Subject: [PATCH 07/14] Fix typo --- src/python/grpcio_tests/tests_aio/unit/close_channel_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py index 9c6f0c24f41..f3b89e4e2d5 100644 --- a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py @@ -1,4 +1,4 @@ -# Copyright 2019 The gRPC Authors. +# Copyright 2020 The gRPC Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 46804d8125ac040f653d96f79951f479dfb158fd Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Thu, 30 Jan 2020 00:00:18 +0100 Subject: [PATCH 08/14] Add test for stream unary close channel, remove irrelevant code --- .../tests_aio/unit/close_channel_test.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py index f3b89e4e2d5..61bc18180bb 100644 --- a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py @@ -81,7 +81,6 @@ class TestCloseChannel(AioTestBase): ) call = UnaryCallWithSleep(messages_pb2.SimpleRequest()) - task = asyncio.ensure_future(call) await channel.close(grace=UNARY_CALL_WITH_SLEEP_VALUE * 2) @@ -96,7 +95,6 @@ class TestCloseChannel(AioTestBase): ) call = UnaryCallWithSleep(messages_pb2.SimpleRequest()) - task = asyncio.ensure_future(call) await channel.close(None) @@ -133,6 +131,19 @@ class TestCloseChannel(AioTestBase): self.assertEqual(channel._ongoing_calls.size(), 0) + async def test_close_stream_unary(self): + channel = aio.insecure_channel(self._server_target) + stub = test_pb2_grpc.TestServiceStub(channel) + + calls = [stub.StreamingInputCall() for _ in range(2)] + + await channel.close() + + for call in calls: + self.assertTrue(call.cancelled()) + + self.assertEqual(channel._ongoing_calls.size(), 0) + async def test_close_stream_stream(self): channel = aio.insecure_channel(self._server_target) stub = test_pb2_grpc.TestServiceStub(channel) From d0b218ae18eb719e65eceae826220f6744e53e9a Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Mon, 3 Feb 2020 15:29:41 +0100 Subject: [PATCH 09/14] Once cancelled just destroy the channel --- src/python/grpcio/grpc/experimental/aio/_channel.py | 12 +----------- src/python/grpcio_tests/tests_aio/tests.json | 3 ++- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 586a431fabc..7682ca96df2 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -29,7 +29,6 @@ from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType, SerializingFunction) from ._utils import _timeout_to_deadline -_TIMEOUT_WAIT_FOR_CLOSE_ONGOING_CALLS_SEC = 0.1 _IMMUTABLE_EMPTY_TUPLE = tuple() _LOGGER = logging.getLogger(__name__) @@ -402,16 +401,7 @@ class Channel: calls = self._ongoing_calls.calls for call in calls: call.cancel() - - try: - await asyncio.wait_for(asyncio.gather(*self._ongoing_calls.calls, - loop=self._loop), - _TIMEOUT_WAIT_FOR_CLOSE_ONGOING_CALLS_SEC, - loop=self._loop) - except asyncio.TimeoutError: - _LOGGER.warning("Closing channel %s, closing RPCs timed out", - str(self)) - + self._channel.close() async def close(self, grace: Optional[float] = None): diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index 3fd3358e1dc..3a9242c8ed6 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -8,7 +8,8 @@ "unit.call_test.TestUnaryUnaryCall", "unit.channel_argument_test.TestChannelArgument", "unit.channel_test.TestChannel", - "unit.channel_test.Test_OngoingCalls", + "unit.close_channel_test.TestCloseChannel", + "unit.close_channel_test.TestOngoingCalls", "unit.connectivity_test.TestConnectivityState", "unit.done_callback_test.TestDoneCallback", "unit.init_test.TestInsecureChannel", From 2cef2fce3996c140702e9c460d909e881ba960e3 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Mon, 3 Feb 2020 15:52:26 +0100 Subject: [PATCH 10/14] Use set as data structure for trace ongoing calls --- .../grpcio/grpc/experimental/aio/_channel.py | 19 +++++++++++-------- .../tests_aio/unit/close_channel_test.py | 4 ++-- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 7682ca96df2..2210d2fd641 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -13,7 +13,7 @@ # limitations under the License. """Invocation-side implementation of gRPC Asyncio Python.""" import asyncio -from typing import Any, AsyncIterable, Optional, Sequence, Text +from typing import Any, AsyncIterable, Optional, Sequence, Set, Text import logging import grpc @@ -37,18 +37,18 @@ _LOGGER = logging.getLogger(__name__) class _OngoingCalls: """Internal class used for have visibility of the ongoing calls.""" - _calls: Sequence[_base_call.RpcContext] + _calls: Set[_base_call.RpcContext] def __init__(self): - self._calls = [] + self._calls = set() def _remove_call(self, call: _base_call.RpcContext): self._calls.remove(call) @property - def calls(self) -> Sequence[_base_call.RpcContext]: - """Returns a shallow copy of the ongoing calls sequence.""" - return self._calls[:] + def calls(self) -> Set[_base_call.RpcContext]: + """Returns the set of ongoing calls.""" + return self._calls def size(self) -> int: """Returns the number of ongoing calls.""" @@ -56,7 +56,7 @@ class _OngoingCalls: def trace_call(self, call: _base_call.RpcContext): """Adds and manages a new ongoing call.""" - self._calls.append(call) + self._calls.add(call) call.add_done_callback(self._remove_call) @@ -398,7 +398,10 @@ class Channel: if not pending: return - calls = self._ongoing_calls.calls + # A new set is created acting as a shallow copy because + # when cancellation happens the calls are automatically + # removed from the originally set. + calls = set(self._ongoing_calls.calls) for call in calls: call.cancel() diff --git a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py index 61bc18180bb..6807bb5b6cc 100644 --- a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py @@ -57,11 +57,11 @@ class TestOngoingCalls(unittest.TestCase): call = FakeCall() ongoing_calls.trace_call(call) self.assertEqual(ongoing_calls.size(), 1) - self.assertEqual(ongoing_calls.calls, [call]) + self.assertEqual(ongoing_calls.calls, set([call])) call.callback(call) self.assertEqual(ongoing_calls.size(), 0) - self.assertEqual(ongoing_calls.calls, []) + self.assertEqual(ongoing_calls.calls, set()) class TestCloseChannel(AioTestBase): From c94364f311296502490ff476fadd6c3d204193c1 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Mon, 3 Feb 2020 17:04:54 +0100 Subject: [PATCH 11/14] Use a weakset for storing ongoing calls --- .../grpcio/grpc/experimental/aio/_channel.py | 9 +++-- .../tests_aio/unit/close_channel_test.py | 40 +++++++++++-------- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 2210d2fd641..e0c761ae18e 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -13,7 +13,8 @@ # limitations under the License. """Invocation-side implementation of gRPC Asyncio Python.""" import asyncio -from typing import Any, AsyncIterable, Optional, Sequence, Set, Text +from typing import Any, AsyncIterable, Optional, Sequence, AbstractSet, Text +from weakref import WeakSet import logging import grpc @@ -37,10 +38,10 @@ _LOGGER = logging.getLogger(__name__) class _OngoingCalls: """Internal class used for have visibility of the ongoing calls.""" - _calls: Set[_base_call.RpcContext] + _calls: AbstractSet[_base_call.RpcContext] def __init__(self): - self._calls = set() + self._calls = WeakSet() def _remove_call(self, call: _base_call.RpcContext): self._calls.remove(call) @@ -401,7 +402,7 @@ class Channel: # A new set is created acting as a shallow copy because # when cancellation happens the calls are automatically # removed from the originally set. - calls = set(self._ongoing_calls.calls) + calls = WeakSet(data=self._ongoing_calls.calls) for call in calls: call.cancel() diff --git a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py index 6807bb5b6cc..3ae0baf62d7 100644 --- a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py @@ -16,6 +16,7 @@ import asyncio import logging import unittest +from weakref import WeakSet import grpc from grpc.experimental import aio @@ -32,36 +33,43 @@ _UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' class TestOngoingCalls(unittest.TestCase): - def test_trace_call(self): - - class FakeCall(_base_call.RpcContext): + class FakeCall(_base_call.RpcContext): - def add_done_callback(self, callback): - self.callback = callback + def add_done_callback(self, callback): + self.callback = callback - def cancel(self): - raise NotImplementedError + def cancel(self): + raise NotImplementedError - def cancelled(self): - raise NotImplementedError + def cancelled(self): + raise NotImplementedError - def done(self): - raise NotImplementedError + def done(self): + raise NotImplementedError - def time_remaining(self): - raise NotImplementedError + def time_remaining(self): + raise NotImplementedError + def test_trace_call(self): ongoing_calls = _OngoingCalls() self.assertEqual(ongoing_calls.size(), 0) - call = FakeCall() + call = TestOngoingCalls.FakeCall() ongoing_calls.trace_call(call) self.assertEqual(ongoing_calls.size(), 1) - self.assertEqual(ongoing_calls.calls, set([call])) + self.assertEqual(ongoing_calls.calls, WeakSet([call])) call.callback(call) self.assertEqual(ongoing_calls.size(), 0) - self.assertEqual(ongoing_calls.calls, set()) + self.assertEqual(ongoing_calls.calls, WeakSet()) + + def test_deleted_call(self): + ongoing_calls = _OngoingCalls() + + call = TestOngoingCalls.FakeCall() + ongoing_calls.trace_call(call) + del(call) + self.assertEqual(ongoing_calls.size(), 0) class TestCloseChannel(AioTestBase): From 3a8be1784cc1986750c91092145c3fadd636180d Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Mon, 3 Feb 2020 17:06:05 +0100 Subject: [PATCH 12/14] make YAPF happy --- src/python/grpcio/grpc/experimental/aio/_channel.py | 2 +- src/python/grpcio_tests/tests_aio/unit/close_channel_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index e0c761ae18e..24692559c69 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -405,7 +405,7 @@ class Channel: calls = WeakSet(data=self._ongoing_calls.calls) for call in calls: call.cancel() - + self._channel.close() async def close(self, grace: Optional[float] = None): diff --git a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py index 3ae0baf62d7..05ead8834a0 100644 --- a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py @@ -68,7 +68,7 @@ class TestOngoingCalls(unittest.TestCase): call = TestOngoingCalls.FakeCall() ongoing_calls.trace_call(call) - del(call) + del (call) self.assertEqual(ongoing_calls.size(), 0) From 2d89ef0acdbd5817e3aff41a707cd9f4e4f8e9e1 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Mon, 3 Feb 2020 17:07:20 +0100 Subject: [PATCH 13/14] Fix pylint issue --- src/python/grpcio/grpc/experimental/aio/_channel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 24692559c69..4abb6d69780 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -47,7 +47,7 @@ class _OngoingCalls: self._calls.remove(call) @property - def calls(self) -> Set[_base_call.RpcContext]: + def calls(self) -> AbstractSet[_base_call.RpcContext]: """Returns the set of ongoing calls.""" return self._calls From 5cd4e133bf978ae31132a879e0eb115a5db67936 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Mon, 3 Feb 2020 21:52:02 +0100 Subject: [PATCH 14/14] Increase timeout in some tests --- .../grpcio_tests/tests_aio/unit/channel_test.py | 1 - .../tests_aio/unit/close_channel_test.py | 2 +- .../grpcio_tests/tests_aio/unit/interceptor_test.py | 13 ++++++++++--- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py index 9e10300f917..10949ac180c 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -19,7 +19,6 @@ import unittest import grpc from grpc.experimental import aio -from grpc.experimental.aio import _base_call from src.proto.grpc.testing import messages_pb2, test_pb2_grpc from tests.unit.framework.common import test_constants diff --git a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py index 05ead8834a0..b749603d52c 100644 --- a/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/close_channel_test.py @@ -90,7 +90,7 @@ class TestCloseChannel(AioTestBase): call = UnaryCallWithSleep(messages_pb2.SimpleRequest()) - await channel.close(grace=UNARY_CALL_WITH_SLEEP_VALUE * 2) + await channel.close(grace=UNARY_CALL_WITH_SLEEP_VALUE * 4) self.assertEqual(grpc.StatusCode.OK, await call.code()) diff --git a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py index e2717082ef2..9fa08a78806 100644 --- a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py @@ -29,6 +29,7 @@ _INITIAL_METADATA_TO_INJECT = ( (_INITIAL_METADATA_KEY, 'extra info'), (_TRAILING_METADATA_KEY, b'\x13\x37'), ) +_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED = 1.0 class TestUnaryUnaryClientInterceptor(AioTestBase): @@ -607,7 +608,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): await call try: - await asyncio.wait_for(called.wait(), timeout=0.1) + await asyncio.wait_for( + called.wait(), + timeout=_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED) except: self.fail("Callback was not called") @@ -640,7 +643,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): call.add_done_callback(callback) try: - await asyncio.wait_for(called.wait(), timeout=0.1) + await asyncio.wait_for( + called.wait(), + timeout=_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED) except: self.fail("Callback was not called") @@ -673,7 +678,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): await call try: - await asyncio.wait_for(called.wait(), timeout=0.1) + await asyncio.wait_for( + called.wait(), + timeout=_TIMEOUT_CHECK_IF_CALLBACK_WAS_CALLED) except: self.fail("Callback was not called")