Adopt reviews' suggestions:

* Created a separate file for test constants
* Guarded current behavior of watch_connectivity_state
* Applied the same SEGV protection to callback_common
pull/21621/head
Lidi Zheng 5 years ago
parent 050b3989f0
commit 3099856a6a
  1. 6
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi
  2. 4
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  3. 7
      src/python/grpcio_tests/tests_aio/unit/BUILD.bazel
  4. 16
      src/python/grpcio_tests/tests_aio/unit/_constants.py
  5. 8
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  6. 11
      src/python/grpcio_tests/tests_aio/unit/channel_test.py
  7. 32
      src/python/grpcio_tests/tests_aio/unit/connectivity_test.py

@ -111,8 +111,12 @@ async def execute_batch(GrpcCallWrapper grpc_call_wrapper,
if error != GRPC_CALL_OK: if error != GRPC_CALL_OK:
raise ExecuteBatchError("Failed grpc_call_start_batch: {}".format(error)) raise ExecuteBatchError("Failed grpc_call_start_batch: {}".format(error))
# NOTE(lidiz) Guard against CanceledError from future.
def dealloc_wrapper(_):
cpython.Py_DECREF(wrapper)
future.add_done_callback(dealloc_wrapper)
await future await future
cpython.Py_DECREF(wrapper)
cdef grpc_event c_event cdef grpc_event c_event
# Tag.event must be called, otherwise messages won't be parsed from C # Tag.event must be called, otherwise messages won't be parsed from C
batch_operation_tag.event(c_event) batch_operation_tag.event(c_event)

@ -16,11 +16,11 @@
class _WatchConnectivityFailed(Exception): class _WatchConnectivityFailed(Exception):
"""Dedicated exception class for watch connectivity failed. """Dedicated exception class for watch connectivity failed.
It might be failed due to deadline exceeded, or the channel is closing. It might be failed due to deadline exceeded.
""" """
cdef CallbackFailureHandler _WATCH_CONNECTIVITY_FAILURE_HANDLER = CallbackFailureHandler( cdef CallbackFailureHandler _WATCH_CONNECTIVITY_FAILURE_HANDLER = CallbackFailureHandler(
'watch_connectivity_state', 'watch_connectivity_state',
'Timed out or channel closed.', 'Timed out',
_WatchConnectivityFailed) _WatchConnectivityFailed)

@ -37,6 +37,12 @@ py_library(
], ],
) )
py_library(
name = "_constants",
srcs = ["_constants.py"],
srcs_version = "PY3",
)
[ [
py_test( py_test(
name = test_file_name[:-3], name = test_file_name[:-3],
@ -49,6 +55,7 @@ py_library(
main = test_file_name, main = test_file_name,
python_version = "PY3", python_version = "PY3",
deps = [ deps = [
":_constants",
":_test_base", ":_test_base",
":_test_server", ":_test_server",
"//external:six", "//external:six",

@ -0,0 +1,16 @@
# Copyright 2020 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
UNREACHABLE_TARGET = '0.0.0.1:1111'
UNARY_CALL_WITH_SLEEP_VALUE = 0.2

@ -13,17 +13,15 @@
# limitations under the License. # limitations under the License.
import asyncio import asyncio
import logging
import datetime import datetime
import logging
import grpc import grpc
from grpc.experimental import aio from grpc.experimental import aio
from tests.unit.framework.common import test_constants
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc
UNARY_CALL_WITH_SLEEP_VALUE = 0.2 from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests_aio.unit._constants import UNARY_CALL_WITH_SLEEP_VALUE
class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):

@ -19,21 +19,20 @@ import threading
import unittest import unittest
import grpc import grpc
from grpc.experimental import aio from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests.unit.framework.common import test_constants from tests.unit.framework.common import test_constants
from tests_aio.unit._test_server import start_test_server, UNARY_CALL_WITH_SLEEP_VALUE from tests_aio.unit._constants import (UNARY_CALL_WITH_SLEEP_VALUE,
UNREACHABLE_TARGET)
from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2 from tests_aio.unit._test_server import start_test_server
_UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall' _UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall'
_UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' _UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep'
_STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall' _STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall'
_NUM_STREAM_RESPONSES = 5 _NUM_STREAM_RESPONSES = 5
_RESPONSE_PAYLOAD_SIZE = 42 _RESPONSE_PAYLOAD_SIZE = 42
_UNREACHABLE_TARGET = '0.1:1111'
class TestChannel(AioTestBase): class TestChannel(AioTestBase):

@ -16,18 +16,16 @@
import asyncio import asyncio
import logging import logging
import threading import threading
import unittest
import time import time
import grpc import unittest
import grpc
from grpc.experimental import aio from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc
from tests.unit.framework.common import test_constants from tests.unit.framework.common import test_constants
from tests_aio.unit._test_server import start_test_server from tests_aio.unit._constants import UNREACHABLE_TARGET
from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit._test_server import start_test_server
_INVALID_BACKEND_ADDRESS = '0.0.0.1:2'
async def _block_until_certain_state(channel, expected_state): async def _block_until_certain_state(channel, expected_state):
@ -46,17 +44,12 @@ class TestConnectivityState(AioTestBase):
await self._server.stop(None) await self._server.stop(None)
async def test_unavailable_backend(self): async def test_unavailable_backend(self):
async with aio.insecure_channel(_INVALID_BACKEND_ADDRESS) as channel: async with aio.insecure_channel(UNREACHABLE_TARGET) as channel:
self.assertEqual(grpc.ChannelConnectivity.IDLE, self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.get_state(False)) channel.get_state(False))
self.assertEqual(grpc.ChannelConnectivity.IDLE, self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.get_state(True)) channel.get_state(True))
async def waiting_transient_failure():
state = channel.get_state()
while state != grpc.ChannelConnectivity.TRANSIENT_FAILURE:
channel.wait_for_state_change(state)
# Should not time out # Should not time out
await asyncio.wait_for( await asyncio.wait_for(
_block_until_certain_state( _block_until_certain_state(
@ -92,6 +85,16 @@ class TestConnectivityState(AioTestBase):
self.assertEqual(grpc.ChannelConnectivity.IDLE, self.assertEqual(grpc.ChannelConnectivity.IDLE,
channel.get_state(False)) channel.get_state(False))
# Waiting for changes in a separate coroutine
wait_started = asyncio.Event()
async def a_pending_wait():
wait_started.set()
await channel.wait_for_state_change(grpc.ChannelConnectivity.IDLE)
pending_task = self.loop.create_task(a_pending_wait())
await wait_started.wait()
await channel.close() await channel.close()
self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN, self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN,
@ -100,6 +103,9 @@ class TestConnectivityState(AioTestBase):
self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN, self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN,
channel.get_state(False)) channel.get_state(False))
# Make sure there isn't any exception in the task
await pending_task
# It can raise exceptions since it is an usage error, but it should not # It can raise exceptions since it is an usage error, but it should not
# segfault or abort. # segfault or abort.
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):

Loading…
Cancel
Save