Improve the surface API & rewrite the test

pull/21621/head
Lidi Zheng 5 years ago
parent fa62339430
commit 650ba93a61
  1. 6
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi
  2. 35
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  3. 55
      src/python/grpcio/grpc/experimental/aio/_channel.py
  4. 78
      src/python/grpcio_tests/tests_aio/unit/connectivity_test.py

@ -12,9 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
cdef enum AioChannelStatus:
AIO_CHANNEL_STATUS_UNKNOWN
AIO_CHANNEL_STATUS_READY
AIO_CHANNEL_STATUS_DESTROYED
cdef class AioChannel:
cdef:
grpc_channel * channel
CallbackCompletionQueue cq
bytes _target
object _loop
AioChannelStatus _status

@ -13,10 +13,14 @@
# limitations under the License.
class _WatchConnectivityFailed(Exception): pass
class _WatchConnectivityFailed(Exception):
"""Dedicated exception class for watch connectivity failed.
It might be failed due to deadline exceeded, or the channel is closing.
"""
cdef CallbackFailureHandler _WATCH_CONNECTIVITY_FAILURE_HANDLER = CallbackFailureHandler(
'watch_connectivity_state',
'Maybe timed out.',
'Timed out or channel closed.',
_WatchConnectivityFailed)
@ -38,6 +42,7 @@ cdef class AioChannel:
channel_args.c_args(),
NULL)
self._loop = asyncio.get_event_loop()
self._status = AIO_CHANNEL_STATUS_READY
def __repr__(self):
class_name = self.__class__.__name__
@ -45,6 +50,7 @@ cdef class AioChannel:
return f"<{class_name} {id_}>"
def check_connectivity_state(self, bint try_to_connect):
"""A Cython wrapper for Core's check connectivity state API."""
return grpc_channel_check_connectivity_state(
self.channel,
try_to_connect,
@ -53,12 +59,21 @@ cdef class AioChannel:
async def watch_connectivity_state(self,
grpc_connectivity_state last_observed_state,
object deadline):
"""Watch for one connectivity state change.
Keeps mirroring the behavior from Core, so we can easily switch to
other design of API if necessary.
"""
if self._status == AIO_CHANNEL_STATUS_DESTROYED:
# TODO(lidiz) switch to UsageError
raise RuntimeError('Channel is closed.')
cdef gpr_timespec c_deadline = _timespec_from_time(deadline)
cdef object future = self._loop.create_future()
cdef CallbackWrapper wrapper = CallbackWrapper(
future,
_WATCH_CONNECTIVITY_FAILURE_HANDLER)
cpython.Py_INCREF(wrapper)
grpc_channel_watch_connectivity_state(
self.channel,
last_observed_state,
@ -66,15 +81,24 @@ cdef class AioChannel:
self.cq.c_ptr(),
wrapper.c_functor())
# NOTE(lidiz) The callback will be invoked after the channel is closed
# with a failure state. We need to keep wrapper alive until then, or we
# will observe a segfault.
def dealloc_wrapper(_):
cpython.Py_DECREF(wrapper)
future.add_done_callback(dealloc_wrapper)
try:
await future
except _WatchConnectivityFailed:
return None
return False
else:
return self.check_connectivity_state(False)
return True
def close(self):
grpc_channel_destroy(self.channel)
self._status = AIO_CHANNEL_STATUS_DESTROYED
def call(self,
bytes method,
@ -85,5 +109,8 @@ cdef class AioChannel:
Returns:
The _AioCall object.
"""
if self._status == AIO_CHANNEL_STATUS_DESTROYED:
# TODO(lidiz) switch to UsageError
raise RuntimeError('Channel is closed.')
cdef _AioCall call = _AioCall(self, deadline, method, credentials)
return call

@ -13,7 +13,6 @@
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
import time
from typing import Any, Optional, Sequence, Text, Tuple
import grpc
@ -225,50 +224,54 @@ class Channel:
self._channel = cygrpc.AioChannel(_common.encode(target), options,
credentials)
def check_connectivity_state(self, try_to_connect: bool = False
) -> grpc.ChannelConnectivity:
def get_state(self,
try_to_connect: bool = False) -> grpc.ChannelConnectivity:
"""Check the connectivity state of a channel.
This is an EXPERIMENTAL API.
It's the nature of connectivity states to change. The returned
connectivity state might become obsolete soon. Combining
"Channel.wait_for_state_change" we guarantee the convergence of
connectivity state between application and ground truth.
Args:
try_to_connect: a bool indicate whether the Channel should try to connect to peer or not.
try_to_connect: a bool indicate whether the Channel should try to
connect to peer or not.
Returns:
A ChannelConnectivity object.
"""
result = self._channel.check_connectivity_state(try_to_connect)
return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]
return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY.get(
result)
async def watch_connectivity_state(
async def wait_for_state_change(
self,
last_observed_state: grpc.ChannelConnectivity,
timeout_seconds: Optional[float] = None,
) -> Optional[grpc.ChannelConnectivity]:
"""Watch for a change in connectivity state.
) -> None:
"""Wait for a change in connectivity state.
This is an EXPERIMENTAL API.
Once the channel connectivity state is different from
last_observed_state, the function will return the new connectivity
state. If deadline expires BEFORE the state is changed, None will be
returned.
The function blocks until there is a change in the channel connectivity
state from the "last_observed_state". If the state is already
different, this function will return immediately.
Args:
try_to_connect: a bool indicate whether the Channel should try to connect to peer or not.
There is an inherent race between the invocation of
"Channel.wait_for_state_change" and "Channel.get_state". The state can
arbitrary times during the race, so there is no way to observe every
state transition.
Returns:
A ChannelConnectivity object or None.
If there is a need to put a timeout for this function, please refer to
"asyncio.wait_for".
Args:
last_observed_state: A grpc.ChannelConnectivity object representing
the last known state.
"""
deadline = time.time(
) + timeout_seconds if timeout_seconds is not None else None
result = await self._channel.watch_connectivity_state(
last_observed_state.value[0], deadline)
if result is None:
return None
else:
return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[
result]
assert await self._channel.watch_connectivity_state(
last_observed_state.value[0], None)
def unary_unary(
self,

@ -13,6 +13,7 @@
# limitations under the License.
"""Tests behavior of the connectivity state."""
import asyncio
import logging
import threading
import unittest
@ -29,6 +30,13 @@ from tests_aio.unit._test_base import AioTestBase
_INVALID_BACKEND_ADDRESS = '0.0.0.1:2'
async def _block_until_certain_state(channel, expected_state):
state = channel.get_state()
while state != expected_state:
await channel.wait_for_state_change(state)
state = channel.get_state()
class TestConnectivityState(AioTestBase):
async def setUp(self):
@ -38,59 +46,65 @@ class TestConnectivityState(AioTestBase):
await self._server.stop(None)
async def test_unavailable_backend(self):
channel = aio.insecure_channel(_INVALID_BACKEND_ADDRESS)
async with aio.insecure_channel(_INVALID_BACKEND_ADDRESS) as channel:
self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.check_connectivity_state(False))
channel.get_state(False))
self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.check_connectivity_state(True))
self.assertEqual(
grpc.ChannelConnectivity.CONNECTING, await
channel.watch_connectivity_state(grpc.ChannelConnectivity.IDLE))
self.assertEqual(
grpc.ChannelConnectivity.TRANSIENT_FAILURE, await
channel.watch_connectivity_state(grpc.ChannelConnectivity.CONNECTING
))
channel.get_state(True))
await channel.close()
async def waiting_transient_failure():
state = channel.get_state()
while state != grpc.ChannelConnectivity.TRANSIENT_FAILURE:
channel.wait_for_state_change(state)
async def test_normal_backend(self):
channel = aio.insecure_channel(self._server_address)
# Should not time out
await asyncio.wait_for(
_block_until_certain_state(
channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE),
test_constants.SHORT_TIMEOUT)
current_state = channel.check_connectivity_state(True)
async def test_normal_backend(self):
async with aio.insecure_channel(self._server_address) as channel:
current_state = channel.get_state(True)
self.assertEqual(grpc.ChannelConnectivity.IDLE, current_state)
deadline = time.time() + test_constants.SHORT_TIMEOUT
while current_state != grpc.ChannelConnectivity.READY:
current_state = await channel.watch_connectivity_state(
current_state, deadline - time.time())
self.assertIsNotNone(current_state)
await channel.close()
# Should not time out
await asyncio.wait_for(
_block_until_certain_state(channel,
grpc.ChannelConnectivity.READY),
test_constants.SHORT_TIMEOUT)
async def test_timeout(self):
channel = aio.insecure_channel(self._server_address)
async with aio.insecure_channel(self._server_address) as channel:
self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.check_connectivity_state(False))
channel.get_state(False))
# If timed out, the function should return None.
self.assertIsNone(await channel.watch_connectivity_state(
grpc.ChannelConnectivity.IDLE, test_constants.SHORT_TIMEOUT))
await channel.close()
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(
_block_until_certain_state(channel,
grpc.ChannelConnectivity.READY),
test_constants.SHORT_TIMEOUT)
async def test_shutdown(self):
channel = aio.insecure_channel(self._server_address)
self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.check_connectivity_state(False))
channel.get_state(False))
await channel.close()
self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN,
channel.check_connectivity_state(False))
channel.get_state(True))
self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN,
channel.get_state(False))
# It can raise Exception since it is an usage error, but it should not
# segfault or abort.
with self.assertRaises(Exception):
await channel.wait_for_state_change(
grpc.ChannelConnectivity.SHUTDOWN)
if __name__ == '__main__':

Loading…
Cancel
Save