diff --git a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py index de9ded2bafb..f47ee77c012 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py @@ -17,6 +17,7 @@ import time import weakref import unittest import threading +import logging import socket from six.moves import queue @@ -25,7 +26,7 @@ import grpc from tests.unit import test_common from tests.unit.framework.common import test_constants import tests.unit.framework.common -from tests.unit.framework.common import bound_socket +from tests.unit.framework.common import get_socket _UNARY_UNARY = '/test/UnaryUnary' _UNARY_STREAM = '/test/UnaryStream' @@ -101,8 +102,9 @@ class _GenericHandler(grpc.GenericRpcHandler): def create_dummy_channel(): """Creating dummy channels is a workaround for retries""" - with bound_socket() as (host, port): - return grpc.insecure_channel('{}:{}'.format(host, port)) + host, port, sock = get_socket() + sock.close() + return grpc.insecure_channel('{}:{}'.format(host, port)) def perform_unary_unary_call(channel, wait_for_ready=None): @@ -203,51 +205,56 @@ class MetadataFlagsTest(unittest.TestCase): # main thread. So, it need another method to store the # exceptions and raise them again in main thread. unhandled_exceptions = queue.Queue() - with bound_socket(listen=False) as (host, port): - addr = '{}:{}'.format(host, port) - wg = test_common.WaitGroup(len(_ALL_CALL_CASES)) - def wait_for_transient_failure(channel_connectivity): - if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE: + # We just need an unused TCP port + host, port, sock = get_socket() + sock.close() + + addr = '{}:{}'.format(host, port) + wg = test_common.WaitGroup(len(_ALL_CALL_CASES)) + + def wait_for_transient_failure(channel_connectivity): + if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE: + wg.done() + + def test_call(perform_call): + with grpc.insecure_channel(addr) as channel: + try: + channel.subscribe(wait_for_transient_failure) + perform_call(channel, wait_for_ready=True) + except BaseException as e: # pylint: disable=broad-except + # If the call failed, the thread would be destroyed. The + # channel object can be collected before calling the + # callback, which will result in a deadlock. wg.done() + unhandled_exceptions.put(e, True) - def test_call(perform_call): - with grpc.insecure_channel(addr) as channel: - try: - channel.subscribe(wait_for_transient_failure) - perform_call(channel, wait_for_ready=True) - except BaseException as e: # pylint: disable=broad-except - # If the call failed, the thread would be destroyed. The - # channel object can be collected before calling the - # callback, which will result in a deadlock. - wg.done() - unhandled_exceptions.put(e, True) - - test_threads = [] - for perform_call in _ALL_CALL_CASES: - test_thread = threading.Thread(target=test_call, - args=(perform_call,)) - test_thread.exception = None - test_thread.start() - test_threads.append(test_thread) - - # Start the server after the connections are waiting - wg.wait() - server = test_common.test_server(reuse_port=True) - server.add_generic_rpc_handlers( - (_GenericHandler(weakref.proxy(self)),)) - server.add_insecure_port(addr) - server.start() - - for test_thread in test_threads: - test_thread.join() - - # Stop the server to make test end properly - server.stop(0) - - if not unhandled_exceptions.empty(): - raise unhandled_exceptions.get(True) + test_threads = [] + for perform_call in _ALL_CALL_CASES: + test_thread = threading.Thread(target=test_call, + args=(perform_call,)) + test_thread.daemon = True + test_thread.exception = None + test_thread.start() + test_threads.append(test_thread) + + # Start the server after the connections are waiting + wg.wait() + server = test_common.test_server(reuse_port=True) + server.add_generic_rpc_handlers((_GenericHandler(weakref.proxy(self)),)) + server.add_insecure_port(addr) + server.start() + + for test_thread in test_threads: + test_thread.join() + + # Stop the server to make test end properly + server.stop(0) + + if not unhandled_exceptions.empty(): + raise unhandled_exceptions.get(True) if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2)