diff --git a/src/python/grpcio_tests/tests/unit/framework/common/__init__.py b/src/python/grpcio_tests/tests/unit/framework/common/__init__.py index 44ce1ddde2f..083828104af 100644 --- a/src/python/grpcio_tests/tests/unit/framework/common/__init__.py +++ b/src/python/grpcio_tests/tests/unit/framework/common/__init__.py @@ -17,17 +17,27 @@ import os import socket _DEFAULT_SOCK_OPTION = socket.SO_REUSEADDR if os.name == 'nt' else socket.SO_REUSEPORT +_UNRECOVERABLE_ERRORS = ('Address already in use',) + + +def _exception_is_unrecoverable(e): + for error in _UNRECOVERABLE_ERRORS: + if error in str(e): + return True + return False def get_socket(bind_address='localhost', + port=0, listen=True, sock_options=(_DEFAULT_SOCK_OPTION,)): - """Opens a socket bound to an arbitrary port. + """Opens a socket. Useful for reserving a port for a system-under-test. Args: bind_address: The host to which to bind. + port: The port to bind. listen: A boolean value indicating whether or not to listen on the socket. sock_options: A sequence of socket options to apply to the socket. @@ -47,19 +57,23 @@ def get_socket(bind_address='localhost', sock = socket.socket(address_family, socket.SOCK_STREAM) for sock_option in _sock_options: sock.setsockopt(socket.SOL_SOCKET, sock_option, 1) - sock.bind((bind_address, 0)) + sock.bind((bind_address, port)) if listen: sock.listen(1) return bind_address, sock.getsockname()[1], sock - except socket.error: + except socket.error as socket_error: sock.close() - continue + if _exception_is_unrecoverable(socket_error): + raise + else: + continue raise RuntimeError("Failed to bind to {} with sock_options {}".format( bind_address, sock_options)) @contextlib.contextmanager def bound_socket(bind_address='localhost', + port=0, listen=True, sock_options=(_DEFAULT_SOCK_OPTION,)): """Opens a socket bound to an arbitrary port. @@ -68,6 +82,7 @@ def bound_socket(bind_address='localhost', Args: bind_address: The host to which to bind. + port: The port to bind. listen: A boolean value indicating whether or not to listen on the socket. sock_options: A sequence of socket options to apply to the socket. @@ -77,6 +92,7 @@ def bound_socket(bind_address='localhost', - the port to which the socket is bound """ host, port, sock = get_socket(bind_address=bind_address, + port=port, listen=listen, sock_options=sock_options) try: diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_argument_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_argument_test.py index 0a30b8e28a2..6037bade998 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_argument_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_argument_test.py @@ -17,13 +17,13 @@ import asyncio import logging import platform import random -import socket import unittest import grpc from grpc.experimental import aio from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests.unit.framework import common from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_server import start_test_server @@ -71,21 +71,17 @@ async def test_if_reuse_port_enabled(server: aio.Server): await server.start() try: - if socket.has_ipv6: - another_socket = socket.socket(family=socket.AF_INET6) - else: - another_socket = socket.socket(family=socket.AF_INET) - another_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - another_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - another_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) - another_socket.bind(('localhost', port)) + with common.bound_socket( + bind_address='localhost', + port=port, + listen=False, + ) as (unused_host, bound_port): + assert bound_port == port except OSError as e: assert 'Address already in use' in str(e) return False else: return True - finally: - another_socket.close() class TestChannelArgument(AioTestBase):