Merge pull request #21621 from lidizheng/aio-connectivity

[Aio] Implement connectivity state related APIs
pull/21662/head
Lidi Zheng 5 years ago committed by GitHub
commit 4bc37f9eea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pxd.pxi
  2. 22
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi
  3. 7
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi
  4. 57
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  5. 4
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  6. 45
      src/python/grpcio/grpc/experimental/aio/_channel.py
  7. 1
      src/python/grpcio_tests/tests_aio/tests.json
  8. 7
      src/python/grpcio_tests/tests_aio/unit/BUILD.bazel
  9. 16
      src/python/grpcio_tests/tests_aio/unit/_constants.py
  10. 8
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  11. 11
      src/python/grpcio_tests/tests_aio/unit/channel_test.py
  12. 118
      src/python/grpcio_tests/tests_aio/unit/connectivity_test.py

@ -33,9 +33,12 @@ cdef struct CallbackContext:
# invoked by Core. # invoked by Core.
# failure_handler: A CallbackFailureHandler object that called when Core # failure_handler: A CallbackFailureHandler object that called when Core
# returns 'success == 0' state. # returns 'success == 0' state.
# wrapper: A self-reference to the CallbackWrapper to help life cycle
# management.
grpc_experimental_completion_queue_functor functor grpc_experimental_completion_queue_functor functor
cpython.PyObject *waiter cpython.PyObject *waiter
cpython.PyObject *failure_handler cpython.PyObject *failure_handler
cpython.PyObject *callback_wrapper
cdef class CallbackWrapper: cdef class CallbackWrapper:

