Improve the surface API & rewrite the test

pull/21621/head
Lidi Zheng 5 years ago
parent fa62339430
commit 650ba93a61
  1. 6
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi
  2. 35
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  3. 55
      src/python/grpcio/grpc/experimental/aio/_channel.py
  4. 78
      src/python/grpcio_tests/tests_aio/unit/connectivity_test.py

@ -12,9 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
cdef enum AioChannelStatus:
AIO_CHANNEL_STATUS_UNKNOWN
AIO_CHANNEL_STATUS_READY
AIO_CHANNEL_STATUS_DESTROYED
cdef class AioChannel: cdef class AioChannel:
cdef: cdef:
grpc_channel * channel grpc_channel * channel
CallbackCompletionQueue cq CallbackCompletionQueue cq
bytes _target bytes _target
object _loop object _loop
AioChannelStatus _status

@ -13,10 +13,14 @@
# limitations under the License. # limitations under the License.
class _WatchConnectivityFailed(Exception): pass class _WatchConnectivityFailed(Exception):
"""Dedicated exception class for watch connectivity failed.
It might be failed due to deadline exceeded, or the channel is closing.
"""
cdef CallbackFailureHandler _WATCH_CONNECTIVITY_FAILURE_HANDLER = CallbackFailureHandler( cdef CallbackFailureHandler _WATCH_CONNECTIVITY_FAILURE_HANDLER = CallbackFailureHandler(
'watch_connectivity_state', 'watch_connectivity_state',
'Maybe timed out.', 'Timed out or channel closed.',
_WatchConnectivityFailed) _WatchConnectivityFailed)
@ -38,6 +42,7 @@ cdef class AioChannel:
channel_args.c_args(), channel_args.c_args(),
NULL) NULL)
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self._status = AIO_CHANNEL_STATUS_READY
def __repr__(self): def __repr__(self):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
@ -45,6 +50,7 @@ cdef class AioChannel:
return f"<{class_name} {id_}>" return f"<{class_name} {id_}>"
def check_connectivity_state(self, bint try_to_connect): def check_connectivity_state(self, bint try_to_connect):
"""A Cython wrapper for Core's check connectivity state API."""
return grpc_channel_check_connectivity_state( return grpc_channel_check_connectivity_state(
self.channel, self.channel,
try_to_connect, try_to_connect,
@ -53,12 +59,21 @@ cdef class AioChannel:
async def watch_connectivity_state(self, async def watch_connectivity_state(self,
grpc_connectivity_state last_observed_state, grpc_connectivity_state last_observed_state,
object deadline): object deadline):
"""Watch for one connectivity state change.
Keeps mirroring the behavior from Core, so we can easily switch to
other design of API if necessary.
"""
if self._status == AIO_CHANNEL_STATUS_DESTROYED:
# TODO(lidiz) switch to UsageError
raise RuntimeError('Channel is closed.')
cdef gpr_timespec c_deadline = _timespec_from_time(deadline) cdef gpr_timespec c_deadline = _timespec_from_time(deadline)
cdef object future = self._loop.create_future() cdef object future = self._loop.create_future()
cdef CallbackWrapper wrapper = CallbackWrapper( cdef CallbackWrapper wrapper = CallbackWrapper(
future, future,
_WATCH_CONNECTIVITY_FAILURE_HANDLER) _WATCH_CONNECTIVITY_FAILURE_HANDLER)
cpython.Py_INCREF(wrapper)
grpc_channel_watch_connectivity_state( grpc_channel_watch_connectivity_state(
self.channel, self.channel,
last_observed_state, last_observed_state,
@ -66,15 +81,24 @@ cdef class AioChannel:
self.cq.c_ptr(), self.cq.c_ptr(),
wrapper.c_functor()) wrapper.c_functor())
# NOTE(lidiz) The callback will be invoked after the channel is closed
# with a failure state. We need to keep wrapper alive until then, or we
# will observe a segfault.
def dealloc_wrapper(_):
cpython.Py_DECREF(wrapper)
future.add_done_callback(dealloc_wrapper)
try: try:
await future await future
except _WatchConnectivityFailed: except _WatchConnectivityFailed:
return None return False
else: else:
return self.check_connectivity_state(False) return True
def close(self): def close(self):
grpc_channel_destroy(self.channel) grpc_channel_destroy(self.channel)
self._status = AIO_CHANNEL_STATUS_DESTROYED
def call(self, def call(self,
bytes method, bytes method,
@ -85,5 +109,8 @@ cdef class AioChannel:
Returns: Returns:
The _AioCall object. The _AioCall object.
""" """
if self._status == AIO_CHANNEL_STATUS_DESTROYED:
# TODO(lidiz) switch to UsageError
raise RuntimeError('Channel is closed.')
cdef _AioCall call = _AioCall(self, deadline, method, credentials) cdef _AioCall call = _AioCall(self, deadline, method, credentials)
return call return call

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python.""" """Invocation-side implementation of gRPC Asyncio Python."""
import asyncio import asyncio
import time
from typing import Any, Optional, Sequence, Text, Tuple from typing import Any, Optional, Sequence, Text, Tuple
import grpc import grpc
@ -225,50 +224,54 @@ class Channel:
self._channel = cygrpc.AioChannel(_common.encode(target), options, self._channel = cygrpc.AioChannel(_common.encode(target), options,
credentials) credentials)
def check_connectivity_state(self, try_to_connect: bool = False def get_state(self,
) -> grpc.ChannelConnectivity: try_to_connect: bool = False) -> grpc.ChannelConnectivity:
"""Check the connectivity state of a channel. """Check the connectivity state of a channel.
This is an EXPERIMENTAL API. This is an EXPERIMENTAL API.
It's the nature of connectivity states to change. The returned
connectivity state might become obsolete soon. Combining
"Channel.wait_for_state_change" we guarantee the convergence of
connectivity state between application and ground truth.
Args: Args:
try_to_connect: a bool indicate whether the Channel should try to connect to peer or not. try_to_connect: a bool indicate whether the Channel should try to
connect to peer or not.
Returns: Returns:
A ChannelConnectivity object. A ChannelConnectivity object.
""" """
result = self._channel.check_connectivity_state(try_to_connect) result = self._channel.check_connectivity_state(try_to_connect)
return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result] return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY.get(
result)
async def watch_connectivity_state( async def wait_for_state_change(
self, self,
last_observed_state: grpc.ChannelConnectivity, last_observed_state: grpc.ChannelConnectivity,
timeout_seconds: Optional[float] = None, ) -> None:
) -> Optional[grpc.ChannelConnectivity]: """Wait for a change in connectivity state.
"""Watch for a change in connectivity state.
This is an EXPERIMENTAL API. This is an EXPERIMENTAL API.
Once the channel connectivity state is different from The function blocks until there is a change in the channel connectivity
last_observed_state, the function will return the new connectivity state from the "last_observed_state". If the state is already
state. If deadline expires BEFORE the state is changed, None will be different, this function will return immediately.
returned.
Args: There is an inherent race between the invocation of
try_to_connect: a bool indicate whether the Channel should try to connect to peer or not. "Channel.wait_for_state_change" and "Channel.get_state". The state can
arbitrary times during the race, so there is no way to observe every
state transition.
Returns: If there is a need to put a timeout for this function, please refer to
A ChannelConnectivity object or None. "asyncio.wait_for".
Args:
last_observed_state: A grpc.ChannelConnectivity object representing
the last known state.
""" """
deadline = time.time( assert await self._channel.watch_connectivity_state(
) + timeout_seconds if timeout_seconds is not None else None last_observed_state.value[0], None)
result = await self._channel.watch_connectivity_state(
last_observed_state.value[0], deadline)
if result is None:
return None
else:
return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[
result]
def unary_unary( def unary_unary(
self, self,

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Tests behavior of the connectivity state.""" """Tests behavior of the connectivity state."""
import asyncio
import logging import logging
import threading import threading
import unittest import unittest
@ -29,6 +30,13 @@ from tests_aio.unit._test_base import AioTestBase
_INVALID_BACKEND_ADDRESS = '0.0.0.1:2' _INVALID_BACKEND_ADDRESS = '0.0.0.1:2'
async def _block_until_certain_state(channel, expected_state):
state = channel.get_state()
while state != expected_state:
await channel.wait_for_state_change(state)
state = channel.get_state()
class TestConnectivityState(AioTestBase): class TestConnectivityState(AioTestBase):
async def setUp(self): async def setUp(self):
@ -38,59 +46,65 @@ class TestConnectivityState(AioTestBase):
await self._server.stop(None) await self._server.stop(None)
async def test_unavailable_backend(self): async def test_unavailable_backend(self):
channel = aio.insecure_channel(_INVALID_BACKEND_ADDRESS) async with aio.insecure_channel(_INVALID_BACKEND_ADDRESS) as channel:
self.assertEqual(grpc.ChannelConnectivity.IDLE, self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.check_connectivity_state(False)) channel.get_state(False))
self.assertEqual(grpc.ChannelConnectivity.IDLE, self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.check_connectivity_state(True)) channel.get_state(True))
self.assertEqual(
grpc.ChannelConnectivity.CONNECTING, await
channel.watch_connectivity_state(grpc.ChannelConnectivity.IDLE))
self.assertEqual(
grpc.ChannelConnectivity.TRANSIENT_FAILURE, await
channel.watch_connectivity_state(grpc.ChannelConnectivity.CONNECTING
))
await channel.close() async def waiting_transient_failure():
state = channel.get_state()
while state != grpc.ChannelConnectivity.TRANSIENT_FAILURE:
channel.wait_for_state_change(state)
async def test_normal_backend(self): # Should not time out
channel = aio.insecure_channel(self._server_address) await asyncio.wait_for(
_block_until_certain_state(
channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE),
test_constants.SHORT_TIMEOUT)
current_state = channel.check_connectivity_state(True) async def test_normal_backend(self):
async with aio.insecure_channel(self._server_address) as channel:
current_state = channel.get_state(True)
self.assertEqual(grpc.ChannelConnectivity.IDLE, current_state) self.assertEqual(grpc.ChannelConnectivity.IDLE, current_state)
deadline = time.time() + test_constants.SHORT_TIMEOUT # Should not time out
await asyncio.wait_for(
while current_state != grpc.ChannelConnectivity.READY: _block_until_certain_state(channel,
current_state = await channel.watch_connectivity_state( grpc.ChannelConnectivity.READY),
current_state, deadline - time.time()) test_constants.SHORT_TIMEOUT)
self.assertIsNotNone(current_state)
await channel.close()
async def test_timeout(self): async def test_timeout(self):
channel = aio.insecure_channel(self._server_address) async with aio.insecure_channel(self._server_address) as channel:
self.assertEqual(grpc.ChannelConnectivity.IDLE, self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.check_connectivity_state(False)) channel.get_state(False))
# If timed out, the function should return None. # If timed out, the function should return None.
self.assertIsNone(await channel.watch_connectivity_state( with self.assertRaises(asyncio.TimeoutError):
grpc.ChannelConnectivity.IDLE, test_constants.SHORT_TIMEOUT)) await asyncio.wait_for(
_block_until_certain_state(channel,
await channel.close() grpc.ChannelConnectivity.READY),
test_constants.SHORT_TIMEOUT)
async def test_shutdown(self): async def test_shutdown(self):
channel = aio.insecure_channel(self._server_address) channel = aio.insecure_channel(self._server_address)
self.assertEqual(grpc.ChannelConnectivity.IDLE, self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.check_connectivity_state(False)) channel.get_state(False))
await channel.close() await channel.close()
self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN, self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN,
channel.check_connectivity_state(False)) channel.get_state(True))
self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN,
channel.get_state(False))
# It can raise Exception since it is an usage error, but it should not
# segfault or abort.
with self.assertRaises(Exception):
await channel.wait_for_state_change(
grpc.ChannelConnectivity.SHUTDOWN)
if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save