|
|
@ -13,12 +13,13 @@ |
|
|
|
# limitations under the License. |
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
|
import argparse |
|
|
|
|
|
|
|
import logging |
|
|
|
import signal |
|
|
|
import signal |
|
|
|
import threading |
|
|
|
import threading |
|
|
|
import time |
|
|
|
import time |
|
|
|
import sys |
|
|
|
import sys |
|
|
|
|
|
|
|
|
|
|
|
from typing import DefaultDict, List, Set |
|
|
|
from typing import DefaultDict, Dict, List, Mapping, Set |
|
|
|
import collections |
|
|
|
import collections |
|
|
|
|
|
|
|
|
|
|
|
from concurrent import futures |
|
|
|
from concurrent import futures |
|
|
@ -30,6 +31,16 @@ from src.proto.grpc.testing import test_pb2_grpc |
|
|
|
from src.proto.grpc.testing import messages_pb2 |
|
|
|
from src.proto.grpc.testing import messages_pb2 |
|
|
|
from src.proto.grpc.testing import empty_pb2 |
|
|
|
from src.proto.grpc.testing import empty_pb2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger() |
|
|
|
|
|
|
|
console_handler = logging.StreamHandler() |
|
|
|
|
|
|
|
formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)-8s %(message)s') |
|
|
|
|
|
|
|
console_handler.setFormatter(formatter) |
|
|
|
|
|
|
|
logger.addHandler(console_handler) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: Make this logfile configurable. |
|
|
|
|
|
|
|
file_handler = logging.FileHandler('/tmp/python_xds_interop_client.log', mode='a') |
|
|
|
|
|
|
|
file_handler.setFormatter(formatter) |
|
|
|
|
|
|
|
logger.addHandler(file_handler) |
|
|
|
|
|
|
|
|
|
|
|
# TODO: Back with a LoadBalancerStatsResponse proto? |
|
|
|
# TODO: Back with a LoadBalancerStatsResponse proto? |
|
|
|
class _StatsWatcher: |
|
|
|
class _StatsWatcher: |
|
|
@ -64,13 +75,17 @@ class _StatsWatcher: |
|
|
|
def await_rpc_stats_response(self, timeout_sec: int |
|
|
|
def await_rpc_stats_response(self, timeout_sec: int |
|
|
|
) -> messages_pb2.LoadBalancerStatsResponse: |
|
|
|
) -> messages_pb2.LoadBalancerStatsResponse: |
|
|
|
"""Blocks until a full response has been collected.""" |
|
|
|
"""Blocks until a full response has been collected.""" |
|
|
|
|
|
|
|
logger.info("Awaiting RPC stats response") |
|
|
|
with self._lock: |
|
|
|
with self._lock: |
|
|
|
|
|
|
|
logger.debug(f"Waiting for {timeout_sec} on condition variable.") |
|
|
|
self._condition.wait_for(lambda: not self._rpcs_needed, |
|
|
|
self._condition.wait_for(lambda: not self._rpcs_needed, |
|
|
|
timeout=float(timeout_sec)) |
|
|
|
timeout=float(timeout_sec)) |
|
|
|
|
|
|
|
logger.debug(f"Waited for {timeout_sec} on condition variable.") |
|
|
|
response = messages_pb2.LoadBalancerStatsResponse() |
|
|
|
response = messages_pb2.LoadBalancerStatsResponse() |
|
|
|
for peer, count in self._rpcs_by_peer.items(): |
|
|
|
for peer, count in self._rpcs_by_peer.items(): |
|
|
|
response.rpcs_by_peer[peer] = count |
|
|
|
response.rpcs_by_peer[peer] = count |
|
|
|
response.num_failures = self._no_remote_peer + self._rpcs_needed |
|
|
|
response.num_failures = self._no_remote_peer + self._rpcs_needed |
|
|
|
|
|
|
|
logger.info("Finished awaiting rpc stats response") |
|
|
|
return response |
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -95,8 +110,7 @@ class _LoadBalancerStatsServicer(test_pb2_grpc.LoadBalancerStatsServiceServicer |
|
|
|
def GetClientStats(self, request: messages_pb2.LoadBalancerStatsRequest, |
|
|
|
def GetClientStats(self, request: messages_pb2.LoadBalancerStatsRequest, |
|
|
|
context: grpc.ServicerContext |
|
|
|
context: grpc.ServicerContext |
|
|
|
) -> messages_pb2.LoadBalancerStatsResponse: |
|
|
|
) -> messages_pb2.LoadBalancerStatsResponse: |
|
|
|
print("Received stats request.") |
|
|
|
logger.info("Received stats request.") |
|
|
|
sys.stdout.flush() |
|
|
|
|
|
|
|
start = None |
|
|
|
start = None |
|
|
|
end = None |
|
|
|
end = None |
|
|
|
watcher = None |
|
|
|
watcher = None |
|
|
@ -108,8 +122,62 @@ class _LoadBalancerStatsServicer(test_pb2_grpc.LoadBalancerStatsServiceServicer |
|
|
|
response = watcher.await_rpc_stats_response(request.timeout_sec) |
|
|
|
response = watcher.await_rpc_stats_response(request.timeout_sec) |
|
|
|
with _global_lock: |
|
|
|
with _global_lock: |
|
|
|
_watchers.remove(watcher) |
|
|
|
_watchers.remove(watcher) |
|
|
|
|
|
|
|
logger.info("Returning stats response: {}".format(response)) |
|
|
|
return response |
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _start_rpc(request_id: int, |
|
|
|
|
|
|
|
stub: test_pb2_grpc.TestServiceStub, |
|
|
|
|
|
|
|
timeout: float, |
|
|
|
|
|
|
|
futures: Mapping[int, grpc.Future]) -> None: |
|
|
|
|
|
|
|
logger.info(f"[{threading.get_ident()}] Sending request to backend: {request_id}") |
|
|
|
|
|
|
|
future = stub.UnaryCall.future(messages_pb2.SimpleRequest(), |
|
|
|
|
|
|
|
timeout=timeout) |
|
|
|
|
|
|
|
futures[request_id] = future |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _on_rpc_done(rpc_id: int, |
|
|
|
|
|
|
|
future: grpc.Future, |
|
|
|
|
|
|
|
print_response: bool) -> None: |
|
|
|
|
|
|
|
exception = future.exception() |
|
|
|
|
|
|
|
hostname = "" |
|
|
|
|
|
|
|
if exception is not None: |
|
|
|
|
|
|
|
if exception.code() == grpc.StatusCode.DEADLINE_EXCEEDED: |
|
|
|
|
|
|
|
logger.error(f"RPC {rpc_id} timed out") |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
logger.error(exception) |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
response = future.result() |
|
|
|
|
|
|
|
logger.info(f"Got result {rpc_id}") |
|
|
|
|
|
|
|
hostname = response.hostname |
|
|
|
|
|
|
|
if print_response: |
|
|
|
|
|
|
|
if future.code() == grpc.StatusCode.OK: |
|
|
|
|
|
|
|
logger.info("Successful response.") |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
logger.info(f"RPC failed: {call}") |
|
|
|
|
|
|
|
with _global_lock: |
|
|
|
|
|
|
|
for watcher in _watchers: |
|
|
|
|
|
|
|
watcher.on_rpc_complete(rpc_id, hostname) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _remove_completed_rpcs(futures: Mapping[int, grpc.Future], |
|
|
|
|
|
|
|
print_response: bool) -> None: |
|
|
|
|
|
|
|
logger.debug("Removing completed RPCs") |
|
|
|
|
|
|
|
done = [] |
|
|
|
|
|
|
|
for future_id, future in futures.items(): |
|
|
|
|
|
|
|
if future.done(): |
|
|
|
|
|
|
|
logger.debug("Calling _on_rpc_done") |
|
|
|
|
|
|
|
_on_rpc_done(future_id, future, args.print_response) |
|
|
|
|
|
|
|
logger.debug("Called _on_rpc_done") |
|
|
|
|
|
|
|
done.append(future_id) |
|
|
|
|
|
|
|
for rpc_id in done: |
|
|
|
|
|
|
|
del futures[rpc_id] |
|
|
|
|
|
|
|
logger.debug("Removed completed RPCs") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cancel_all_rpcs(futures: Mapping[int, grpc.Future]) -> None: |
|
|
|
|
|
|
|
logger.info("Cancelling all remaining RPCs") |
|
|
|
|
|
|
|
for future in futures.values(): |
|
|
|
|
|
|
|
future.cancel() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: Accept finer-grained arguments. |
|
|
|
# TODO: Accept finer-grained arguments. |
|
|
|
def _run_single_channel(args: argparse.Namespace): |
|
|
|
def _run_single_channel(args: argparse.Namespace): |
|
|
@ -117,45 +185,28 @@ def _run_single_channel(args: argparse.Namespace): |
|
|
|
duration_per_query = 1.0 / float(args.qps) |
|
|
|
duration_per_query = 1.0 / float(args.qps) |
|
|
|
with grpc.insecure_channel(args.server) as channel: |
|
|
|
with grpc.insecure_channel(args.server) as channel: |
|
|
|
stub = test_pb2_grpc.TestServiceStub(channel) |
|
|
|
stub = test_pb2_grpc.TestServiceStub(channel) |
|
|
|
|
|
|
|
futures: Dict[int, grpc.Future] = {} |
|
|
|
while not _stop_event.is_set(): |
|
|
|
while not _stop_event.is_set(): |
|
|
|
request_id = None |
|
|
|
request_id = None |
|
|
|
with _global_lock: |
|
|
|
with _global_lock: |
|
|
|
request_id = _global_rpc_id |
|
|
|
request_id = _global_rpc_id |
|
|
|
_global_rpc_id += 1 |
|
|
|
_global_rpc_id += 1 |
|
|
|
print(f"[{threading.get_ident()}] Sending request to backend: {request_id}") |
|
|
|
|
|
|
|
sys.stdout.flush() |
|
|
|
|
|
|
|
start = time.time() |
|
|
|
start = time.time() |
|
|
|
end = start + duration_per_query |
|
|
|
end = start + duration_per_query |
|
|
|
try: |
|
|
|
_start_rpc(request_id, stub, float(args.rpc_timeout_sec), futures) |
|
|
|
response, call = stub.UnaryCall.with_call(messages_pb2.SimpleRequest(), |
|
|
|
# TODO: Complete RPCs more frequently than 1 / QPS? |
|
|
|
timeout=float( |
|
|
|
_remove_completed_rpcs(futures, args.print_response) |
|
|
|
args.rpc_timeout_sec)) |
|
|
|
logger.debug(f"Currently {len(futures)} in-flight RPCs") |
|
|
|
except grpc.RpcError as e: |
|
|
|
|
|
|
|
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: |
|
|
|
|
|
|
|
print(f"RPC timed out after {args.rpc_timeout_sec}") |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
raise |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
print(f"Got result {request_id}") |
|
|
|
|
|
|
|
sys.stdout.flush() |
|
|
|
|
|
|
|
with _global_lock: |
|
|
|
|
|
|
|
for watcher in _watchers: |
|
|
|
|
|
|
|
watcher.on_rpc_complete(request_id, response.hostname) |
|
|
|
|
|
|
|
if args.print_response: |
|
|
|
|
|
|
|
if call.code() == grpc.StatusCode.OK: |
|
|
|
|
|
|
|
print("Successful response.") |
|
|
|
|
|
|
|
sys.stdout.flush() |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
print(f"RPC failed: {call}") |
|
|
|
|
|
|
|
sys.stdout.flush() |
|
|
|
|
|
|
|
now = time.time() |
|
|
|
now = time.time() |
|
|
|
while now < end: |
|
|
|
while now < end: |
|
|
|
time.sleep(end - now) |
|
|
|
time.sleep(end - now) |
|
|
|
now = time.time() |
|
|
|
now = time.time() |
|
|
|
|
|
|
|
_cancel_all_rpcs(futures) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: Accept finer-grained arguments. |
|
|
|
# TODO: Accept finer-grained arguments. |
|
|
|
def _run(args: argparse.Namespace) -> None: |
|
|
|
def _run(args: argparse.Namespace) -> None: |
|
|
|
|
|
|
|
logger.info("Starting python xDS Interop Client.") |
|
|
|
global _global_server # pylint: disable=global-statement |
|
|
|
global _global_server # pylint: disable=global-statement |
|
|
|
channel_threads: List[threading.Thread] = [] |
|
|
|
channel_threads: List[threading.Thread] = [] |
|
|
|
for i in range(args.num_channels): |
|
|
|
for i in range(args.num_channels): |
|
|
@ -190,7 +241,7 @@ if __name__ == "__main__": |
|
|
|
type=int, |
|
|
|
type=int, |
|
|
|
help="The number of queries to send from each channel per second.") |
|
|
|
help="The number of queries to send from each channel per second.") |
|
|
|
parser.add_argument("--rpc_timeout_sec", |
|
|
|
parser.add_argument("--rpc_timeout_sec", |
|
|
|
default=10, |
|
|
|
default=30, |
|
|
|
type=int, |
|
|
|
type=int, |
|
|
|
help="The per-RPC timeout in seconds.") |
|
|
|
help="The per-RPC timeout in seconds.") |
|
|
|
parser.add_argument("--server", |
|
|
|
parser.add_argument("--server", |
|
|
@ -203,4 +254,6 @@ if __name__ == "__main__": |
|
|
|
help="The port on which to expose the peer distribution stats service.") |
|
|
|
help="The port on which to expose the peer distribution stats service.") |
|
|
|
args = parser.parse_args() |
|
|
|
args = parser.parse_args() |
|
|
|
signal.signal(signal.SIGINT, _handle_sigint) |
|
|
|
signal.signal(signal.SIGINT, _handle_sigint) |
|
|
|
|
|
|
|
logger.setLevel(logging.DEBUG) |
|
|
|
|
|
|
|
# logging.basicConfig(level=logging.INFO, stream=sys.stderr) |
|
|
|
_run(args) |
|
|
|
_run(args) |
|
|
|