@ -36,10 +36,15 @@ cdef class CallbackWrapper:
self.context.functor.functor_run = self.functor_run self.context.functor.functor_run = self.functor_run
self.context.waiter = <cpython.PyObject*>future self.context.waiter = <cpython.PyObject*>future
self.context.failure_handler = <cpython.PyObject*>failure_handler self.context.failure_handler = <cpython.PyObject*>failure_handler
self.context.callback_wrapper = <cpython.PyObject*>self
# NOTE(lidiz) Not using a list here, because this class is critical in # NOTE(lidiz) Not using a list here, because this class is critical in
# data path. We should make it as efficient as possible. # data path. We should make it as efficient as possible.
self._reference_of_future = future self._reference_of_future = future
self._reference_of_failure_handler = failure_handler self._reference_of_failure_handler = failure_handler
# NOTE(lidiz) We need to ensure when Core invokes our callback, the
# callback function itself is not deallocated. Othersise, we will get
# a segfault. We can view this as Core holding a ref.
cpython.Py_INCREF(self)
@staticmethod @staticmethod
cdef void functor_run( cdef void functor_run(
@ -47,12 +52,12 @@ cdef class CallbackWrapper:
int success): int success):
cdef CallbackContext *context = <CallbackContext *>functor cdef CallbackContext *context = <CallbackContext *>functor
cdef object waiter = <object>context.waiter cdef object waiter = <object>context.waiter
if waiter.cancelled(): if not waiter.cancelled():
return if success == 0:
if success == 0: (<CallbackFailureHandler>context.failure_handler).handle(waiter)
(<CallbackFailureHandler>context.failure_handler).handle(waiter) else:
else: waiter.set_result(None)
waiter.set_result(None) cpython.Py_DECREF(<object>context.callback_wrapper)
cdef grpc_experimental_completion_queue_functor *c_functor(self): cdef grpc_experimental_completion_queue_functor *c_functor(self):
return &self.context.functor return &self.context.functor
@ -99,9 +104,6 @@ async def execute_batch(GrpcCallWrapper grpc_call_wrapper,
cdef CallbackWrapper wrapper = CallbackWrapper( cdef CallbackWrapper wrapper = CallbackWrapper(
future, future,
CallbackFailureHandler('execute_batch', operations, ExecuteBatchError)) CallbackFailureHandler('execute_batch', operations, ExecuteBatchError))
# NOTE(lidiz) Without Py_INCREF, the wrapper object will be destructed
# when calling "await". This is an over-optimization by Cython.
cpython.Py_INCREF(wrapper)
cdef grpc_call_error error = grpc_call_start_batch( cdef grpc_call_error error = grpc_call_start_batch(
grpc_call_wrapper.call, grpc_call_wrapper.call,
batch_operation_tag.c_ops, batch_operation_tag.c_ops,
@ -112,7 +114,7 @@ async def execute_batch(GrpcCallWrapper grpc_call_wrapper,
raise ExecuteBatchError("Failed grpc_call_start_batch: {}".format(error)) raise ExecuteBatchError("Failed grpc_call_start_batch: {}".format(error))
await future await future
cpython.Py_DECREF(wrapper)
cdef grpc_event c_event cdef grpc_event c_event
# Tag.event must be called, otherwise messages won't be parsed from C # Tag.event must be called, otherwise messages won't be parsed from C
batch_operation_tag.event(c_event) batch_operation_tag.event(c_event)

@ -12,8 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
cdef enum AioChannelStatus:
AIO_CHANNEL_STATUS_UNKNOWN
AIO_CHANNEL_STATUS_READY
AIO_CHANNEL_STATUS_DESTROYED
cdef class AioChannel: cdef class AioChannel:
cdef: cdef:
grpc_channel * channel grpc_channel * channel
CallbackCompletionQueue cq CallbackCompletionQueue cq
bytes _target bytes _target
object _loop
AioChannelStatus _status

@ -13,6 +13,17 @@
# limitations under the License. # limitations under the License.
class _WatchConnectivityFailed(Exception):
"""Dedicated exception class for watch connectivity failed.
It might be failed due to deadline exceeded.
"""
cdef CallbackFailureHandler _WATCH_CONNECTIVITY_FAILURE_HANDLER = CallbackFailureHandler(
'watch_connectivity_state',
'Timed out',
_WatchConnectivityFailed)
cdef class AioChannel: cdef class AioChannel:
def __cinit__(self, bytes target, tuple options, ChannelCredentials credentials): def __cinit__(self, bytes target, tuple options, ChannelCredentials credentials):
if options is None: if options is None:
@ -20,6 +31,8 @@ cdef class AioChannel:
cdef _ChannelArgs channel_args = _ChannelArgs(options) cdef _ChannelArgs channel_args = _ChannelArgs(options)
self._target = target self._target = target
self.cq = CallbackCompletionQueue() self.cq = CallbackCompletionQueue()
self._loop = asyncio.get_event_loop()
self._status = AIO_CHANNEL_STATUS_READY
if credentials is None: if credentials is None:
self.channel = grpc_insecure_channel_create( self.channel = grpc_insecure_channel_create(
@ -29,7 +42,7 @@ cdef class AioChannel:
else: else:
self.channel = grpc_secure_channel_create( self.channel = grpc_secure_channel_create(
<grpc_channel_credentials *> credentials.c(), <grpc_channel_credentials *> credentials.c(),
<char *> target, <char *>target,
channel_args.c_args(), channel_args.c_args(),
NULL) NULL)
@ -38,8 +51,47 @@ cdef class AioChannel:
id_ = id(self) id_ = id(self)
return f"<{class_name} {id_}>" return f"<{class_name} {id_}>"
def check_connectivity_state(self, bint try_to_connect):
"""A Cython wrapper for Core's check connectivity state API."""
return grpc_channel_check_connectivity_state(
self.channel,
try_to_connect,
)
async def watch_connectivity_state(self,
grpc_connectivity_state last_observed_state,
object deadline):
"""Watch for one connectivity state change.
Keeps mirroring the behavior from Core, so we can easily switch to
other design of API if necessary.
"""
if self._status == AIO_CHANNEL_STATUS_DESTROYED:
# TODO(lidiz) switch to UsageError
raise RuntimeError('Channel is closed.')
cdef gpr_timespec c_deadline = _timespec_from_time(deadline)
cdef object future = self._loop.create_future()
cdef CallbackWrapper wrapper = CallbackWrapper(
future,
_WATCH_CONNECTIVITY_FAILURE_HANDLER)
grpc_channel_watch_connectivity_state(
self.channel,
last_observed_state,
c_deadline,
self.cq.c_ptr(),
wrapper.c_functor())
try:
await future
except _WatchConnectivityFailed:
return False
else:
return True
def close(self): def close(self):
grpc_channel_destroy(self.channel) grpc_channel_destroy(self.channel)
self._status = AIO_CHANNEL_STATUS_DESTROYED
def call(self, def call(self,
bytes method, bytes method,
@ -50,5 +102,8 @@ cdef class AioChannel:
Returns: Returns:
The _AioCall object. The _AioCall object.
""" """
if self._status == AIO_CHANNEL_STATUS_DESTROYED:
# TODO(lidiz) switch to UsageError
raise RuntimeError('Channel is closed.')
cdef _AioCall call = _AioCall(self, deadline, method, credentials) cdef _AioCall call = _AioCall(self, deadline, method, credentials)
return call return call

@ -307,9 +307,6 @@ cdef class AioServer:
cdef CallbackWrapper wrapper = CallbackWrapper( cdef CallbackWrapper wrapper = CallbackWrapper(
future, future,
REQUEST_CALL_FAILURE_HANDLER) REQUEST_CALL_FAILURE_HANDLER)
# NOTE(lidiz) Without Py_INCREF, the wrapper object will be destructed
# when calling "await". This is an over-optimization by Cython.
cpython.Py_INCREF(wrapper)
error = grpc_server_request_call( error = grpc_server_request_call(
self._server.c_server, &rpc_state.call, &rpc_state.details, self._server.c_server, &rpc_state.call, &rpc_state.details,
&rpc_state.request_metadata, &rpc_state.request_metadata,
@ -320,7 +317,6 @@ cdef class AioServer:
raise RuntimeError("Error in grpc_server_request_call: %s" % error) raise RuntimeError("Error in grpc_server_request_call: %s" % error)
await future await future
cpython.Py_DECREF(wrapper)
return rpc_state return rpc_state
async def _server_main_loop(self, async def _server_main_loop(self,

@ -224,6 +224,51 @@ class Channel:
self._channel = cygrpc.AioChannel(_common.encode(target), options, self._channel = cygrpc.AioChannel(_common.encode(target), options,
credentials) credentials)
def get_state(self,
try_to_connect: bool = False) -> grpc.ChannelConnectivity:
"""Check the connectivity state of a channel.
This is an EXPERIMENTAL API.
If the channel reaches a stable connectivity state, it is guaranteed
that the return value of this function will eventually converge to that
state.
Args: try_to_connect: a bool indicate whether the Channel should try to
connect to peer or not.
Returns: A ChannelConnectivity object.
"""
result = self._channel.check_connectivity_state(try_to_connect)
return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]
async def wait_for_state_change(
self,
last_observed_state: grpc.ChannelConnectivity,
) -> None:
"""Wait for a change in connectivity state.
This is an EXPERIMENTAL API.
The function blocks until there is a change in the channel connectivity
state from the "last_observed_state". If the state is already
different, this function will return immediately.
There is an inherent race between the invocation of
"Channel.wait_for_state_change" and "Channel.get_state". The state can
change arbitrary times during the race, so there is no way to observe
every state transition.
If there is a need to put a timeout for this function, please refer to
"asyncio.wait_for".
Args:
last_observed_state: A grpc.ChannelConnectivity object representing
the last known state.
"""
assert await self._channel.watch_connectivity_state(
last_observed_state.value[0], None)
def unary_unary( def unary_unary(
self, self,
method: Text, method: Text,

@ -5,6 +5,7 @@
"unit.call_test.TestUnaryUnaryCall", "unit.call_test.TestUnaryUnaryCall",
"unit.channel_argument_test.TestChannelArgument", "unit.channel_argument_test.TestChannelArgument",
"unit.channel_test.TestChannel", "unit.channel_test.TestChannel",
"unit.connectivity_test.TestConnectivityState",
"unit.init_test.TestInsecureChannel", "unit.init_test.TestInsecureChannel",
"unit.init_test.TestSecureChannel", "unit.init_test.TestSecureChannel",
"unit.interceptor_test.TestInterceptedUnaryUnaryCall", "unit.interceptor_test.TestInterceptedUnaryUnaryCall",

@ -37,6 +37,12 @@ py_library(
], ],
) )
py_library(
name = "_constants",
srcs = ["_constants.py"],
srcs_version = "PY3",
)
[ [
py_test( py_test(
name = test_file_name[:-3], name = test_file_name[:-3],
@ -49,6 +55,7 @@ py_library(
main = test_file_name, main = test_file_name,
python_version = "PY3", python_version = "PY3",
deps = [ deps = [
":_constants",
":_test_base", ":_test_base",
":_test_server", ":_test_server",
"//external:six", "//external:six",

@ -0,0 +1,16 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
UNREACHABLE_TARGET = '0.0.0.1:1111'
UNARY_CALL_WITH_SLEEP_VALUE = 0.2

@ -13,17 +13,15 @@
# limitations under the License. # limitations under the License.
import asyncio import asyncio
import logging
import datetime import datetime
import logging
import grpc import grpc
from grpc.experimental import aio from grpc.experimental import aio
from tests.unit.framework.common import test_constants
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc
UNARY_CALL_WITH_SLEEP_VALUE = 0.2 from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests_aio.unit._constants import UNARY_CALL_WITH_SLEEP_VALUE
class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):

@ -19,21 +19,20 @@ import threading
import unittest import unittest
import grpc import grpc
from grpc.experimental import aio from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests.unit.framework.common import test_constants from tests.unit.framework.common import test_constants
from tests_aio.unit._test_server import start_test_server, UNARY_CALL_WITH_SLEEP_VALUE from tests_aio.unit._constants import (UNARY_CALL_WITH_SLEEP_VALUE,
UNREACHABLE_TARGET)
from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2 from tests_aio.unit._test_server import start_test_server
_UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall' _UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall'
_UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' _UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep'
_STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall' _STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall'
_NUM_STREAM_RESPONSES = 5 _NUM_STREAM_RESPONSES = 5
_RESPONSE_PAYLOAD_SIZE = 42 _RESPONSE_PAYLOAD_SIZE = 42
_UNREACHABLE_TARGET = '0.1:1111'
class TestChannel(AioTestBase): class TestChannel(AioTestBase):

@ -0,0 +1,118 @@
# Copyright 2019 The gRPC Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests behavior of the connectivity state."""
import asyncio
import logging
import threading
import time
import unittest
import grpc
from grpc.experimental import aio
from tests.unit.framework.common import test_constants
from tests_aio.unit._constants import UNREACHABLE_TARGET
from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit._test_server import start_test_server
async def _block_until_certain_state(channel, expected_state):
state = channel.get_state()
while state != expected_state:
await channel.wait_for_state_change(state)
state = channel.get_state()
class TestConnectivityState(AioTestBase):
async def setUp(self):
self._server_address, self._server = await start_test_server()
async def tearDown(self):
await self._server.stop(None)
async def test_unavailable_backend(self):
async with aio.insecure_channel(UNREACHABLE_TARGET) as channel:
self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.get_state(False))
self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.get_state(True))
# Should not time out
await asyncio.wait_for(
_block_until_certain_state(
channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE),
test_constants.SHORT_TIMEOUT)
async def test_normal_backend(self):
async with aio.insecure_channel(self._server_address) as channel:
current_state = channel.get_state(True)
self.assertEqual(grpc.ChannelConnectivity.IDLE, current_state)
# Should not time out
await asyncio.wait_for(
_block_until_certain_state(channel,
grpc.ChannelConnectivity.READY),
test_constants.SHORT_TIMEOUT)
async def test_timeout(self):
async with aio.insecure_channel(self._server_address) as channel:
self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.get_state(False))
# If timed out, the function should return None.
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(
_block_until_certain_state(channel,
grpc.ChannelConnectivity.READY),
test_constants.SHORT_TIMEOUT)
async def test_shutdown(self):
channel = aio.insecure_channel(self._server_address)
self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.get_state(False))
# Waiting for changes in a separate coroutine
wait_started = asyncio.Event()
async def a_pending_wait():
wait_started.set()
await channel.wait_for_state_change(grpc.ChannelConnectivity.IDLE)
pending_task = self.loop.create_task(a_pending_wait())
await wait_started.wait()
await channel.close()
self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN,
channel.get_state(True))
self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN,
channel.get_state(False))
# Make sure there isn't any exception in the task
await pending_task
# It can raise exceptions since it is an usage error, but it should not
# segfault or abort.
with self.assertRaises(RuntimeError):
await channel.wait_for_state_change(
grpc.ChannelConnectivity.SHUTDOWN)
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
Loading…
Cancel
Save