diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pxd.pxi index 6c653fa42e6..e48d48385b7 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pxd.pxi @@ -33,9 +33,12 @@ cdef struct CallbackContext: # invoked by Core. # failure_handler: A CallbackFailureHandler object that called when Core # returns 'success == 0' state. + # wrapper: A self-reference to the CallbackWrapper to help life cycle + # management. grpc_experimental_completion_queue_functor functor cpython.PyObject *waiter cpython.PyObject *failure_handler + cpython.PyObject *callback_wrapper cdef class CallbackWrapper: diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi index 1bcc61a9856..280a7832e61 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi @@ -36,10 +36,15 @@ cdef class CallbackWrapper: self.context.functor.functor_run = self.functor_run self.context.waiter = future self.context.failure_handler = failure_handler + self.context.callback_wrapper = self # NOTE(lidiz) Not using a list here, because this class is critical in # data path. We should make it as efficient as possible. self._reference_of_future = future self._reference_of_failure_handler = failure_handler + # NOTE(lidiz) We need to ensure when Core invokes our callback, the + # callback function itself is not deallocated. Othersise, we will get + # a segfault. We can view this as Core holding a ref. + cpython.Py_INCREF(self) @staticmethod cdef void functor_run( @@ -47,12 +52,12 @@ cdef class CallbackWrapper: int success): cdef CallbackContext *context = functor cdef object waiter = context.waiter - if waiter.cancelled(): - return - if success == 0: - (context.failure_handler).handle(waiter) - else: - waiter.set_result(None) + if not waiter.cancelled(): + if success == 0: + (context.failure_handler).handle(waiter) + else: + waiter.set_result(None) + cpython.Py_DECREF(context.callback_wrapper) cdef grpc_experimental_completion_queue_functor *c_functor(self): return &self.context.functor @@ -99,9 +104,6 @@ async def execute_batch(GrpcCallWrapper grpc_call_wrapper, cdef CallbackWrapper wrapper = CallbackWrapper( future, CallbackFailureHandler('execute_batch', operations, ExecuteBatchError)) - # NOTE(lidiz) Without Py_INCREF, the wrapper object will be destructed - # when calling "await". This is an over-optimization by Cython. - cpython.Py_INCREF(wrapper) cdef grpc_call_error error = grpc_call_start_batch( grpc_call_wrapper.call, batch_operation_tag.c_ops, @@ -112,7 +114,7 @@ async def execute_batch(GrpcCallWrapper grpc_call_wrapper, raise ExecuteBatchError("Failed grpc_call_start_batch: {}".format(error)) await future - cpython.Py_DECREF(wrapper) + cdef grpc_event c_event # Tag.event must be called, otherwise messages won't be parsed from C batch_operation_tag.event(c_event) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi index 6e187d65c92..68a0d11b748 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi @@ -12,8 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +cdef enum AioChannelStatus: + AIO_CHANNEL_STATUS_UNKNOWN + AIO_CHANNEL_STATUS_READY + AIO_CHANNEL_STATUS_DESTROYED + cdef class AioChannel: cdef: grpc_channel * channel CallbackCompletionQueue cq bytes _target + object _loop + AioChannelStatus _status diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi index 850fb77d297..4022c892d20 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi @@ -13,6 +13,17 @@ # limitations under the License. +class _WatchConnectivityFailed(Exception): + """Dedicated exception class for watch connectivity failed. + + It might be failed due to deadline exceeded. + """ +cdef CallbackFailureHandler _WATCH_CONNECTIVITY_FAILURE_HANDLER = CallbackFailureHandler( + 'watch_connectivity_state', + 'Timed out', + _WatchConnectivityFailed) + + cdef class AioChannel: def __cinit__(self, bytes target, tuple options, ChannelCredentials credentials): if options is None: @@ -20,6 +31,8 @@ cdef class AioChannel: cdef _ChannelArgs channel_args = _ChannelArgs(options) self._target = target self.cq = CallbackCompletionQueue() + self._loop = asyncio.get_event_loop() + self._status = AIO_CHANNEL_STATUS_READY if credentials is None: self.channel = grpc_insecure_channel_create( @@ -29,7 +42,7 @@ cdef class AioChannel: else: self.channel = grpc_secure_channel_create( credentials.c(), - target, + target, channel_args.c_args(), NULL) @@ -38,8 +51,47 @@ cdef class AioChannel: id_ = id(self) return f"<{class_name} {id_}>" + def check_connectivity_state(self, bint try_to_connect): + """A Cython wrapper for Core's check connectivity state API.""" + return grpc_channel_check_connectivity_state( + self.channel, + try_to_connect, + ) + + async def watch_connectivity_state(self, + grpc_connectivity_state last_observed_state, + object deadline): + """Watch for one connectivity state change. + + Keeps mirroring the behavior from Core, so we can easily switch to + other design of API if necessary. + """ + if self._status == AIO_CHANNEL_STATUS_DESTROYED: + # TODO(lidiz) switch to UsageError + raise RuntimeError('Channel is closed.') + cdef gpr_timespec c_deadline = _timespec_from_time(deadline) + + cdef object future = self._loop.create_future() + cdef CallbackWrapper wrapper = CallbackWrapper( + future, + _WATCH_CONNECTIVITY_FAILURE_HANDLER) + grpc_channel_watch_connectivity_state( + self.channel, + last_observed_state, + c_deadline, + self.cq.c_ptr(), + wrapper.c_functor()) + + try: + await future + except _WatchConnectivityFailed: + return False + else: + return True + def close(self): grpc_channel_destroy(self.channel) + self._status = AIO_CHANNEL_STATUS_DESTROYED def call(self, bytes method, @@ -50,5 +102,8 @@ cdef class AioChannel: Returns: The _AioCall object. """ + if self._status == AIO_CHANNEL_STATUS_DESTROYED: + # TODO(lidiz) switch to UsageError + raise RuntimeError('Channel is closed.') cdef _AioCall call = _AioCall(self, deadline, method, credentials) return call diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index 8aee3295f55..b264db2f0f5 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -307,9 +307,6 @@ cdef class AioServer: cdef CallbackWrapper wrapper = CallbackWrapper( future, REQUEST_CALL_FAILURE_HANDLER) - # NOTE(lidiz) Without Py_INCREF, the wrapper object will be destructed - # when calling "await". This is an over-optimization by Cython. - cpython.Py_INCREF(wrapper) error = grpc_server_request_call( self._server.c_server, &rpc_state.call, &rpc_state.details, &rpc_state.request_metadata, @@ -320,7 +317,6 @@ cdef class AioServer: raise RuntimeError("Error in grpc_server_request_call: %s" % error) await future - cpython.Py_DECREF(wrapper) return rpc_state async def _server_main_loop(self, diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 11405a8bd1b..2562f0f6d81 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -224,6 +224,51 @@ class Channel: self._channel = cygrpc.AioChannel(_common.encode(target), options, credentials) + def get_state(self, + try_to_connect: bool = False) -> grpc.ChannelConnectivity: + """Check the connectivity state of a channel. + + This is an EXPERIMENTAL API. + + If the channel reaches a stable connectivity state, it is guaranteed + that the return value of this function will eventually converge to that + state. + + Args: try_to_connect: a bool indicate whether the Channel should try to + connect to peer or not. + + Returns: A ChannelConnectivity object. + """ + result = self._channel.check_connectivity_state(try_to_connect) + return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result] + + async def wait_for_state_change( + self, + last_observed_state: grpc.ChannelConnectivity, + ) -> None: + """Wait for a change in connectivity state. + + This is an EXPERIMENTAL API. + + The function blocks until there is a change in the channel connectivity + state from the "last_observed_state". If the state is already + different, this function will return immediately. + + There is an inherent race between the invocation of + "Channel.wait_for_state_change" and "Channel.get_state". The state can + change arbitrary times during the race, so there is no way to observe + every state transition. + + If there is a need to put a timeout for this function, please refer to + "asyncio.wait_for". + + Args: + last_observed_state: A grpc.ChannelConnectivity object representing + the last known state. + """ + assert await self._channel.watch_connectivity_state( + last_observed_state.value[0], None) + def unary_unary( self, method: Text, diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index d5d7072dfda..082681dce0e 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -5,6 +5,7 @@ "unit.call_test.TestUnaryUnaryCall", "unit.channel_argument_test.TestChannelArgument", "unit.channel_test.TestChannel", + "unit.connectivity_test.TestConnectivityState", "unit.init_test.TestInsecureChannel", "unit.init_test.TestSecureChannel", "unit.interceptor_test.TestInterceptedUnaryUnaryCall", diff --git a/src/python/grpcio_tests/tests_aio/unit/BUILD.bazel b/src/python/grpcio_tests/tests_aio/unit/BUILD.bazel index 41aa33034cc..fd47d2c33d5 100644 --- a/src/python/grpcio_tests/tests_aio/unit/BUILD.bazel +++ b/src/python/grpcio_tests/tests_aio/unit/BUILD.bazel @@ -37,6 +37,12 @@ py_library( ], ) +py_library( + name = "_constants", + srcs = ["_constants.py"], + srcs_version = "PY3", +) + [ py_test( name = test_file_name[:-3], @@ -49,6 +55,7 @@ py_library( main = test_file_name, python_version = "PY3", deps = [ + ":_constants", ":_test_base", ":_test_server", "//external:six", diff --git a/src/python/grpcio_tests/tests_aio/unit/_constants.py b/src/python/grpcio_tests/tests_aio/unit/_constants.py new file mode 100644 index 00000000000..986a6f9d842 --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/_constants.py @@ -0,0 +1,16 @@ +# Copyright 2020 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +UNREACHABLE_TARGET = '0.0.0.1:1111' +UNARY_CALL_WITH_SLEEP_VALUE = 0.2 diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_server.py b/src/python/grpcio_tests/tests_aio/unit/_test_server.py index 9d0b7c0d358..d99c46f05c2 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_server.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_server.py @@ -13,17 +13,15 @@ # limitations under the License. import asyncio -import logging import datetime +import logging import grpc from grpc.experimental import aio -from tests.unit.framework.common import test_constants -from src.proto.grpc.testing import messages_pb2 -from src.proto.grpc.testing import test_pb2_grpc -UNARY_CALL_WITH_SLEEP_VALUE = 0.2 +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests_aio.unit._constants import UNARY_CALL_WITH_SLEEP_VALUE class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py index 10a64b474e2..1ab372a0e8c 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -19,21 +19,20 @@ import threading import unittest import grpc - from grpc.experimental import aio -from src.proto.grpc.testing import messages_pb2 -from src.proto.grpc.testing import test_pb2_grpc + +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc from tests.unit.framework.common import test_constants -from tests_aio.unit._test_server import start_test_server, UNARY_CALL_WITH_SLEEP_VALUE +from tests_aio.unit._constants import (UNARY_CALL_WITH_SLEEP_VALUE, + UNREACHABLE_TARGET) from tests_aio.unit._test_base import AioTestBase -from src.proto.grpc.testing import messages_pb2 +from tests_aio.unit._test_server import start_test_server _UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall' _UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' _STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall' _NUM_STREAM_RESPONSES = 5 _RESPONSE_PAYLOAD_SIZE = 42 -_UNREACHABLE_TARGET = '0.1:1111' class TestChannel(AioTestBase): diff --git a/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py b/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py new file mode 100644 index 00000000000..95a819b2b5f --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py @@ -0,0 +1,118 @@ +# Copyright 2019 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests behavior of the connectivity state.""" + +import asyncio +import logging +import threading +import time +import unittest + +import grpc +from grpc.experimental import aio + +from tests.unit.framework.common import test_constants +from tests_aio.unit._constants import UNREACHABLE_TARGET +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit._test_server import start_test_server + + +async def _block_until_certain_state(channel, expected_state): + state = channel.get_state() + while state != expected_state: + await channel.wait_for_state_change(state) + state = channel.get_state() + + +class TestConnectivityState(AioTestBase): + + async def setUp(self): + self._server_address, self._server = await start_test_server() + + async def tearDown(self): + await self._server.stop(None) + + async def test_unavailable_backend(self): + async with aio.insecure_channel(UNREACHABLE_TARGET) as channel: + self.assertEqual(grpc.ChannelConnectivity.IDLE, + channel.get_state(False)) + self.assertEqual(grpc.ChannelConnectivity.IDLE, + channel.get_state(True)) + + # Should not time out + await asyncio.wait_for( + _block_until_certain_state( + channel, grpc.ChannelConnectivity.TRANSIENT_FAILURE), + test_constants.SHORT_TIMEOUT) + + async def test_normal_backend(self): + async with aio.insecure_channel(self._server_address) as channel: + current_state = channel.get_state(True) + self.assertEqual(grpc.ChannelConnectivity.IDLE, current_state) + + # Should not time out + await asyncio.wait_for( + _block_until_certain_state(channel, + grpc.ChannelConnectivity.READY), + test_constants.SHORT_TIMEOUT) + + async def test_timeout(self): + async with aio.insecure_channel(self._server_address) as channel: + self.assertEqual(grpc.ChannelConnectivity.IDLE, + channel.get_state(False)) + + # If timed out, the function should return None. + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for( + _block_until_certain_state(channel, + grpc.ChannelConnectivity.READY), + test_constants.SHORT_TIMEOUT) + + async def test_shutdown(self): + channel = aio.insecure_channel(self._server_address) + + self.assertEqual(grpc.ChannelConnectivity.IDLE, + channel.get_state(False)) + + # Waiting for changes in a separate coroutine + wait_started = asyncio.Event() + + async def a_pending_wait(): + wait_started.set() + await channel.wait_for_state_change(grpc.ChannelConnectivity.IDLE) + + pending_task = self.loop.create_task(a_pending_wait()) + await wait_started.wait() + + await channel.close() + + self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN, + channel.get_state(True)) + + self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN, + channel.get_state(False)) + + # Make sure there isn't any exception in the task + await pending_task + + # It can raise exceptions since it is an usage error, but it should not + # segfault or abort. + with self.assertRaises(RuntimeError): + await channel.wait_for_state_change( + grpc.ChannelConnectivity.SHUTDOWN) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2)