From 4484918d346d8b116fd566737bc5394ecab8d4d0 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Mon, 4 May 2020 17:13:51 -0700 Subject: [PATCH 1/2] Close the socket to ensure falling into transient failure state --- .../tests/unit/_metadata_flags_test.py | 95 ++++++++++--------- 1 file changed, 51 insertions(+), 44 deletions(-) diff --git a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py index de9ded2bafb..32bc6d48b1e 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py @@ -17,6 +17,7 @@ import time import weakref import unittest import threading +import logging import socket from six.moves import queue @@ -25,7 +26,7 @@ import grpc from tests.unit import test_common from tests.unit.framework.common import test_constants import tests.unit.framework.common -from tests.unit.framework.common import bound_socket +from tests.unit.framework.common import get_socket _UNARY_UNARY = '/test/UnaryUnary' _UNARY_STREAM = '/test/UnaryStream' @@ -101,8 +102,9 @@ class _GenericHandler(grpc.GenericRpcHandler): def create_dummy_channel(): """Creating dummy channels is a workaround for retries""" - with bound_socket() as (host, port): - return grpc.insecure_channel('{}:{}'.format(host, port)) + host, port, sock = get_socket() + sock.close() + return grpc.insecure_channel('{}:{}'.format(host, port)) def perform_unary_unary_call(channel, wait_for_ready=None): @@ -203,51 +205,56 @@ class MetadataFlagsTest(unittest.TestCase): # main thread. So, it need another method to store the # exceptions and raise them again in main thread. unhandled_exceptions = queue.Queue() - with bound_socket(listen=False) as (host, port): - addr = '{}:{}'.format(host, port) - wg = test_common.WaitGroup(len(_ALL_CALL_CASES)) - def wait_for_transient_failure(channel_connectivity): - if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE: + # We just need an unused TCP port + host, port, sock = get_socket() + sock.close() + + addr = '{}:{}'.format(host, port) + wg = test_common.WaitGroup(len(_ALL_CALL_CASES)) + + def wait_for_transient_failure(channel_connectivity): + if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE: + wg.done() + + def test_call(perform_call): + with grpc.insecure_channel(addr) as channel: + try: + channel.subscribe(wait_for_transient_failure) + perform_call(channel, wait_for_ready=True) + except BaseException as e: # pylint: disable=broad-except + # If the call failed, the thread would be destroyed. The + # channel object can be collected before calling the + # callback, which will result in a deadlock. wg.done() + unhandled_exceptions.put(e, True) - def test_call(perform_call): - with grpc.insecure_channel(addr) as channel: - try: - channel.subscribe(wait_for_transient_failure) - perform_call(channel, wait_for_ready=True) - except BaseException as e: # pylint: disable=broad-except - # If the call failed, the thread would be destroyed. The - # channel object can be collected before calling the - # callback, which will result in a deadlock. - wg.done() - unhandled_exceptions.put(e, True) - - test_threads = [] - for perform_call in _ALL_CALL_CASES: - test_thread = threading.Thread(target=test_call, - args=(perform_call,)) - test_thread.exception = None - test_thread.start() - test_threads.append(test_thread) - - # Start the server after the connections are waiting - wg.wait() - server = test_common.test_server(reuse_port=True) - server.add_generic_rpc_handlers( - (_GenericHandler(weakref.proxy(self)),)) - server.add_insecure_port(addr) - server.start() - - for test_thread in test_threads: - test_thread.join() - - # Stop the server to make test end properly - server.stop(0) - - if not unhandled_exceptions.empty(): - raise unhandled_exceptions.get(True) + test_threads = [] + for perform_call in _ALL_CALL_CASES: + test_thread = threading.Thread(target=test_call, + args=(perform_call,), + daemon=True) + test_thread.exception = None + test_thread.start() + test_threads.append(test_thread) + + # Start the server after the connections are waiting + wg.wait() + server = test_common.test_server(reuse_port=True) + server.add_generic_rpc_handlers((_GenericHandler(weakref.proxy(self)),)) + server.add_insecure_port(addr) + server.start() + + for test_thread in test_threads: + test_thread.join() + + # Stop the server to make test end properly + server.stop(0) + + if not unhandled_exceptions.empty(): + raise unhandled_exceptions.get(True) if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) From e2a41a100105d574041a763a09e187a884b8a8b6 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Tue, 5 May 2020 09:54:23 -0700 Subject: [PATCH 2/2] Restore for 2.7 --- src/python/grpcio_tests/tests/unit/_metadata_flags_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py index 32bc6d48b1e..f47ee77c012 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py @@ -232,8 +232,8 @@ class MetadataFlagsTest(unittest.TestCase): test_threads = [] for perform_call in _ALL_CALL_CASES: test_thread = threading.Thread(target=test_call, - args=(perform_call,), - daemon=True) + args=(perform_call,)) + test_thread.daemon = True test_thread.exception = None test_thread.start() test_threads.append(test_thread)