diff --git a/src/python/grpcio_tests/tests_py3_only/interop/BUILD.bazel b/src/python/grpcio_tests/tests_py3_only/interop/BUILD.bazel index 2551c31db7c..ca360871209 100644 --- a/src/python/grpcio_tests/tests_py3_only/interop/BUILD.bazel +++ b/src/python/grpcio_tests/tests_py3_only/interop/BUILD.bazel @@ -41,3 +41,19 @@ py_binary( "//src/python/grpcio_reflection/grpc_reflection/v1alpha:grpc_reflection", ], ) + +py_test( + name = "xds_interop_client_test", + srcs = ["xds_interop_client_test.py"], + imports = ["."], + python_version = "PY3", + deps = [ + ":xds_interop_client", + ":xds_interop_server", + "//src/proto/grpc/testing:empty_py_pb2", + "//src/proto/grpc/testing:py_messages_proto", + "//src/proto/grpc/testing:py_test_proto", + "//src/proto/grpc/testing:test_py_pb2_grpc", + "//src/python/grpcio_tests/tests/unit/framework/common", + ], +) diff --git a/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py b/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py index c98bba39a5a..582da6a1c0f 100644 --- a/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py +++ b/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py @@ -198,7 +198,8 @@ def _on_rpc_done(rpc_id: int, future: grpc.Future, method: str, print_response: bool) -> None: exception = future.exception() hostname = "" - _global_rpc_statuses[method][future.code().value[0]] += 1 + with _global_lock: + _global_rpc_statuses[method][future.code().value[0]] += 1 if exception is not None: with _global_lock: _global_rpcs_failed[method] += 1 @@ -294,18 +295,17 @@ def _run_single_channel(config: _ChannelConfiguration) -> None: continue else: duration_per_query = 1.0 / float(config.qps) - request_id = None - with _global_lock: - request_id = _global_rpc_id - _global_rpc_id += 1 - _global_rpcs_started[config.method] += 1 - start = time.time() - end = start + duration_per_query - with config.condition: + request_id = None + with _global_lock: + request_id = _global_rpc_id + _global_rpc_id += 1 + _global_rpcs_started[config.method] += 1 + start = time.time() + end = start + duration_per_query _start_rpc(config.method, config.metadata, request_id, stub, float(config.rpc_timeout_sec), futures) - with config.condition: - _remove_completed_rpcs(futures, config.print_response) + print_response = config.print_response + _remove_completed_rpcs(futures, config.print_response) logger.debug(f"Currently {len(futures)} in-flight RPCs") now = time.time() while now < end: diff --git a/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client_test.py b/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client_test.py new file mode 100644 index 00000000000..98f6e388b19 --- /dev/null +++ b/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client_test.py @@ -0,0 +1,182 @@ +# Copyright 2022 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import contextlib +import logging +import os +import subprocess +import sys +import tempfile +import time +from typing import Iterable, List, Mapping, Set, Tuple +import unittest + +import grpc.experimental +import xds_interop_client +import xds_interop_server + +from src.proto.grpc.testing import empty_pb2 +from src.proto.grpc.testing import messages_pb2 +from src.proto.grpc.testing import test_pb2 +from src.proto.grpc.testing import test_pb2_grpc +import src.python.grpcio_tests.tests.unit.framework.common as framework_common + +_CLIENT_PATH = os.path.abspath(os.path.realpath(xds_interop_client.__file__)) +_SERVER_PATH = os.path.abspath(os.path.realpath(xds_interop_server.__file__)) + +_METHODS = ( + (messages_pb2.ClientConfigureRequest.UNARY_CALL, "UNARY_CALL"), + (messages_pb2.ClientConfigureRequest.EMPTY_CALL, "EMPTY_CALL"), +) + +_QPS = 100 +_NUM_CHANNELS = 20 + +_TEST_ITERATIONS = 10 +_ITERATION_DURATION_SECONDS = 1 +_SUBPROCESS_TIMEOUT_SECONDS = 2 + + +def _set_union(a: Iterable, b: Iterable) -> Set: + c = set(a) + c.update(b) + return c + + +@contextlib.contextmanager +def _start_python_with_args( + file: str, args: List[str] +) -> Tuple[subprocess.Popen, tempfile.TemporaryFile, tempfile.TemporaryFile]: + with tempfile.TemporaryFile(mode='r') as stdout: + with tempfile.TemporaryFile(mode='r') as stderr: + proc = subprocess.Popen((sys.executable, file) + tuple(args), + stdout=stdout, + stderr=stderr) + yield proc, stdout, stderr + + +def _dump_stream(process_name: str, stream_name: str, + stream: tempfile.TemporaryFile): + sys.stderr.write(f"{process_name} {stream_name}:\n") + stream.seek(0) + sys.stderr.write(stream.read()) + + +def _dump_streams(process_name: str, stdout: tempfile.TemporaryFile, + stderr: tempfile.TemporaryFile): + _dump_stream(process_name, "stdout", stdout) + _dump_stream(process_name, "stderr", stderr) + sys.stderr.write(f"End {process_name} output.\n") + + +def _index_accumulated_stats( + response: messages_pb2.LoadBalancerAccumulatedStatsResponse +) -> Mapping[str, Mapping[int, int]]: + indexed = collections.defaultdict(lambda: collections.defaultdict(int)) + for _, method_str in _METHODS: + for status in response.stats_per_method[method_str].result.keys(): + indexed[method_str][status] = response.stats_per_method[ + method_str].result[status] + return indexed + + +def _subtract_indexed_stats(a: Mapping[str, Mapping[int, int]], + b: Mapping[str, Mapping[int, int]]): + c = collections.defaultdict(lambda: collections.defaultdict(int)) + all_methods = _set_union(a.keys(), b.keys()) + for method in all_methods: + all_statuses = _set_union(a[method].keys(), b[method].keys()) + for status in all_statuses: + c[method][status] = a[method][status] - b[method][status] + return c + + +def _collect_stats(stats_port: int, + duration: int) -> Mapping[str, Mapping[int, int]]: + settings = { + "target": f"localhost:{stats_port}", + "insecure": True, + } + response = test_pb2_grpc.LoadBalancerStatsService.GetClientAccumulatedStats( + messages_pb2.LoadBalancerAccumulatedStatsRequest(), **settings) + before = _index_accumulated_stats(response) + time.sleep(duration) + response = test_pb2_grpc.LoadBalancerStatsService.GetClientAccumulatedStats( + messages_pb2.LoadBalancerAccumulatedStatsRequest(), **settings) + after = _index_accumulated_stats(response) + return _subtract_indexed_stats(after, before) + + +class XdsInteropClientTest(unittest.TestCase): + + def _assert_client_consistent(self, server_port: int, stats_port: int, + qps: int, num_channels: int): + settings = { + "target": f"localhost:{stats_port}", + "insecure": True, + } + for i in range(_TEST_ITERATIONS): + target_method, target_method_str = _METHODS[i % len(_METHODS)] + test_pb2_grpc.XdsUpdateClientConfigureService.Configure( + messages_pb2.ClientConfigureRequest(types=[target_method]), + **settings) + delta = _collect_stats(stats_port, _ITERATION_DURATION_SECONDS) + logging.info("Delta: %s", delta) + for _, method_str in _METHODS: + for status in delta[method_str]: + if status == 0 and method_str == target_method_str: + self.assertGreater(delta[method_str][status], 0, delta) + else: + self.assertEqual(delta[method_str][status], 0, delta) + + def test_configure_consistency(self): + _, server_port, socket = framework_common.get_socket() + + with _start_python_with_args( + _SERVER_PATH, + [f"--port={server_port}", f"--maintenance_port={server_port}" + ]) as (server, server_stdout, server_stderr): + # Send RPC to server to make sure it's running. + logging.info("Sending RPC to server.") + test_pb2_grpc.TestService.EmptyCall(empty_pb2.Empty(), + f"localhost:{server_port}", + insecure=True, + wait_for_ready=True) + logging.info("Server successfully started.") + socket.close() + _, stats_port, stats_socket = framework_common.get_socket() + with _start_python_with_args(_CLIENT_PATH, [ + f"--server=localhost:{server_port}", + f"--stats_port={stats_port}", f"--qps={_QPS}", + f"--num_channels={_NUM_CHANNELS}" + ]) as (client, client_stdout, client_stderr): + stats_socket.close() + try: + self._assert_client_consistent(server_port, stats_port, + _QPS, _NUM_CHANNELS) + except: + _dump_streams("server", server_stdout, server_stderr) + _dump_streams("client", client_stdout, client_stderr) + raise + finally: + server.kill() + client.kill() + server.wait(timeout=_SUBPROCESS_TIMEOUT_SECONDS) + client.wait(timeout=_SUBPROCESS_TIMEOUT_SECONDS) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2)