WIP. Start writing signatures

pull/20838/head
Richard Belleville 5 years ago
parent 0c8bedca33
commit 3b652bc3ef
  1. BIN
      src/python/grpcio_tests/tests/stress/single_thread.cprof
  2. 135
      src/python/grpcio_tests/tests/unit/_metadata_flags_test.py
  3. 1
      src/python/grpcio_tests/tests/unit/framework/common/BUILD.bazel
  4. 42
      src/python/grpcio_tests/tests/unit/framework/common/__init__.py

@ -24,6 +24,8 @@ import grpc
from tests.unit import test_common
from tests.unit.framework.common import test_constants
import tests.unit.framework.common
from tests.unit.framework.common import listening_socket
_UNARY_UNARY = '/test/UnaryUnary'
_UNARY_STREAM = '/test/UnaryStream'
@ -93,35 +95,36 @@ class _GenericHandler(grpc.GenericRpcHandler):
return None
def _create_socket_ipv6(bind_address):
listen_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
listen_socket.bind((bind_address, 0, 0, 0))
return listen_socket
def _create_socket_ipv4(bind_address):
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
listen_socket.bind((bind_address, 0))
return listen_socket
def get_free_loopback_tcp_port():
listen_socket = None
if socket.has_ipv6:
try:
listen_socket = _create_socket_ipv6('')
except socket.error:
listen_socket = _create_socket_ipv4('')
else:
listen_socket = _create_socket_ipv4('')
address_tuple = listen_socket.getsockname()
return listen_socket, "localhost:%s" % (address_tuple[1])
# def _create_socket_ipv6(bind_address):
# listen_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
# listen_socket.bind((bind_address, 0, 0, 0))
# return listen_socket
#
#
# def _create_socket_ipv4(bind_address):
# listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# listen_socket.bind((bind_address, 0))
# return listen_socket
#
#
# def get_free_loopback_tcp_port():
# listen_socket = None
# if socket.has_ipv6:
# try:
# listen_socket = _create_socket_ipv6('')
# except socket.error:
# listen_socket = _create_socket_ipv4('')
# else:
# listen_socket = _create_socket_ipv4('')
# address_tuple = listen_socket.getsockname()
# return listen_socket, "localhost:%s" % (address_tuple[1])
def create_dummy_channel():
"""Creating dummy channels is a workaround for retries"""
_, addr = get_free_loopback_tcp_port()
return grpc.insecure_channel(addr)
# _, addr = get_free_loopback_tcp_port()
with listening_socket() as host, port:
return grpc.insecure_channel('{}:{}'.format(host, port))
def perform_unary_unary_call(channel, wait_for_ready=None):
@ -221,49 +224,49 @@ class MetadataFlagsTest(unittest.TestCase):
# main thread. So, it need another method to store the
# exceptions and raise them again in main thread.
unhandled_exceptions = queue.Queue()
tcp, addr = get_free_loopback_tcp_port()
wg = test_common.WaitGroup(len(_ALL_CALL_CASES))
def wait_for_transient_failure(channel_connectivity):
if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE:
wg.done()
def test_call(perform_call):
with grpc.insecure_channel(addr) as channel:
try:
channel.subscribe(wait_for_transient_failure)
perform_call(channel, wait_for_ready=True)
except BaseException as e: # pylint: disable=broad-except
# If the call failed, the thread would be destroyed. The
# channel object can be collected before calling the
# callback, which will result in a deadlock.
with listening_socket() as (host, port):
addr = '{}:{}'.format(host, port)
wg = test_common.WaitGroup(len(_ALL_CALL_CASES))
def wait_for_transient_failure(channel_connectivity):
if channel_connectivity == grpc.ChannelConnectivity.TRANSIENT_FAILURE:
wg.done()
unhandled_exceptions.put(e, True)
test_threads = []
for perform_call in _ALL_CALL_CASES:
test_thread = threading.Thread(
target=test_call, args=(perform_call,))
test_thread.exception = None
test_thread.start()
test_threads.append(test_thread)
# Start the server after the connections are waiting
wg.wait()
tcp.close()
server = test_common.test_server()
server.add_generic_rpc_handlers((_GenericHandler(weakref.proxy(self)),))
server.add_insecure_port(addr)
server.start()
for test_thread in test_threads:
test_thread.join()
# Stop the server to make test end properly
server.stop(0)
if not unhandled_exceptions.empty():
raise unhandled_exceptions.get(True)
def test_call(perform_call):
with grpc.insecure_channel(addr) as channel:
try:
channel.subscribe(wait_for_transient_failure)
perform_call(channel, wait_for_ready=True)
except BaseException as e: # pylint: disable=broad-except
# If the call failed, the thread would be destroyed. The
# channel object can be collected before calling the
# callback, which will result in a deadlock.
wg.done()
unhandled_exceptions.put(e, True)
test_threads = []
for perform_call in _ALL_CALL_CASES:
test_thread = threading.Thread(
target=test_call, args=(perform_call,))
test_thread.exception = None
test_thread.start()
test_threads.append(test_thread)
# Start the server after the connections are waiting
wg.wait()
server = test_common.test_server()
server.add_generic_rpc_handlers((_GenericHandler(weakref.proxy(self)),))
server.add_insecure_port(addr)
server.start()
for test_thread in test_threads:
test_thread.join()
# Stop the server to make test end properly
server.stop(0)
if not unhandled_exceptions.empty():
raise unhandled_exceptions.get(True)
if __name__ == '__main__':

@ -3,6 +3,7 @@ package(default_visibility = ["//visibility:public"])
py_library(
name = "common",
srcs = [
"__init__.py",
"test_constants.py",
"test_control.py",
"test_coverage.py",

@ -11,3 +11,45 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import socket
def get_socket(bind_address='localhost', sock_options=(socket.SO_REUSEPORT,)):
"""Opens a listening socket on an arbitrary port.
Useful for reserving a port for a system-under-test.
Args:
bind_address: The host to which to bind.
sock_options: A sequence of socket options to apply to the socket.
Returns:
A tuple containing:
- the address to which the socket is bound
- the port to which the socket is bound
- the socket object itself
"""
_sock_options = sock_options if sock_options else []
for address_family in (socket.AF_INET, socket.AF_INET6):
try:
sock = socket.socket(address_family, socket.SOCK_STREAM)
for sock_option in _sock_options:
sock.setsockopt(socket.SOL_SOCKET, sock_option, 1)
sock.bind((bind_address, 0))
sock.listen(1)
return bind_address, sock.getsockname()[1], sock
except socket.error:
continue
raise RuntimeError("Failed to find to {} with sock_options {}".format(bind_address, sock_options))
@contextlib.contextmanager
def listening_socket(bind_address='localhost', sock_options=(socket.SO_REUSEPORT,)):
# TODO: Docstring.
host, port, sock = get_socket(bind_address=bind_address, sock_options=sock_options)
try:
yield host, port
finally:
sock.close()

Loading…
Cancel
Save