Merge pull request #21621 from lidizheng/aio-connectivity

[Aio] Implement connectivity state related APIs
pull/21662/head
Lidi Zheng 5 years ago committed by GitHub
commit 4bc37f9eea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pxd.pxi
  2. 22
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi
  3. 7
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi
  4. 57
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  5. 4
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  6. 45
      src/python/grpcio/grpc/experimental/aio/_channel.py
  7. 1
      src/python/grpcio_tests/tests_aio/tests.json
  8. 7
      src/python/grpcio_tests/tests_aio/unit/BUILD.bazel
  9. 16
      src/python/grpcio_tests/tests_aio/unit/_constants.py
  10. 8
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  11. 11
      src/python/grpcio_tests/tests_aio/unit/channel_test.py
  12. 118
      src/python/grpcio_tests/tests_aio/unit/connectivity_test.py

@ -33,9 +33,12 @@ cdef struct CallbackContext:
# invoked by Core.
# failure_handler: A CallbackFailureHandler object that called when Core
# returns 'success == 0' state.
# wrapper: A self-reference to the CallbackWrapper to help life cycle
# management.
grpc_experimental_completion_queue_functor functor
cpython.PyObject *waiter
cpython.PyObject *failure_handler
cpython.PyObject *callback_wrapper
cdef class CallbackWrapper:

