Merge pull request #22860 from lidizheng/fix-metadata-flags-test-flake

Fix the metadata flags test flake
pull/22886/head
Lidi Zheng 5 years ago committed by GitHub
commit 9993d2b9a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 95
      src/python/grpcio_tests/tests/unit/_metadata_flags_test.py

@ -17,6 +17,7 @@ import time
import weakref import weakref
import unittest import unittest
import threading import threading
import logging
import socket import socket
from six.moves import queue from six.moves import queue
@ -25,7 +26,7 @@ import grpc
from tests.unit import test_common from tests.unit import test_common
from tests.unit.framework.common import test_constants from tests.unit.framework.common import test_constants
import tests.unit.framework.common import tests.unit.framework.common
from tests.unit.framework.common import bound_socket from tests.unit.framework.common import get_socket
_UNARY_UNARY = '/test/UnaryUnary' _UNARY_UNARY = '/test/UnaryUnary'
_UNARY_STREAM = '/test/UnaryStream' _UNARY_STREAM = '/test/UnaryStream'
@ -101,8 +102,9 @@ class _GenericHandler(grpc.GenericRpcHandler):
def create_dummy_channel(): def create_dummy_channel():
"""Creating dummy channels is a workaround for retries""" """Creating dummy channels is a workaround for retries"""
with bound_socket() as (host, port): host, port, sock = get_socket()
return grpc.insecure_channel('{}:{}'.format(host, port)) sock.close()
return grpc.insecure_channel('{}:{}'.format(host, port))
def perform_unary_unary_call(channel, wait_for_ready=None): def perform_unary_unary_call(channel, wait_for_ready=None):
@ -203,51 +205,56 @@ class MetadataFlagsTest(unittest.TestCase):
# main thread. So, it need another method to store the # main thread. So, it need another method to store the
# exceptions and raise them again in main thread. # exceptions and raise them again in main thread.
unhandled_exceptions = queue.Queue() unhandled_exceptions = queue.Queue()
with bound_socket(listen=False) as (host, port):
addr = '{}:{}'.format(host, port)
wg = test_common.WaitGroup(len(_ALL_CALL_CASES))
def wait_for_transient_failure(channel_connectivity): # We just need an unused TCP port
if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE: host, port, sock = get_socket()
sock.close()
addr = '{}:{}'.format(host, port)
wg = test_common.WaitGroup(len(_ALL_CALL_CASES))
def wait_for_transient_failure(channel_connectivity):
if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE:
wg.done()
def test_call(perform_call):
with grpc.insecure_channel(addr) as channel:
try:
channel.subscribe(wait_for_transient_failure)
perform_call(channel, wait_for_ready=True)
except BaseException as e: # pylint: disable=broad-except
# If the call failed, the thread would be destroyed. The
# channel object can be collected before calling the
# callback, which will result in a deadlock.
wg.done() wg.done()
unhandled_exceptions.put(e, True)
def test_call(perform_call): test_threads = []
with grpc.insecure_channel(addr) as channel: for perform_call in _ALL_CALL_CASES:
try: test_thread = threading.Thread(target=test_call,
channel.subscribe(wait_for_transient_failure) args=(perform_call,))
perform_call(channel, wait_for_ready=True) test_thread.daemon = True
except BaseException as e: # pylint: disable=broad-except test_thread.exception = None
# If the call failed, the thread would be destroyed. The test_thread.start()
# channel object can be collected before calling the test_threads.append(test_thread)
# callback, which will result in a deadlock.
wg.done() # Start the server after the connections are waiting
unhandled_exceptions.put(e, True) wg.wait()
server = test_common.test_server(reuse_port=True)
test_threads = [] server.add_generic_rpc_handlers((_GenericHandler(weakref.proxy(self)),))
for perform_call in _ALL_CALL_CASES: server.add_insecure_port(addr)
test_thread = threading.Thread(target=test_call, server.start()
args=(perform_call,))
test_thread.exception = None for test_thread in test_threads:
test_thread.start() test_thread.join()
test_threads.append(test_thread)
# Stop the server to make test end properly
# Start the server after the connections are waiting server.stop(0)
wg.wait()
server = test_common.test_server(reuse_port=True) if not unhandled_exceptions.empty():
server.add_generic_rpc_handlers( raise unhandled_exceptions.get(True)
(_GenericHandler(weakref.proxy(self)),))
server.add_insecure_port(addr)
server.start()
for test_thread in test_threads:
test_thread.join()
# Stop the server to make test end properly
server.stop(0)
if not unhandled_exceptions.empty():
raise unhandled_exceptions.get(True)
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2) unittest.main(verbosity=2)

Loading…
Cancel
Save