From 650ba93a614281e6c8927865cda03c37819c2649 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Fri, 10 Jan 2020 15:35:26 -0800 Subject: [PATCH] Improve the surface API & rewrite the test --- .../grpc/_cython/_cygrpc/aio/channel.pxd.pxi | 6 ++ .../grpc/_cython/_cygrpc/aio/channel.pyx.pxi | 35 +++++++- .../grpcio/grpc/experimental/aio/_channel.py | 55 ++++++------ .../tests_aio/unit/connectivity_test.py | 90 +++++++++++-------- 4 files changed, 118 insertions(+), 68 deletions(-) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi index 1e9f6347a77..68a0d11b748 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi @@ -12,9 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +cdef enum AioChannelStatus: + AIO_CHANNEL_STATUS_UNKNOWN + AIO_CHANNEL_STATUS_READY + AIO_CHANNEL_STATUS_DESTROYED + cdef class AioChannel: cdef: grpc_channel * channel CallbackCompletionQueue cq bytes _target object _loop + AioChannelStatus _status diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi index e6ac6ec2791..3be108412ec 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi @@ -13,10 +13,14 @@ # limitations under the License. -class _WatchConnectivityFailed(Exception): pass +class _WatchConnectivityFailed(Exception): + """Dedicated exception class for watch connectivity failed. + + It might be failed due to deadline exceeded, or the channel is closing. + """ cdef CallbackFailureHandler _WATCH_CONNECTIVITY_FAILURE_HANDLER = CallbackFailureHandler( 'watch_connectivity_state', - 'Maybe timed out.', + 'Timed out or channel closed.', _WatchConnectivityFailed) @@ -38,6 +42,7 @@ cdef class AioChannel: channel_args.c_args(), NULL) self._loop = asyncio.get_event_loop() + self._status = AIO_CHANNEL_STATUS_READY def __repr__(self): class_name = self.__class__.__name__ @@ -45,6 +50,7 @@ cdef class AioChannel: return f"<{class_name} {id_}>" def check_connectivity_state(self, bint try_to_connect): + """A Cython wrapper for Core's check connectivity state API.""" return grpc_channel_check_connectivity_state( self.channel, try_to_connect, @@ -53,12 +59,21 @@ cdef class AioChannel: async def watch_connectivity_state(self, grpc_connectivity_state last_observed_state, object deadline): + """Watch for one connectivity state change. + + Keeps mirroring the behavior from Core, so we can easily switch to + other design of API if necessary. + """ + if self._status == AIO_CHANNEL_STATUS_DESTROYED: + # TODO(lidiz) switch to UsageError + raise RuntimeError('Channel is closed.') cdef gpr_timespec c_deadline = _timespec_from_time(deadline) cdef object future = self._loop.create_future() cdef CallbackWrapper wrapper = CallbackWrapper( future, _WATCH_CONNECTIVITY_FAILURE_HANDLER) + cpython.Py_INCREF(wrapper) grpc_channel_watch_connectivity_state( self.channel, last_observed_state, @@ -66,15 +81,24 @@ cdef class AioChannel: self.cq.c_ptr(), wrapper.c_functor()) + # NOTE(lidiz) The callback will be invoked after the channel is closed + # with a failure state. We need to keep wrapper alive until then, or we + # will observe a segfault. + def dealloc_wrapper(_): + cpython.Py_DECREF(wrapper) + future.add_done_callback(dealloc_wrapper) + try: await future except _WatchConnectivityFailed: - return None + return False else: - return self.check_connectivity_state(False) + return True + def close(self): grpc_channel_destroy(self.channel) + self._status = AIO_CHANNEL_STATUS_DESTROYED def call(self, bytes method, @@ -85,5 +109,8 @@ cdef class AioChannel: Returns: The _AioCall object. """ + if self._status == AIO_CHANNEL_STATUS_DESTROYED: + # TODO(lidiz) switch to UsageError + raise RuntimeError('Channel is closed.') cdef _AioCall call = _AioCall(self, deadline, method, credentials) return call diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 4014e823680..4997c4e3feb 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -13,7 +13,6 @@ # limitations under the License. """Invocation-side implementation of gRPC Asyncio Python.""" import asyncio -import time from typing import Any, Optional, Sequence, Text, Tuple import grpc @@ -225,50 +224,54 @@ class Channel: self._channel = cygrpc.AioChannel(_common.encode(target), options, credentials) - def check_connectivity_state(self, try_to_connect: bool = False - ) -> grpc.ChannelConnectivity: + def get_state(self, + try_to_connect: bool = False) -> grpc.ChannelConnectivity: """Check the connectivity state of a channel. This is an EXPERIMENTAL API. + It's the nature of connectivity states to change. The returned + connectivity state might become obsolete soon. Combining + "Channel.wait_for_state_change" we guarantee the convergence of + connectivity state between application and ground truth. + Args: - try_to_connect: a bool indicate whether the Channel should try to connect to peer or not. + try_to_connect: a bool indicate whether the Channel should try to + connect to peer or not. Returns: A ChannelConnectivity object. """ result = self._channel.check_connectivity_state(try_to_connect) - return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result] + return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY.get( + result) - async def watch_connectivity_state( + async def wait_for_state_change( self, last_observed_state: grpc.ChannelConnectivity, - timeout_seconds: Optional[float] = None, - ) -> Optional[grpc.ChannelConnectivity]: - """Watch for a change in connectivity state. + ) -> None: + """Wait for a change in connectivity state. This is an EXPERIMENTAL API. - Once the channel connectivity state is different from - last_observed_state, the function will return the new connectivity - state. If deadline expires BEFORE the state is changed, None will be - returned. + The function blocks until there is a change in the channel connectivity + state from the "last_observed_state". If the state is already + different, this function will return immediately. - Args: - try_to_connect: a bool indicate whether the Channel should try to connect to peer or not. + There is an inherent race between the invocation of + "Channel.wait_for_state_change" and "Channel.get_state". The state can + arbitrary times during the race, so there is no way to observe every + state transition. - Returns: - A ChannelConnectivity object or None. + If there is a need to put a timeout for this function, please refer to + "asyncio.wait_for". + + Args: + last_observed_state: A grpc.ChannelConnectivity object representing + the last known state. """ - deadline = time.time( - ) + timeout_seconds if timeout_seconds is not None else None - result = await self._channel.watch_connectivity_state( - last_observed_state.value[0], deadline) - if result is None: - return None - else: - return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[ - result] + assert await self._channel.watch_connectivity_state( + last_observed_state.value[0], None) def unary_unary( self, diff --git a/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py b/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py index de4497a15c3..a9c8dac39da 100644 --- a/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests behavior of the connectivity state.""" +import asyncio import logging import threading import unittest @@ -29,6 +30,13 @@ from tests_aio.unit._test_base import AioTestBase _INVALID_BACKEND_ADDRESS = '0.0.0.1:2' +async def _block_until_certain_state(channel, expected_state): + state = channel.get_state() + while state != expected_state: + await channel.wait_for_state_change(state) + state = channel.get_state() + + class TestConnectivityState(AioTestBase): async def setUp(self): @@ -38,59 +46,65 @@ class TestConnectivityState(AioTestBase): await self._server.stop(None) async def test_unavailable_backend(self): - channel = aio.insecure_channel(_INVALID_BACKEND_ADDRESS) - - self.assertEqual(grpc.ChannelConnectivity.IDLE, - channel.check_connectivity_state(False)) - self.assertEqual(grpc.ChannelConnectivity.IDLE, - channel.check_connectivity_state(True)) - self.assertEqual( - grpc.ChannelConnectivity.CONNECTING, await - channel.watch_connectivity_state(grpc.ChannelConnectivity.IDLE)) - self.assertEqual( - grpc.ChannelConnectivity.TRANSIENT_FAILURE, await - channel.watch_connectivity_state(grpc.ChannelConnectivity.CONNECTING - )) - - await channel.close() + async with aio.insecure_channel(_INVALID_BACKEND_ADDRESS) as channel: + self.assertEqual(grpc.ChannelConnectivity.IDLE, + channel.get_state(False)) + self.assertEqual(grpc.ChannelConnectivity.IDLE, + channel.get_state(True)) + + async def waiting_transient_failure(): + state = channel.get_state() + while state != grpc.ChannelConnectivity.TRANSIENT_FAILURE: + channel.wait_for_state_change(state) + + # Should not time out + await asyncio.wait_for( + _block_until_certain_state( + channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE), + test_constants.SHORT_TIMEOUT) async def test_normal_backend(self): - channel = aio.insecure_channel(self._server_address) - - current_state = channel.check_connectivity_state(True) - self.assertEqual(grpc.ChannelConnectivity.IDLE, current_state) - - deadline = time.time() + test_constants.SHORT_TIMEOUT + async with aio.insecure_channel(self._server_address) as channel: + current_state = channel.get_state(True) + self.assertEqual(grpc.ChannelConnectivity.IDLE, current_state) - while current_state != grpc.ChannelConnectivity.READY: - current_state = await channel.watch_connectivity_state( - current_state, deadline - time.time()) - self.assertIsNotNone(current_state) - - await channel.close() + # Should not time out + await asyncio.wait_for( + _block_until_certain_state(channel, + grpc.ChannelConnectivity.READY), + test_constants.SHORT_TIMEOUT) async def test_timeout(self): - channel = aio.insecure_channel(self._server_address) - - self.assertEqual(grpc.ChannelConnectivity.IDLE, - channel.check_connectivity_state(False)) + async with aio.insecure_channel(self._server_address) as channel: + self.assertEqual(grpc.ChannelConnectivity.IDLE, + channel.get_state(False)) - # If timed out, the function should return None. - self.assertIsNone(await channel.watch_connectivity_state( - grpc.ChannelConnectivity.IDLE, test_constants.SHORT_TIMEOUT)) - - await channel.close() + # If timed out, the function should return None. + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for( + _block_until_certain_state(channel, + grpc.ChannelConnectivity.READY), + test_constants.SHORT_TIMEOUT) async def test_shutdown(self): channel = aio.insecure_channel(self._server_address) self.assertEqual(grpc.ChannelConnectivity.IDLE, - channel.check_connectivity_state(False)) + channel.get_state(False)) await channel.close() self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN, - channel.check_connectivity_state(False)) + channel.get_state(True)) + + self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN, + channel.get_state(False)) + + # It can raise Exception since it is an usage error, but it should not + # segfault or abort. + with self.assertRaises(Exception): + await channel.wait_for_state_change( + grpc.ChannelConnectivity.SHUTDOWN) if __name__ == '__main__':