@ -36,10 +36,15 @@ cdef class CallbackWrapper:
self.context.functor.functor_run = self.functor_run
self.context.waiter = <cpython.PyObject*>future
self.context.failure_handler = <cpython.PyObject*>failure_handler
self.context.callback_wrapper = <cpython.PyObject*>self
# NOTE(lidiz) Not using a list here, because this class is critical in
# data path. We should make it as efficient as possible.
self._reference_of_future = future
self._reference_of_failure_handler = failure_handler
# NOTE(lidiz) We need to ensure when Core invokes our callback, the
# callback function itself is not deallocated. Othersise, we will get
# a segfault. We can view this as Core holding a ref.
cpython.Py_INCREF(self)
@staticmethod
cdef void functor_run(
@ -47,12 +52,12 @@ cdef class CallbackWrapper:
int success):
cdef CallbackContext *context = <CallbackContext *>functor
cdef object waiter = <object>context.waiter
if waiter.cancelled():
return
if success == 0:
(<CallbackFailureHandler>context.failure_handler).handle(waiter)
else:
waiter.set_result(None)
if not waiter.cancelled():
if success == 0:
(<CallbackFailureHandler>context.failure_handler).handle(waiter)
else:
waiter.set_result(None)
cpython.Py_DECREF(<object>context.callback_wrapper)
cdef grpc_experimental_completion_queue_functor *c_functor(self):
return &self.context.functor
@ -99,9 +104,6 @@ async def execute_batch(GrpcCallWrapper grpc_call_wrapper,
cdef CallbackWrapper wrapper = CallbackWrapper(
future,
CallbackFailureHandler('execute_batch', operations, ExecuteBatchError))
# NOTE(lidiz) Without Py_INCREF, the wrapper object will be destructed
# when calling "await". This is an over-optimization by Cython.
cpython.Py_INCREF(wrapper)
cdef grpc_call_error error = grpc_call_start_batch(
grpc_call_wrapper.call,
batch_operation_tag.c_ops,
@ -112,7 +114,7 @@ async def execute_batch(GrpcCallWrapper grpc_call_wrapper,
raise ExecuteBatchError("Failed grpc_call_start_batch: {}".format(error))
await future
cpython.Py_DECREF(wrapper)
cdef grpc_event c_event
# Tag.event must be called, otherwise messages won't be parsed from C
batch_operation_tag.event(c_event)

@ -12,8 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
cdef enum AioChannelStatus:
AIO_CHANNEL_STATUS_UNKNOWN
AIO_CHANNEL_STATUS_READY
AIO_CHANNEL_STATUS_DESTROYED
cdef class AioChannel:
cdef:
grpc_channel * channel
CallbackCompletionQueue cq
bytes _target
object _loop
AioChannelStatus _status

@ -13,6 +13,17 @@
# limitations under the License.
class _WatchConnectivityFailed(Exception):
"""Dedicated exception class for watch connectivity failed.
It might be failed due to deadline exceeded.
"""
cdef CallbackFailureHandler _WATCH_CONNECTIVITY_FAILURE_HANDLER = CallbackFailureHandler(
'watch_connectivity_state',
'Timed out',
_WatchConnectivityFailed)
cdef class AioChannel:
def __cinit__(self, bytes target, tuple options, ChannelCredentials credentials):
if options is None:
@ -20,6 +31,8 @@ cdef class AioChannel:
cdef _ChannelArgs channel_args = _ChannelArgs(options)
self._target = target
self.cq = CallbackCompletionQueue()
self._loop = asyncio.get_event_loop()
self._status = AIO_CHANNEL_STATUS_READY
if credentials is None:
self.channel = grpc_insecure_channel_create(
@ -29,7 +42,7 @@ cdef class AioChannel:
else:
self.channel = grpc_secure_channel_create(
<grpc_channel_credentials *> credentials.c(),
<char *> target,
<char *>target,
channel_args.c_args(),
NULL)
@ -38,8 +51,47 @@ cdef class AioChannel:
id_ = id(self)
return f"<{class_name} {id_}>"
def check_connectivity_state(self, bint try_to_connect):
"""A Cython wrapper for Core's check connectivity state API."""
return grpc_channel_check_connectivity_state(
self.channel,
try_to_connect,
)
async def watch_connectivity_state(self,
grpc_connectivity_state last_observed_state,
object deadline):
"""Watch for one connectivity state change.
Keeps mirroring the behavior from Core, so we can easily switch to
other design of API if necessary.
"""
if self._status == AIO_CHANNEL_STATUS_DESTROYED:
# TODO(lidiz) switch to UsageError
raise RuntimeError('Channel is closed.')
cdef gpr_timespec c_deadline = _timespec_from_time(deadline)
cdef object future = self._loop.create_future()
cdef CallbackWrapper wrapper = CallbackWrapper(
future,
_WATCH_CONNECTIVITY_FAILURE_HANDLER)
grpc_channel_watch_connectivity_state(
self.channel,
last_observed_state,
c_deadline,
self.cq.c_ptr(),
wrapper.c_functor())
try:
await future
except _WatchConnectivityFailed:
return False
else:
return True
def close(self):
grpc_channel_destroy(self.channel)
self._status = AIO_CHANNEL_STATUS_DESTROYED
def call(self,
bytes method,
@ -50,5 +102,8 @@ cdef class AioChannel:
Returns:
The _AioCall object.
"""
if self._status == AIO_CHANNEL_STATUS_DESTROYED:
# TODO(lidiz) switch to UsageError
raise RuntimeError('Channel is closed.')
cdef _AioCall call = _AioCall(self, deadline, method, credentials)
return call

@ -307,9 +307,6 @@ cdef class AioServer:
cdef CallbackWrapper wrapper = CallbackWrapper(
future,
REQUEST_CALL_FAILURE_HANDLER)
# NOTE(lidiz) Without Py_INCREF, the wrapper object will be destructed
# when calling "await". This is an over-optimization by Cython.
cpython.Py_INCREF(wrapper)
error = grpc_server_request_call(
self._server.c_server, &rpc_state.call, &rpc_state.details,
&rpc_state.request_metadata,
@ -320,7 +317,6 @@ cdef class AioServer:
raise RuntimeError("Error in grpc_server_request_call: %s" % error)
await future
cpython.Py_DECREF(wrapper)
return rpc_state
async def _server_main_loop(self,

@ -224,6 +224,51 @@ class Channel:
self._channel = cygrpc.AioChannel(_common.encode(target), options,
credentials)
def get_state(self,
try_to_connect: bool = False) -> grpc.ChannelConnectivity:
"""Check the connectivity state of a channel.
This is an EXPERIMENTAL API.
If the channel reaches a stable connectivity state, it is guaranteed
that the return value of this function will eventually converge to that
state.
Args: try_to_connect: a bool indicate whether the Channel should try to
connect to peer or not.
Returns: A ChannelConnectivity object.
"""
result = self._channel.check_connectivity_state(try_to_connect)
return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]
async def wait_for_state_change(
self,
last_observed_state: grpc.ChannelConnectivity,
) -> None:
"""Wait for a change in connectivity state.
This is an EXPERIMENTAL API.
The function blocks until there is a change in the channel connectivity
state from the "last_observed_state". If the state is already
different, this function will return immediately.
There is an inherent race between the invocation of
"Channel.wait_for_state_change" and "Channel.get_state". The state can
change arbitrary times during the race, so there is no way to observe
every state transition.
If there is a need to put a timeout for this function, please refer to
"asyncio.wait_for".
Args:
last_observed_state: A grpc.ChannelConnectivity object representing
the last known state.
"""
assert await self._channel.watch_connectivity_state(
last_observed_state.value[0], None)
def unary_unary(
self,
method: Text,

@ -5,6 +5,7 @@
"unit.call_test.TestUnaryUnaryCall",
"unit.channel_argument_test.TestChannelArgument",
"unit.channel_test.TestChannel",
"unit.connectivity_test.TestConnectivityState",
"unit.init_test.TestInsecureChannel",
"unit.init_test.TestSecureChannel",
"unit.interceptor_test.TestInterceptedUnaryUnaryCall",

@ -37,6 +37,12 @@ py_library(
],
)
py_library(
name = "_constants",
srcs = ["_constants.py"],
srcs_version = "PY3",
)
[
py_test(
name = test_file_name[:-3],
@ -49,6 +55,7 @@ py_library(
main = test_file_name,
python_version = "PY3",
deps = [
":_constants",
":_test_base",
":_test_server",
"//external:six",

@ -0,0 +1,16 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
UNREACHABLE_TARGET = '0.0.0.1:1111'
UNARY_CALL_WITH_SLEEP_VALUE = 0.2

@ -13,17 +13,15 @@
# limitations under the License.
import asyncio
import logging
import datetime
import logging
import grpc
from grpc.experimental import aio
from tests.unit.framework.common import test_constants
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc
UNARY_CALL_WITH_SLEEP_VALUE = 0.2
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests_aio.unit._constants import UNARY_CALL_WITH_SLEEP_VALUE
class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):

@ -19,21 +19,20 @@ import threading
import unittest
import grpc
from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests.unit.framework.common import test_constants
from tests_aio.unit._test_server import start_test_server, UNARY_CALL_WITH_SLEEP_VALUE
from tests_aio.unit._constants import (UNARY_CALL_WITH_SLEEP_VALUE,
UNREACHABLE_TARGET)
from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2
from tests_aio.unit._test_server import start_test_server
_UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall'
_UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep'
_STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall'
_NUM_STREAM_RESPONSES = 5
_RESPONSE_PAYLOAD_SIZE = 42
_UNREACHABLE_TARGET = '0.1:1111'
class TestChannel(AioTestBase):

@ -0,0 +1,118 @@
# Copyright 2019 The gRPC Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests behavior of the connectivity state."""
import asyncio
import logging
import threading
import time
import unittest
import grpc
from grpc.experimental import aio
from tests.unit.framework.common import test_constants
from tests_aio.unit._constants import UNREACHABLE_TARGET
from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit._test_server import start_test_server
async def _block_until_certain_state(channel, expected_state):
state = channel.get_state()
while state != expected_state:
await channel.wait_for_state_change(state)
state = channel.get_state()
class TestConnectivityState(AioTestBase):
async def setUp(self):
self._server_address, self._server = await start_test_server()
async def tearDown(self):
await self._server.stop(None)
async def test_unavailable_backend(self):
async with aio.insecure_channel(UNREACHABLE_TARGET) as channel:
self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.get_state(False))
self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.get_state(True))
# Should not time out
await asyncio.wait_for(
_block_until_certain_state(
channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE),
test_constants.SHORT_TIMEOUT)
async def test_normal_backend(self):
async with aio.insecure_channel(self._server_address) as channel:
current_state = channel.get_state(True)
self.assertEqual(grpc.ChannelConnectivity.IDLE, current_state)
# Should not time out
await asyncio.wait_for(
_block_until_certain_state(channel,
grpc.ChannelConnectivity.READY),
test_constants.SHORT_TIMEOUT)
async def test_timeout(self):
async with aio.insecure_channel(self._server_address) as channel:
self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.get_state(False))
# If timed out, the function should return None.
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(
_block_until_certain_state(channel,
grpc.ChannelConnectivity.READY),
test_constants.SHORT_TIMEOUT)
async def test_shutdown(self):
channel = aio.insecure_channel(self._server_address)
self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.get_state(False))
# Waiting for changes in a separate coroutine
wait_started = asyncio.Event()
async def a_pending_wait():
wait_started.set()
await channel.wait_for_state_change(grpc.ChannelConnectivity.IDLE)
pending_task = self.loop.create_task(a_pending_wait())
await wait_started.wait()
await channel.close()
self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN,
channel.get_state(True))
self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN,
channel.get_state(False))
# Make sure there isn't any exception in the task
await pending_task
# It can raise exceptions since it is an usage error, but it should not
# segfault or abort.
with self.assertRaises(RuntimeError):
await channel.wait_for_state_change(
grpc.ChannelConnectivity.SHUTDOWN)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
Loading…
Cancel
Save