diff --git a/src/python/grpcio_tests/tests/unit/BUILD.bazel b/src/python/grpcio_tests/tests/unit/BUILD.bazel index 639205a0d43..fbaff2697a3 100644 --- a/src/python/grpcio_tests/tests/unit/BUILD.bazel +++ b/src/python/grpcio_tests/tests/unit/BUILD.bazel @@ -27,8 +27,7 @@ GRPCIO_TESTS_UNIT = [ "_metadata_flags_test.py", "_metadata_code_details_test.py", "_metadata_test.py", - # TODO: Issue 16336 - # "_reconnect_test.py", + "_reconnect_test.py", "_resource_exhausted_test.py", "_rpc_test.py", "_signal_handling_test.py", diff --git a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py index 05d943cc0fa..efc55ea2ddf 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_flags_test.py @@ -24,6 +24,8 @@ import grpc from tests.unit import test_common from tests.unit.framework.common import test_constants +import tests.unit.framework.common +from tests.unit.framework.common import bound_socket _UNARY_UNARY = '/test/UnaryUnary' _UNARY_STREAM = '/test/UnaryStream' @@ -93,35 +95,10 @@ class _GenericHandler(grpc.GenericRpcHandler): return None -def _create_socket_ipv6(bind_address): - listen_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - listen_socket.bind((bind_address, 0, 0, 0)) - return listen_socket - - -def _create_socket_ipv4(bind_address): - listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - listen_socket.bind((bind_address, 0)) - return listen_socket - - -def get_free_loopback_tcp_port(): - listen_socket = None - if socket.has_ipv6: - try: - listen_socket = _create_socket_ipv6('') - except socket.error: - listen_socket = _create_socket_ipv4('') - else: - listen_socket = _create_socket_ipv4('') - address_tuple = listen_socket.getsockname() - return listen_socket, "localhost:%s" % (address_tuple[1]) - - def create_dummy_channel(): """Creating dummy channels is a workaround for retries""" - _, addr = get_free_loopback_tcp_port() - return grpc.insecure_channel(addr) + with bound_socket() as (host, port): + return grpc.insecure_channel('{}:{}'.format(host, port)) def perform_unary_unary_call(channel, wait_for_ready=None): @@ -221,49 +198,50 @@ class MetadataFlagsTest(unittest.TestCase): # main thread. So, it need another method to store the # exceptions and raise them again in main thread. unhandled_exceptions = queue.Queue() - tcp, addr = get_free_loopback_tcp_port() - wg = test_common.WaitGroup(len(_ALL_CALL_CASES)) - - def wait_for_transient_failure(channel_connectivity): - if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE: - wg.done() - - def test_call(perform_call): - with grpc.insecure_channel(addr) as channel: - try: - channel.subscribe(wait_for_transient_failure) - perform_call(channel, wait_for_ready=True) - except BaseException as e: # pylint: disable=broad-except - # If the call failed, the thread would be destroyed. The - # channel object can be collected before calling the - # callback, which will result in a deadlock. + with bound_socket(listen=False) as (host, port): + addr = '{}:{}'.format(host, port) + wg = test_common.WaitGroup(len(_ALL_CALL_CASES)) + + def wait_for_transient_failure(channel_connectivity): + if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE: wg.done() - unhandled_exceptions.put(e, True) - test_threads = [] - for perform_call in _ALL_CALL_CASES: - test_thread = threading.Thread( - target=test_call, args=(perform_call,)) - test_thread.exception = None - test_thread.start() - test_threads.append(test_thread) - - # Start the server after the connections are waiting - wg.wait() - tcp.close() - server = test_common.test_server() - server.add_generic_rpc_handlers((_GenericHandler(weakref.proxy(self)),)) - server.add_insecure_port(addr) - server.start() - - for test_thread in test_threads: - test_thread.join() - - # Stop the server to make test end properly - server.stop(0) - - if not unhandled_exceptions.empty(): - raise unhandled_exceptions.get(True) + def test_call(perform_call): + with grpc.insecure_channel(addr) as channel: + try: + channel.subscribe(wait_for_transient_failure) + perform_call(channel, wait_for_ready=True) + except BaseException as e: # pylint: disable=broad-except + # If the call failed, the thread would be destroyed. The + # channel object can be collected before calling the + # callback, which will result in a deadlock. + wg.done() + unhandled_exceptions.put(e, True) + + test_threads = [] + for perform_call in _ALL_CALL_CASES: + test_thread = threading.Thread( + target=test_call, args=(perform_call,)) + test_thread.exception = None + test_thread.start() + test_threads.append(test_thread) + + # Start the server after the connections are waiting + wg.wait() + server = test_common.test_server(reuse_port=True) + server.add_generic_rpc_handlers((_GenericHandler( + weakref.proxy(self)),)) + server.add_insecure_port(addr) + server.start() + + for test_thread in test_threads: + test_thread.join() + + # Stop the server to make test end properly + server.stop(0) + + if not unhandled_exceptions.empty(): + raise unhandled_exceptions.get(True) if __name__ == '__main__': diff --git a/src/python/grpcio_tests/tests/unit/_reconnect_test.py b/src/python/grpcio_tests/tests/unit/_reconnect_test.py index d4ea126e2b5..0d97f9d735b 100644 --- a/src/python/grpcio_tests/tests/unit/_reconnect_test.py +++ b/src/python/grpcio_tests/tests/unit/_reconnect_test.py @@ -22,6 +22,7 @@ import grpc from grpc.framework.foundation import logging_pool from tests.unit.framework.common import test_constants +from tests.unit.framework.common import bound_socket _REQUEST = b'\x00\x00\x00' _RESPONSE = b'\x00\x00\x01' @@ -33,44 +34,6 @@ def _handle_unary_unary(unused_request, unused_servicer_context): return _RESPONSE -def _get_reuse_socket_option(): - try: - return socket.SO_REUSEPORT - except AttributeError: - # SO_REUSEPORT is unavailable on Windows, but SO_REUSEADDR - # allows forcibly re-binding to a port - return socket.SO_REUSEADDR - - -def _pick_and_bind_port(sock_opt): - # Reserve a port, when we restart the server we want - # to hold onto the port - port = 0 - for address_family in (socket.AF_INET6, socket.AF_INET): - try: - s = socket.socket(address_family, socket.SOCK_STREAM) - except socket.error: - continue # this address family is unavailable - s.setsockopt(socket.SOL_SOCKET, sock_opt, 1) - try: - s.bind(('localhost', port)) - # for socket.SOCK_STREAM sockets, it is necessary to call - # listen to get the desired behavior. - s.listen(1) - port = s.getsockname()[1] - except socket.error: - # port was not available on the current address family - # try again - port = 0 - break - finally: - s.close() - if s: - return port if port != 0 else _pick_and_bind_port(sock_opt) - else: - return None # no address family was available - - class ReconnectTest(unittest.TestCase): def test_reconnect(self): @@ -79,14 +42,13 @@ class ReconnectTest(unittest.TestCase): 'UnaryUnary': grpc.unary_unary_rpc_method_handler(_handle_unary_unary) }) - sock_opt = _get_reuse_socket_option() - port = _pick_and_bind_port(sock_opt) - self.assertIsNotNone(port) - - server = grpc.server(server_pool, (handler,)) - server.add_insecure_port('[::]:{}'.format(port)) - server.start() - channel = grpc.insecure_channel('localhost:%d' % port) + options = (('grpc.so_reuseport', 1),) + with bound_socket() as (host, port): + addr = '{}:{}'.format(host, port) + server = grpc.server(server_pool, (handler,), options=options) + server.add_insecure_port(addr) + server.start() + channel = grpc.insecure_channel(addr) multi_callable = channel.unary_unary(_UNARY_UNARY) self.assertEqual(_RESPONSE, multi_callable(_REQUEST)) server.stop(None) @@ -94,8 +56,8 @@ class ReconnectTest(unittest.TestCase): # GRPC_CLIENT_CHANNEL_BACKUP_POLL_INTERVAL_MS can be set to change # this. time.sleep(5.1) - server = grpc.server(server_pool, (handler,)) - server.add_insecure_port('[::]:{}'.format(port)) + server = grpc.server(server_pool, (handler,), options=options) + server.add_insecure_port(addr) server.start() self.assertEqual(_RESPONSE, multi_callable(_REQUEST)) server.stop(None) diff --git a/src/python/grpcio_tests/tests/unit/_tcp_proxy.py b/src/python/grpcio_tests/tests/unit/_tcp_proxy.py index 5ad0bf8f028..84dc0e2d6cf 100644 --- a/src/python/grpcio_tests/tests/unit/_tcp_proxy.py +++ b/src/python/grpcio_tests/tests/unit/_tcp_proxy.py @@ -27,35 +27,12 @@ import select import socket import threading +from tests.unit.framework.common import get_socket + _TCP_PROXY_BUFFER_SIZE = 1024 _TCP_PROXY_TIMEOUT = datetime.timedelta(milliseconds=500) -def _create_socket_ipv6(bind_address): - listen_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - listen_socket.bind((bind_address, 0, 0, 0)) - return listen_socket - - -def _create_socket_ipv4(bind_address): - listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - listen_socket.bind((bind_address, 0)) - return listen_socket - - -def _init_listen_socket(bind_address): - listen_socket = None - if socket.has_ipv6: - try: - listen_socket = _create_socket_ipv6(bind_address) - except socket.error: - listen_socket = _create_socket_ipv4(bind_address) - else: - listen_socket = _create_socket_ipv4(bind_address) - listen_socket.listen(1) - return listen_socket, listen_socket.getsockname()[1] - - def _init_proxy_socket(gateway_address, gateway_port): proxy_socket = socket.create_connection((gateway_address, gateway_port)) return proxy_socket @@ -87,8 +64,8 @@ class TcpProxy(object): self._thread = threading.Thread(target=self._run_proxy) def start(self): - self._listen_socket, self._port = _init_listen_socket( - self._bind_address) + _, self._port, self._listen_socket = get_socket( + bind_address=self._bind_address) self._proxy_socket = _init_proxy_socket(self._gateway_address, self._gateway_port) self._thread.start() diff --git a/src/python/grpcio_tests/tests/unit/framework/common/BUILD.bazel b/src/python/grpcio_tests/tests/unit/framework/common/BUILD.bazel index cd5d99cfa83..a5ddcf3402b 100644 --- a/src/python/grpcio_tests/tests/unit/framework/common/BUILD.bazel +++ b/src/python/grpcio_tests/tests/unit/framework/common/BUILD.bazel @@ -3,6 +3,7 @@ package(default_visibility = ["//visibility:public"]) py_library( name = "common", srcs = [ + "__init__.py", "test_constants.py", "test_control.py", "test_coverage.py", diff --git a/src/python/grpcio_tests/tests/unit/framework/common/__init__.py b/src/python/grpcio_tests/tests/unit/framework/common/__init__.py index 5fb4f3c3cfd..c1ac76248ed 100644 --- a/src/python/grpcio_tests/tests/unit/framework/common/__init__.py +++ b/src/python/grpcio_tests/tests/unit/framework/common/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015 gRPC authors. +# Copyright 2019 The gRPC authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,3 +11,74 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import contextlib +import os +import socket + +_DEFAULT_SOCK_OPTION = socket.SO_REUSEADDR if os.name == 'nt' else socket.SO_REUSEPORT + + +def get_socket(bind_address='localhost', + listen=True, + sock_options=(_DEFAULT_SOCK_OPTION,)): + """Opens a socket bound to an arbitrary port. + + Useful for reserving a port for a system-under-test. + + Args: + bind_address: The host to which to bind. + listen: A boolean value indicating whether or not to listen on the socket. + sock_options: A sequence of socket options to apply to the socket. + + Returns: + A tuple containing: + - the address to which the socket is bound + - the port to which the socket is bound + - the socket object itself + """ + _sock_options = sock_options if sock_options else [] + if socket.has_ipv6: + address_families = (socket.AF_INET6, socket.AF_INET) + else: + address_families = (socket.AF_INET) + for address_family in address_families: + try: + sock = socket.socket(address_family, socket.SOCK_STREAM) + for sock_option in _sock_options: + sock.setsockopt(socket.SOL_SOCKET, sock_option, 1) + sock.bind((bind_address, 0)) + if listen: + sock.listen(1) + return bind_address, sock.getsockname()[1], sock + except socket.error: + sock.close() + continue + raise RuntimeError("Failed to bind to {} with sock_options {}".format( + bind_address, sock_options)) + + +@contextlib.contextmanager +def bound_socket(bind_address='localhost', + listen=True, + sock_options=(_DEFAULT_SOCK_OPTION,)): + """Opens a socket bound to an arbitrary port. + + Useful for reserving a port for a system-under-test. + + Args: + bind_address: The host to which to bind. + listen: A boolean value indicating whether or not to listen on the socket. + sock_options: A sequence of socket options to apply to the socket. + + Yields: + A tuple containing: + - the address to which the socket is bound + - the port to which the socket is bound + """ + host, port, sock = get_socket( + bind_address=bind_address, listen=listen, sock_options=sock_options) + try: + yield host, port + finally: + sock.close() diff --git a/src/python/grpcio_tests/tests/unit/test_common.py b/src/python/grpcio_tests/tests/unit/test_common.py index bc3b24862dc..305781bd533 100644 --- a/src/python/grpcio_tests/tests/unit/test_common.py +++ b/src/python/grpcio_tests/tests/unit/test_common.py @@ -100,14 +100,14 @@ def test_secure_channel(target, channel_credentials, server_host_override): return channel -def test_server(max_workers=10): +def test_server(max_workers=10, reuse_port=False): """Creates an insecure grpc server. These servers have SO_REUSEPORT disabled to prevent cross-talk. """ return grpc.server( futures.ThreadPoolExecutor(max_workers=max_workers), - options=(('grpc.so_reuseport', 0),)) + options=(('grpc.so_reuseport', int(reuse_port)),)) class WaitGroup(object):