|
|
|
@ -24,6 +24,8 @@ import grpc |
|
|
|
|
|
|
|
|
|
from tests.unit import test_common |
|
|
|
|
from tests.unit.framework.common import test_constants |
|
|
|
|
import tests.unit.framework.common |
|
|
|
|
from tests.unit.framework.common import bound_socket |
|
|
|
|
|
|
|
|
|
_UNARY_UNARY = '/test/UnaryUnary' |
|
|
|
|
_UNARY_STREAM = '/test/UnaryStream' |
|
|
|
@ -93,35 +95,10 @@ class _GenericHandler(grpc.GenericRpcHandler): |
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_socket_ipv6(bind_address): |
|
|
|
|
listen_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) |
|
|
|
|
listen_socket.bind((bind_address, 0, 0, 0)) |
|
|
|
|
return listen_socket |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_socket_ipv4(bind_address): |
|
|
|
|
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
|
|
|
|
listen_socket.bind((bind_address, 0)) |
|
|
|
|
return listen_socket |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_free_loopback_tcp_port(): |
|
|
|
|
listen_socket = None |
|
|
|
|
if socket.has_ipv6: |
|
|
|
|
try: |
|
|
|
|
listen_socket = _create_socket_ipv6('') |
|
|
|
|
except socket.error: |
|
|
|
|
listen_socket = _create_socket_ipv4('') |
|
|
|
|
else: |
|
|
|
|
listen_socket = _create_socket_ipv4('') |
|
|
|
|
address_tuple = listen_socket.getsockname() |
|
|
|
|
return listen_socket, "localhost:%s" % (address_tuple[1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_dummy_channel(): |
|
|
|
|
"""Creating dummy channels is a workaround for retries""" |
|
|
|
|
_, addr = get_free_loopback_tcp_port() |
|
|
|
|
return grpc.insecure_channel(addr) |
|
|
|
|
with bound_socket() as (host, port): |
|
|
|
|
return grpc.insecure_channel('{}:{}'.format(host, port)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def perform_unary_unary_call(channel, wait_for_ready=None): |
|
|
|
@ -221,49 +198,50 @@ class MetadataFlagsTest(unittest.TestCase): |
|
|
|
|
# main thread. So, it need another method to store the |
|
|
|
|
# exceptions and raise them again in main thread. |
|
|
|
|
unhandled_exceptions = queue.Queue() |
|
|
|
|
tcp, addr = get_free_loopback_tcp_port() |
|
|
|
|
wg = test_common.WaitGroup(len(_ALL_CALL_CASES)) |
|
|
|
|
|
|
|
|
|
def wait_for_transient_failure(channel_connectivity): |
|
|
|
|
if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE: |
|
|
|
|
wg.done() |
|
|
|
|
|
|
|
|
|
def test_call(perform_call): |
|
|
|
|
with grpc.insecure_channel(addr) as channel: |
|
|
|
|
try: |
|
|
|
|
channel.subscribe(wait_for_transient_failure) |
|
|
|
|
perform_call(channel, wait_for_ready=True) |
|
|
|
|
except BaseException as e: # pylint: disable=broad-except |
|
|
|
|
# If the call failed, the thread would be destroyed. The |
|
|
|
|
# channel object can be collected before calling the |
|
|
|
|
# callback, which will result in a deadlock. |
|
|
|
|
with bound_socket(listen=False) as (host, port): |
|
|
|
|
addr = '{}:{}'.format(host, port) |
|
|
|
|
wg = test_common.WaitGroup(len(_ALL_CALL_CASES)) |
|
|
|
|
|
|
|
|
|
def wait_for_transient_failure(channel_connectivity): |
|
|
|
|
if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE: |
|
|
|
|
wg.done() |
|
|
|
|
unhandled_exceptions.put(e, True) |
|
|
|
|
|
|
|
|
|
test_threads = [] |
|
|
|
|
for perform_call in _ALL_CALL_CASES: |
|
|
|
|
test_thread = threading.Thread( |
|
|
|
|
target=test_call, args=(perform_call,)) |
|
|
|
|
test_thread.exception = None |
|
|
|
|
test_thread.start() |
|
|
|
|
test_threads.append(test_thread) |
|
|
|
|
|
|
|
|
|
# Start the server after the connections are waiting |
|
|
|
|
wg.wait() |
|
|
|
|
tcp.close() |
|
|
|
|
server = test_common.test_server() |
|
|
|
|
server.add_generic_rpc_handlers((_GenericHandler(weakref.proxy(self)),)) |
|
|
|
|
server.add_insecure_port(addr) |
|
|
|
|
server.start() |
|
|
|
|
|
|
|
|
|
for test_thread in test_threads: |
|
|
|
|
test_thread.join() |
|
|
|
|
|
|
|
|
|
# Stop the server to make test end properly |
|
|
|
|
server.stop(0) |
|
|
|
|
|
|
|
|
|
if not unhandled_exceptions.empty(): |
|
|
|
|
raise unhandled_exceptions.get(True) |
|
|
|
|
def test_call(perform_call): |
|
|
|
|
with grpc.insecure_channel(addr) as channel: |
|
|
|
|
try: |
|
|
|
|
channel.subscribe(wait_for_transient_failure) |
|
|
|
|
perform_call(channel, wait_for_ready=True) |
|
|
|
|
except BaseException as e: # pylint: disable=broad-except |
|
|
|
|
# If the call failed, the thread would be destroyed. The |
|
|
|
|
# channel object can be collected before calling the |
|
|
|
|
# callback, which will result in a deadlock. |
|
|
|
|
wg.done() |
|
|
|
|
unhandled_exceptions.put(e, True) |
|
|
|
|
|
|
|
|
|
test_threads = [] |
|
|
|
|
for perform_call in _ALL_CALL_CASES: |
|
|
|
|
test_thread = threading.Thread( |
|
|
|
|
target=test_call, args=(perform_call,)) |
|
|
|
|
test_thread.exception = None |
|
|
|
|
test_thread.start() |
|
|
|
|
test_threads.append(test_thread) |
|
|
|
|
|
|
|
|
|
# Start the server after the connections are waiting |
|
|
|
|
wg.wait() |
|
|
|
|
server = test_common.test_server(reuse_port=True) |
|
|
|
|
server.add_generic_rpc_handlers((_GenericHandler( |
|
|
|
|
weakref.proxy(self)),)) |
|
|
|
|
server.add_insecure_port(addr) |
|
|
|
|
server.start() |
|
|
|
|
|
|
|
|
|
for test_thread in test_threads: |
|
|
|
|
test_thread.join() |
|
|
|
|
|
|
|
|
|
# Stop the server to make test end properly |
|
|
|
|
server.stop(0) |
|
|
|
|
|
|
|
|
|
if not unhandled_exceptions.empty(): |
|
|
|
|
raise unhandled_exceptions.get(True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|