@ -16,18 +16,16 @@
import asyncio
import asyncio
import logging
import logging
import threading
import threading
import unittest
import time
import time
import grpc
import unittest
import grpc
from grpc . experimental import aio
from grpc . experimental import aio
from src . proto . grpc . testing import messages_pb2
from src . proto . grpc . testing import test_pb2_grpc
from tests . unit . framework . common import test_constants
from tests . unit . framework . common import test_constants
from tests_aio . unit . _test_server import start_test_server
from tests_aio . unit . _constants import UNREACHABLE_TARGET
from tests_aio . unit . _test_base import AioTestBase
from tests_aio . unit . _test_base import AioTestBase
from tests_aio . unit . _test_server import start_test_server
_INVALID_BACKEND_ADDRESS = ' 0.0.0.1:2 '
async def _block_until_certain_state ( channel , expected_state ) :
async def _block_until_certain_state ( channel , expected_state ) :
@ -46,17 +44,12 @@ class TestConnectivityState(AioTestBase):
await self . _server . stop ( None )
await self . _server . stop ( None )
async def test_unavailable_backend ( self ) :
async def test_unavailable_backend ( self ) :
async with aio . insecure_channel ( _INVALID_BACKEND_ADDRESS ) as channel :
async with aio . insecure_channel ( UNREACHABLE_TARGET ) as channel :
self . assertEqual ( grpc . ChannelConnectivity . IDLE ,
self . assertEqual ( grpc . ChannelConnectivity . IDLE ,
channel . get_state ( False ) )
channel . get_state ( False ) )
self . assertEqual ( grpc . ChannelConnectivity . IDLE ,
self . assertEqual ( grpc . ChannelConnectivity . IDLE ,
channel . get_state ( True ) )
channel . get_state ( True ) )
async def waiting_transient_failure ( ) :
state = channel . get_state ( )
while state != grpc . ChannelConnectivity . TRANSIENT_FAILURE :
channel . wait_for_state_change ( state )
# Should not time out
# Should not time out
await asyncio . wait_for (
await asyncio . wait_for (
_block_until_certain_state (
_block_until_certain_state (
@ -92,6 +85,16 @@ class TestConnectivityState(AioTestBase):
self . assertEqual ( grpc . ChannelConnectivity . IDLE ,
self . assertEqual ( grpc . ChannelConnectivity . IDLE ,
channel . get_state ( False ) )
channel . get_state ( False ) )
# Waiting for changes in a separate coroutine
wait_started = asyncio . Event ( )
async def a_pending_wait ( ) :
wait_started . set ( )
await channel . wait_for_state_change ( grpc . ChannelConnectivity . IDLE )
pending_task = self . loop . create_task ( a_pending_wait ( ) )
await wait_started . wait ( )
await channel . close ( )
await channel . close ( )
self . assertEqual ( grpc . ChannelConnectivity . SHUTDOWN ,
self . assertEqual ( grpc . ChannelConnectivity . SHUTDOWN ,
@ -100,6 +103,9 @@ class TestConnectivityState(AioTestBase):
self . assertEqual ( grpc . ChannelConnectivity . SHUTDOWN ,
self . assertEqual ( grpc . ChannelConnectivity . SHUTDOWN ,
channel . get_state ( False ) )
channel . get_state ( False ) )
# Make sure there isn't any exception in the task
await pending_task
# It can raise exceptions since it is an usage error, but it should not
# It can raise exceptions since it is an usage error, but it should not
# segfault or abort.
# segfault or abort.
with self . assertRaises ( RuntimeError ) :
with self . assertRaises ( RuntimeError ) :