diff --git a/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py b/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py index b671302cc04..fb253bf942d 100644 --- a/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py +++ b/src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py @@ -19,7 +19,7 @@ import threading import time import sys -from typing import DefaultDict, Dict, List, Mapping, Set +from typing import DefaultDict, Dict, List, Mapping, Set, Sequence import collections from concurrent import futures @@ -37,6 +37,8 @@ formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)-8s %(message)s') console_handler.setFormatter(formatter) logger.addHandler(console_handler) +_SUPPORTED_METHODS = ("UnaryCall", "EmptyCall",) + class _StatsWatcher: _start: int @@ -212,11 +214,11 @@ class _MethodHandle: channel_thread.join() -def _run(args: argparse.Namespace) -> None: +def _run(args: argparse.Namespace, methods: Sequence[str]) -> None: logger.info("Starting python xDS Interop Client.") global _global_server # pylint: disable=global-statement method_handles = [] - for method in ("UnaryCall",): + for method in methods: method_handles.append(_MethodHandle(method, args.num_channels, args.qps, args.server, args.rpc_timeout_sec, args.print_response)) _global_server = grpc.server(futures.ThreadPoolExecutor()) _global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}") @@ -265,6 +267,13 @@ if __name__ == "__main__": default=None, type=str, help="A file to log to.") + rpc_help = "A comma-delimited list of RPC methods to run. Must be one of " + rpc_help += ", ".join(_SUPPORTED_METHODS) + rpc_help += "." + parser.add_argument("--rpc", + default="UnaryCall", + type=str, + help=rpc_help) args = parser.parse_args() signal.signal(signal.SIGINT, _handle_sigint) if args.verbose: @@ -273,4 +282,7 @@ if __name__ == "__main__": file_handler = logging.FileHandler(args.log_file, mode='a') file_handler.setFormatter(formatter) logger.addHandler(file_handler) - _run(args) + methods = args.rpc.split(",") + if set(methods) - set(_SUPPORTED_METHODS): + raise ValueError("--rpc supported methods: {}".format(", ".join(_SUPPORTED_METHODS))) + _run(args, methods)