|
|
|
@ -19,7 +19,7 @@ import threading |
|
|
|
|
import time |
|
|
|
|
import sys |
|
|
|
|
|
|
|
|
|
from typing import DefaultDict, Dict, List, Mapping, Set |
|
|
|
|
from typing import DefaultDict, Dict, List, Mapping, Set, Sequence, Tuple |
|
|
|
|
import collections |
|
|
|
|
|
|
|
|
|
from concurrent import futures |
|
|
|
@ -37,12 +37,20 @@ formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)-8s %(message)s') |
|
|
|
|
console_handler.setFormatter(formatter) |
|
|
|
|
logger.addHandler(console_handler) |
|
|
|
|
|
|
|
|
|
_SUPPORTED_METHODS = ( |
|
|
|
|
"UnaryCall", |
|
|
|
|
"EmptyCall", |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
PerMethodMetadataType = Mapping[str, Sequence[Tuple[str, str]]] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _StatsWatcher: |
|
|
|
|
_start: int |
|
|
|
|
_end: int |
|
|
|
|
_rpcs_needed: int |
|
|
|
|
_rpcs_by_peer: DefaultDict[str, int] |
|
|
|
|
_rpcs_by_method: DefaultDict[str, DefaultDict[str, int]] |
|
|
|
|
_no_remote_peer: int |
|
|
|
|
_lock: threading.Lock |
|
|
|
|
_condition: threading.Condition |
|
|
|
@ -52,10 +60,12 @@ class _StatsWatcher: |
|
|
|
|
self._end = end |
|
|
|
|
self._rpcs_needed = end - start |
|
|
|
|
self._rpcs_by_peer = collections.defaultdict(int) |
|
|
|
|
self._rpcs_by_method = collections.defaultdict( |
|
|
|
|
lambda: collections.defaultdict(int)) |
|
|
|
|
self._condition = threading.Condition() |
|
|
|
|
self._no_remote_peer = 0 |
|
|
|
|
|
|
|
|
|
def on_rpc_complete(self, request_id: int, peer: str) -> None: |
|
|
|
|
def on_rpc_complete(self, request_id: int, peer: str, method: str) -> None: |
|
|
|
|
"""Records statistics for a single RPC.""" |
|
|
|
|
if self._start <= request_id < self._end: |
|
|
|
|
with self._condition: |
|
|
|
@ -63,6 +73,7 @@ class _StatsWatcher: |
|
|
|
|
self._no_remote_peer += 1 |
|
|
|
|
else: |
|
|
|
|
self._rpcs_by_peer[peer] += 1 |
|
|
|
|
self._rpcs_by_method[method][peer] += 1 |
|
|
|
|
self._rpcs_needed -= 1 |
|
|
|
|
self._condition.notify() |
|
|
|
|
|
|
|
|
@ -75,6 +86,9 @@ class _StatsWatcher: |
|
|
|
|
response = messages_pb2.LoadBalancerStatsResponse() |
|
|
|
|
for peer, count in self._rpcs_by_peer.items(): |
|
|
|
|
response.rpcs_by_peer[peer] = count |
|
|
|
|
for method, count_by_peer in self._rpcs_by_method.items(): |
|
|
|
|
for peer, count in count_by_peer.items(): |
|
|
|
|
response.rpcs_by_method[method].rpcs_by_peer[peer] = count |
|
|
|
|
response.num_failures = self._no_remote_peer + self._rpcs_needed |
|
|
|
|
return response |
|
|
|
|
|
|
|
|
@ -116,15 +130,25 @@ class _LoadBalancerStatsServicer(test_pb2_grpc.LoadBalancerStatsServiceServicer |
|
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _start_rpc(request_id: int, stub: test_pb2_grpc.TestServiceStub, |
|
|
|
|
timeout: float, futures: Mapping[int, grpc.Future]) -> None: |
|
|
|
|
logger.info(f"Sending request to backend: {request_id}") |
|
|
|
|
future = stub.UnaryCall.future(messages_pb2.SimpleRequest(), |
|
|
|
|
timeout=timeout) |
|
|
|
|
futures[request_id] = future |
|
|
|
|
def _start_rpc(method: str, metadata: Sequence[Tuple[str, str]], |
|
|
|
|
request_id: int, stub: test_pb2_grpc.TestServiceStub, |
|
|
|
|
timeout: float, |
|
|
|
|
futures: Mapping[int, Tuple[grpc.Future, str]]) -> None: |
|
|
|
|
logger.info(f"Sending {method} request to backend: {request_id}") |
|
|
|
|
if method == "UnaryCall": |
|
|
|
|
future = stub.UnaryCall.future(messages_pb2.SimpleRequest(), |
|
|
|
|
metadata=metadata, |
|
|
|
|
timeout=timeout) |
|
|
|
|
elif method == "EmptyCall": |
|
|
|
|
future = stub.EmptyCall.future(empty_pb2.Empty(), |
|
|
|
|
metadata=metadata, |
|
|
|
|
timeout=timeout) |
|
|
|
|
else: |
|
|
|
|
raise ValueError(f"Unrecognized method '{method}'.") |
|
|
|
|
futures[request_id] = (future, method) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _on_rpc_done(rpc_id: int, future: grpc.Future, |
|
|
|
|
def _on_rpc_done(rpc_id: int, future: grpc.Future, method: str, |
|
|
|
|
print_response: bool) -> None: |
|
|
|
|
exception = future.exception() |
|
|
|
|
hostname = "" |
|
|
|
@ -135,8 +159,13 @@ def _on_rpc_done(rpc_id: int, future: grpc.Future, |
|
|
|
|
logger.error(exception) |
|
|
|
|
else: |
|
|
|
|
response = future.result() |
|
|
|
|
logger.info(f"Got result {rpc_id}") |
|
|
|
|
hostname = response.hostname |
|
|
|
|
hostname = None |
|
|
|
|
for metadatum in future.initial_metadata(): |
|
|
|
|
if metadatum[0] == "hostname": |
|
|
|
|
hostname = metadatum[1] |
|
|
|
|
break |
|
|
|
|
else: |
|
|
|
|
hostname = response.hostname |
|
|
|
|
if print_response: |
|
|
|
|
if future.code() == grpc.StatusCode.OK: |
|
|
|
|
logger.info("Successful response.") |
|
|
|
@ -144,33 +173,35 @@ def _on_rpc_done(rpc_id: int, future: grpc.Future, |
|
|
|
|
logger.info(f"RPC failed: {call}") |
|
|
|
|
with _global_lock: |
|
|
|
|
for watcher in _watchers: |
|
|
|
|
watcher.on_rpc_complete(rpc_id, hostname) |
|
|
|
|
watcher.on_rpc_complete(rpc_id, hostname, method) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _remove_completed_rpcs(futures: Mapping[int, grpc.Future], |
|
|
|
|
print_response: bool) -> None: |
|
|
|
|
logger.debug("Removing completed RPCs") |
|
|
|
|
done = [] |
|
|
|
|
for future_id, future in futures.items(): |
|
|
|
|
for future_id, (future, method) in futures.items(): |
|
|
|
|
if future.done(): |
|
|
|
|
_on_rpc_done(future_id, future, args.print_response) |
|
|
|
|
_on_rpc_done(future_id, future, method, args.print_response) |
|
|
|
|
done.append(future_id) |
|
|
|
|
for rpc_id in done: |
|
|
|
|
del futures[rpc_id] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cancel_all_rpcs(futures: Mapping[int, grpc.Future]) -> None: |
|
|
|
|
def _cancel_all_rpcs(futures: Mapping[int, Tuple[grpc.Future, str]]) -> None: |
|
|
|
|
logger.info("Cancelling all remaining RPCs") |
|
|
|
|
for future in futures.values(): |
|
|
|
|
for future, _ in futures.values(): |
|
|
|
|
future.cancel() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _run_single_channel(args: argparse.Namespace): |
|
|
|
|
def _run_single_channel(method: str, metadata: Sequence[Tuple[str, str]], |
|
|
|
|
qps: int, server: str, rpc_timeout_sec: int, |
|
|
|
|
print_response: bool): |
|
|
|
|
global _global_rpc_id # pylint: disable=global-statement |
|
|
|
|
duration_per_query = 1.0 / float(args.qps) |
|
|
|
|
with grpc.insecure_channel(args.server) as channel: |
|
|
|
|
duration_per_query = 1.0 / float(qps) |
|
|
|
|
with grpc.insecure_channel(server) as channel: |
|
|
|
|
stub = test_pb2_grpc.TestServiceStub(channel) |
|
|
|
|
futures: Dict[int, grpc.Future] = {} |
|
|
|
|
futures: Dict[int, Tuple[grpc.Future, str]] = {} |
|
|
|
|
while not _stop_event.is_set(): |
|
|
|
|
request_id = None |
|
|
|
|
with _global_lock: |
|
|
|
@ -178,8 +209,9 @@ def _run_single_channel(args: argparse.Namespace): |
|
|
|
|
_global_rpc_id += 1 |
|
|
|
|
start = time.time() |
|
|
|
|
end = start + duration_per_query |
|
|
|
|
_start_rpc(request_id, stub, float(args.rpc_timeout_sec), futures) |
|
|
|
|
_remove_completed_rpcs(futures, args.print_response) |
|
|
|
|
_start_rpc(method, metadata, request_id, stub, |
|
|
|
|
float(rpc_timeout_sec), futures) |
|
|
|
|
_remove_completed_rpcs(futures, print_response) |
|
|
|
|
logger.debug(f"Currently {len(futures)} in-flight RPCs") |
|
|
|
|
now = time.time() |
|
|
|
|
while now < end: |
|
|
|
@ -188,22 +220,75 @@ def _run_single_channel(args: argparse.Namespace): |
|
|
|
|
_cancel_all_rpcs(futures) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _run(args: argparse.Namespace) -> None: |
|
|
|
|
class _MethodHandle: |
|
|
|
|
"""An object grouping together threads driving RPCs for a method.""" |
|
|
|
|
|
|
|
|
|
_channel_threads: List[threading.Thread] |
|
|
|
|
|
|
|
|
|
def __init__(self, method: str, metadata: Sequence[Tuple[str, str]], |
|
|
|
|
num_channels: int, qps: int, server: str, rpc_timeout_sec: int, |
|
|
|
|
print_response: bool): |
|
|
|
|
"""Creates and starts a group of threads running the indicated method.""" |
|
|
|
|
self._channel_threads = [] |
|
|
|
|
for i in range(num_channels): |
|
|
|
|
thread = threading.Thread(target=_run_single_channel, |
|
|
|
|
args=( |
|
|
|
|
method, |
|
|
|
|
metadata, |
|
|
|
|
qps, |
|
|
|
|
server, |
|
|
|
|
rpc_timeout_sec, |
|
|
|
|
print_response, |
|
|
|
|
)) |
|
|
|
|
thread.start() |
|
|
|
|
self._channel_threads.append(thread) |
|
|
|
|
|
|
|
|
|
def stop(self): |
|
|
|
|
"""Joins all threads referenced by the handle.""" |
|
|
|
|
for channel_thread in self._channel_threads: |
|
|
|
|
channel_thread.join() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _run(args: argparse.Namespace, methods: Sequence[str], |
|
|
|
|
per_method_metadata: PerMethodMetadataType) -> None: |
|
|
|
|
logger.info("Starting python xDS Interop Client.") |
|
|
|
|
global _global_server # pylint: disable=global-statement |
|
|
|
|
channel_threads: List[threading.Thread] = [] |
|
|
|
|
for i in range(args.num_channels): |
|
|
|
|
thread = threading.Thread(target=_run_single_channel, args=(args,)) |
|
|
|
|
thread.start() |
|
|
|
|
channel_threads.append(thread) |
|
|
|
|
method_handles = [] |
|
|
|
|
for method in methods: |
|
|
|
|
method_handles.append( |
|
|
|
|
_MethodHandle(method, per_method_metadata.get(method, []), |
|
|
|
|
args.num_channels, args.qps, args.server, |
|
|
|
|
args.rpc_timeout_sec, args.print_response)) |
|
|
|
|
_global_server = grpc.server(futures.ThreadPoolExecutor()) |
|
|
|
|
_global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}") |
|
|
|
|
test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server( |
|
|
|
|
_LoadBalancerStatsServicer(), _global_server) |
|
|
|
|
_global_server.start() |
|
|
|
|
_global_server.wait_for_termination() |
|
|
|
|
for i in range(args.num_channels): |
|
|
|
|
thread.join() |
|
|
|
|
for method_handle in method_handles: |
|
|
|
|
method_handle.stop() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_metadata_arg(metadata_arg: str) -> PerMethodMetadataType: |
|
|
|
|
metadata = metadata_arg.split(",") if args.metadata else [] |
|
|
|
|
per_method_metadata = collections.defaultdict(list) |
|
|
|
|
for metadatum in metadata: |
|
|
|
|
elems = metadatum.split(":") |
|
|
|
|
if len(elems) != 3: |
|
|
|
|
raise ValueError( |
|
|
|
|
f"'{metadatum}' was not in the form 'METHOD:KEY:VALUE'") |
|
|
|
|
if elems[0] not in _SUPPORTED_METHODS: |
|
|
|
|
raise ValueError(f"Unrecognized method '{elems[0]}'") |
|
|
|
|
per_method_metadata[elems[0]].append((elems[1], elems[2])) |
|
|
|
|
return per_method_metadata |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_rpc_arg(rpc_arg: str) -> Sequence[str]: |
|
|
|
|
methods = rpc_arg.split(",") |
|
|
|
|
if set(methods) - set(_SUPPORTED_METHODS): |
|
|
|
|
raise ValueError("--rpc supported methods: {}".format( |
|
|
|
|
", ".join(_SUPPORTED_METHODS))) |
|
|
|
|
return methods |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
@ -243,6 +328,15 @@ if __name__ == "__main__": |
|
|
|
|
default=None, |
|
|
|
|
type=str, |
|
|
|
|
help="A file to log to.") |
|
|
|
|
rpc_help = "A comma-delimited list of RPC methods to run. Must be one of " |
|
|
|
|
rpc_help += ", ".join(_SUPPORTED_METHODS) |
|
|
|
|
rpc_help += "." |
|
|
|
|
parser.add_argument("--rpc", default="UnaryCall", type=str, help=rpc_help) |
|
|
|
|
metadata_help = ( |
|
|
|
|
"A comma-delimited list of 3-tuples of the form " + |
|
|
|
|
"METHOD:KEY:VALUE, e.g. " + |
|
|
|
|
"EmptyCall:key1:value1,UnaryCall:key2:value2,EmptyCall:k3:v3") |
|
|
|
|
parser.add_argument("--metadata", default="", type=str, help=metadata_help) |
|
|
|
|
args = parser.parse_args() |
|
|
|
|
signal.signal(signal.SIGINT, _handle_sigint) |
|
|
|
|
if args.verbose: |
|
|
|
@ -251,4 +345,4 @@ if __name__ == "__main__": |
|
|
|
|
file_handler = logging.FileHandler(args.log_file, mode='a') |
|
|
|
|
file_handler.setFormatter(formatter) |
|
|
|
|
logger.addHandler(file_handler) |
|
|
|
|
_run(args) |
|
|
|
|
_run(args, parse_rpc_arg(args.rpc), parse_metadata_arg(args.metadata)) |
|
|
|
|