Add --metadata flag

pull/23892/head
Richard Belleville 5 years ago
parent 86525d703c
commit b0ab7197a6
  1. 41
      src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py

@ -19,7 +19,7 @@ import threading
import time
import sys
from typing import DefaultDict, Dict, List, Mapping, Set, Sequence
from typing import DefaultDict, Dict, List, Mapping, Set, Sequence, Tuple
import collections
from concurrent import futures
@ -39,6 +39,7 @@ logger.addHandler(console_handler)
_SUPPORTED_METHODS = ("UnaryCall", "EmptyCall",)
PerMethodMetadataType = Mapping[str, Sequence[Tuple[str, str]]]
class _StatsWatcher:
_start: int
@ -118,14 +119,16 @@ class _LoadBalancerStatsServicer(test_pb2_grpc.LoadBalancerStatsServiceServicer
return response
def _start_rpc(method: str, request_id: int, stub: test_pb2_grpc.TestServiceStub,
def _start_rpc(method: str, metadata: Sequence[Tuple[str, str]], request_id: int, stub: test_pb2_grpc.TestServiceStub,
timeout: float, futures: Mapping[int, grpc.Future]) -> None:
logger.info(f"Sending request to backend: {request_id}")
logger.info(f"Sending {method} request to backend: {request_id}")
if method == "UnaryCall":
future = stub.UnaryCall.future(messages_pb2.SimpleRequest(),
metadata=metadata,
timeout=timeout)
elif method == "EmptyCall":
future = stub.EmptyCall.future(empty_pb2.Empty(),
metadata=metadata,
timeout=timeout)
else:
raise ValueError(f"Unrecognized method '{method}'.")
@ -173,7 +176,7 @@ def _cancel_all_rpcs(futures: Mapping[int, grpc.Future]) -> None:
future.cancel()
def _run_single_channel(method: str, qps: int, server: str, rpc_timeout_sec: int, print_response: bool):
def _run_single_channel(method: str, metadata: Sequence[Tuple[str, str]], qps: int, server: str, rpc_timeout_sec: int, print_response: bool):
global _global_rpc_id # pylint: disable=global-statement
duration_per_query = 1.0 / float(qps)
with grpc.insecure_channel(server) as channel:
@ -186,7 +189,7 @@ def _run_single_channel(method: str, qps: int, server: str, rpc_timeout_sec: int
_global_rpc_id += 1
start = time.time()
end = start + duration_per_query
_start_rpc(method, request_id, stub, float(rpc_timeout_sec), futures)
_start_rpc(method, metadata, request_id, stub, float(rpc_timeout_sec), futures)
_remove_completed_rpcs(futures, print_response)
logger.debug(f"Currently {len(futures)} in-flight RPCs")
now = time.time()
@ -200,11 +203,11 @@ class _MethodHandle:
_channel_threads: List[threading.Thread]
def __init__(self, method: str, num_channels: int, qps: int, server: str, rpc_timeout_sec: int, print_response: bool):
def __init__(self, method: str, metadata: Sequence[Tuple[str, str]], num_channels: int, qps: int, server: str, rpc_timeout_sec: int, print_response: bool):
"""Creates and starts a group of threads running the indicated method."""
self._channel_threads = []
for i in range(num_channels):
thread = threading.Thread(target=_run_single_channel, args=(method, qps, server, rpc_timeout_sec, print_response,))
thread = threading.Thread(target=_run_single_channel, args=(method, metadata, qps, server, rpc_timeout_sec, print_response,))
thread.start()
self._channel_threads.append(thread)
@ -214,12 +217,12 @@ class _MethodHandle:
channel_thread.join()
def _run(args: argparse.Namespace, methods: Sequence[str]) -> None:
def _run(args: argparse.Namespace, methods: Sequence[str], per_method_metadata: PerMethodMetadataType) -> None:
logger.info("Starting python xDS Interop Client.")
global _global_server # pylint: disable=global-statement
method_handles = []
for method in methods:
method_handles.append(_MethodHandle(method, args.num_channels, args.qps, args.server, args.rpc_timeout_sec, args.print_response))
method_handles.append(_MethodHandle(method, per_method_metadata.get(method, []), args.num_channels, args.qps, args.server, args.rpc_timeout_sec, args.print_response))
_global_server = grpc.server(futures.ThreadPoolExecutor())
_global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}")
test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server(
@ -274,6 +277,13 @@ if __name__ == "__main__":
default="UnaryCall",
type=str,
help=rpc_help)
metadata_help = ("A comma-delimited list of 3-tuples of the form " +
"METHOD:KEY:VALUE, e.g. " +
"EmptyCall:key1:value1,UnaryCall:key2:value2,EmptyCall:k3:v3")
parser.add_argument("--metadata",
default="",
type=str,
help=metadata_help)
args = parser.parse_args()
signal.signal(signal.SIGINT, _handle_sigint)
if args.verbose:
@ -282,7 +292,16 @@ if __name__ == "__main__":
file_handler = logging.FileHandler(args.log_file, mode='a')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
methods = args.rpc.split(",")
methods = args.rpc.split(",")
if set(methods) - set(_SUPPORTED_METHODS):
raise ValueError("--rpc supported methods: {}".format(", ".join(_SUPPORTED_METHODS)))
_run(args, methods)
per_method_metadata = collections.defaultdict(list)
metadata = args.metadata.split(",") if args.metadata else []
for metadatum in metadata:
elems = metadatum.split(":")
if len(elems) != 3:
raise ValueError(f"'{metadatum}' was not in the form 'METHOD:KEY:VALUE'")
if elems[0] not in _SUPPORTED_METHODS:
raise ValueError(f"Unrecognized method '{elems[0]}'")
per_method_metadata[elems[0]].append((elems[1], elems[2]))
_run(args, methods, per_method_metadata)

Loading…
Cancel
Save