Reuse 'bound_socket' to improve the test case

pull/21607/head
Lidi Zheng 5 years ago
parent d7698e7e1d
commit 305defc7cb
  1. 24
      src/python/grpcio_tests/tests/unit/framework/common/__init__.py
  2. 18
      src/python/grpcio_tests/tests_aio/unit/channel_argument_test.py

@ -17,17 +17,27 @@ import os
import socket
_DEFAULT_SOCK_OPTION = socket.SO_REUSEADDR if os.name == 'nt' else socket.SO_REUSEPORT
_UNRECOVERABLE_ERRORS = ('Address already in use',)
def _exception_is_unrecoverable(e):
for error in _UNRECOVERABLE_ERRORS:
if error in str(e):
return True
return False
def get_socket(bind_address='localhost',
port=0,
listen=True,
sock_options=(_DEFAULT_SOCK_OPTION,)):
"""Opens a socket bound to an arbitrary port.
"""Opens a socket.
Useful for reserving a port for a system-under-test.
Args:
bind_address: The host to which to bind.
port: The port to bind.
listen: A boolean value indicating whether or not to listen on the socket.
sock_options: A sequence of socket options to apply to the socket.
@ -47,19 +57,23 @@ def get_socket(bind_address='localhost',
sock = socket.socket(address_family, socket.SOCK_STREAM)
for sock_option in _sock_options:
sock.setsockopt(socket.SOL_SOCKET, sock_option, 1)
sock.bind((bind_address, 0))
sock.bind((bind_address, port))
if listen:
sock.listen(1)
return bind_address, sock.getsockname()[1], sock
except socket.error:
except socket.error as socket_error:
sock.close()
continue
if _exception_is_unrecoverable(socket_error):
raise
else:
continue
raise RuntimeError("Failed to bind to {} with sock_options {}".format(
bind_address, sock_options))
@contextlib.contextmanager
def bound_socket(bind_address='localhost',
port=0,
listen=True,
sock_options=(_DEFAULT_SOCK_OPTION,)):
"""Opens a socket bound to an arbitrary port.
@ -68,6 +82,7 @@ def bound_socket(bind_address='localhost',
Args:
bind_address: The host to which to bind.
port: The port to bind.
listen: A boolean value indicating whether or not to listen on the socket.
sock_options: A sequence of socket options to apply to the socket.
@ -77,6 +92,7 @@ def bound_socket(bind_address='localhost',
- the port to which the socket is bound
"""
host, port, sock = get_socket(bind_address=bind_address,
port=port,
listen=listen,
sock_options=sock_options)
try:

@ -17,13 +17,13 @@ import asyncio
import logging
import platform
import random
import socket
import unittest
import grpc
from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
from tests.unit.framework import common
from tests_aio.unit._test_base import AioTestBase
from tests_aio.unit._test_server import start_test_server
@ -71,21 +71,17 @@ async def test_if_reuse_port_enabled(server: aio.Server):
await server.start()
try:
if socket.has_ipv6:
another_socket = socket.socket(family=socket.AF_INET6)
else:
another_socket = socket.socket(family=socket.AF_INET)
another_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
another_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
another_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
another_socket.bind(('localhost', port))
with common.bound_socket(
bind_address='localhost',
port=port,
listen=False,
) as (unused_host, bound_port):
assert bound_port == port
except OSError as e:
assert 'Address already in use' in str(e)
return False
else:
return True
finally:
another_socket.close()
class TestChannelArgument(AioTestBase):

Loading…
Cancel
Save