Implement connectivity state related APIs

pull/21621/head
Lidi Zheng 5 years ago
parent b0d7e680cb
commit 5f0a70973e
  1. 1
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi
  2. 39
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  3. 48
      src/python/grpcio/grpc/experimental/aio/_channel.py
  4. 98
      src/python/grpcio_tests/tests_aio/unit/connectivity_test.py

@ -17,3 +17,4 @@ cdef class AioChannel:
grpc_channel * channel
CallbackCompletionQueue cq
bytes _target
object _loop

@ -13,14 +13,19 @@
# limitations under the License.
class _WatchConnectivityFailed(Exception): pass
cdef CallbackFailureHandler _WATCH_CONNECTIVITY_FAILURE_HANDLER = CallbackFailureHandler(
'watch_connectivity_state',
'Maybe timed out.',
_WatchConnectivityFailed)
cdef class AioChannel:
def __cinit__(self, bytes target, tuple options, ChannelCredentials credentials):
if options is None:
options = ()
cdef _ChannelArgs channel_args = _ChannelArgs(options)
self._target = target
self.cq = CallbackCompletionQueue()
if credentials is None:
self.channel = grpc_insecure_channel_create(
<char *>target,
@ -32,12 +37,42 @@ cdef class AioChannel:
<char *> target,
channel_args.c_args(),
NULL)
self._loop = asyncio.get_event_loop()
def __repr__(self):
class_name = self.__class__.__name__
id_ = id(self)
return f"<{class_name} {id_}>"
def check_connectivity_state(self, bint try_to_connect):
return grpc_channel_check_connectivity_state(
self.channel,
try_to_connect,
)
async def watch_connectivity_state(self,
grpc_connectivity_state last_observed_state,
object deadline):
cdef gpr_timespec c_deadline = _timespec_from_time(deadline)
cdef object future = self._loop.create_future()
cdef CallbackWrapper wrapper = CallbackWrapper(
future,
_WATCH_CONNECTIVITY_FAILURE_HANDLER)
grpc_channel_watch_connectivity_state(
self.channel,
last_observed_state,
c_deadline,
self.cq.c_ptr(),
wrapper.c_functor())
try:
await future
except _WatchConnectivityFailed:
return None
else:
return self.check_connectivity_state(False)
def close(self):
grpc_channel_destroy(self.channel)

@ -13,7 +13,8 @@
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
from typing import Any, Optional, Sequence, Text
import time
from typing import Any, Optional, Sequence, Text, Tuple
import grpc
from grpc import _common
@ -224,6 +225,51 @@ class Channel:
self._channel = cygrpc.AioChannel(_common.encode(target), options,
credentials)
def check_connectivity_state(self, try_to_connect: bool = False
) -> grpc.ChannelConnectivity:
"""Check the connectivity state of a channel.
This is an EXPERIMENTAL API.
Args:
try_to_connect: a bool indicate whether the Channel should try to connect to peer or not.
Returns:
A ChannelConnectivity object.
"""
result = self._channel.check_connectivity_state(try_to_connect)
return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]
async def watch_connectivity_state(
self,
last_observed_state: grpc.ChannelConnectivity,
timeout_seconds: Optional[float] = None,
) -> Optional[grpc.ChannelConnectivity]:
"""Watch for a change in connectivity state.
This is an EXPERIMENTAL API.
Once the channel connectivity state is different from
last_observed_state, the function will return the new connectivity
state. If deadline expires BEFORE the state is changed, None will be
returned.
Args:
try_to_connect: a bool indicate whether the Channel should try to connect to peer or not.
Returns:
A ChannelConnectivity object or None.
"""
deadline = time.time(
) + timeout_seconds if timeout_seconds is not None else None
result = await self._channel.watch_connectivity_state(
last_observed_state.value[0], deadline)
if result is None:
return None
else:
return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[
result]
def unary_unary(
self,
method: Text,

@ -0,0 +1,98 @@
# Copyright 2019 The gRPC Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests behavior of the connectivity state."""
import logging
import threading
import unittest
import time
import grpc
from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc
from tests.unit.framework.common import test_constants
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase
_INVALID_BACKEND_ADDRESS = '0.0.0.1:2'
class TestChannel(AioTestBase):
async def setUp(self):
self._server_address, self._server = await start_test_server()
async def tearDown(self):
await self._server.stop(None)
async def test_unavailable_backend(self):
channel = aio.insecure_channel(_INVALID_BACKEND_ADDRESS)
self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.check_connectivity_state(False))
self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.check_connectivity_state(True))
self.assertEqual(
grpc.ChannelConnectivity.CONNECTING, await
channel.watch_connectivity_state(grpc.ChannelConnectivity.IDLE))
self.assertEqual(
grpc.ChannelConnectivity.TRANSIENT_FAILURE, await
channel.watch_connectivity_state(grpc.ChannelConnectivity.CONNECTING
))
await channel.close()
async def test_normal_backend(self):
channel = aio.insecure_channel(self._server_address)
current_state = channel.check_connectivity_state(True)
self.assertEqual(grpc.ChannelConnectivity.IDLE, current_state)
deadline = time.time() + test_constants.SHORT_TIMEOUT
while current_state != grpc.ChannelConnectivity.READY:
current_state = await channel.watch_connectivity_state(
current_state, deadline - time.time())
self.assertIsNotNone(current_state)
await channel.close()
async def test_timeout(self):
channel = aio.insecure_channel(self._server_address)
self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.check_connectivity_state(False))
# If timed out, the function should return None.
self.assertIsNone(await channel.watch_connectivity_state(
grpc.ChannelConnectivity.IDLE, test_constants.SHORT_TIMEOUT))
await channel.close()
async def test_shutdown(self):
channel = aio.insecure_channel(self._server_address)
self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.check_connectivity_state(False))
await channel.close()
self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN,
channel.check_connectivity_state(False))
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
Loading…
Cancel
Save