Working client

pull/22104/head
Richard Belleville 5 years ago
parent f26f80d532
commit 34e320a439
  1. 109
      src/python/grpcio_tests/tests/interop/xds_interop_client.py

@ -13,12 +13,13 @@
# limitations under the License.
import argparse
import logging
import signal
import threading
import time
import sys
from typing import DefaultDict, List, Set
from typing import DefaultDict, Dict, List, Mapping, Set
import collections
from concurrent import futures
@ -30,6 +31,16 @@ from src.proto.grpc.testing import test_pb2_grpc
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import empty_pb2
logger = logging.getLogger()
console_handler = logging.StreamHandler()
formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)-8s %(message)s')
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# TODO: Make this logfile configurable.
file_handler = logging.FileHandler('/tmp/python_xds_interop_client.log', mode='a')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# TODO: Back with a LoadBalancerStatsResponse proto?
class _StatsWatcher:
@ -64,13 +75,17 @@ class _StatsWatcher:
def await_rpc_stats_response(self, timeout_sec: int
) -> messages_pb2.LoadBalancerStatsResponse:
"""Blocks until a full response has been collected."""
logger.info("Awaiting RPC stats response")
with self._lock:
logger.debug(f"Waiting for {timeout_sec} on condition variable.")
self._condition.wait_for(lambda: not self._rpcs_needed,
timeout=float(timeout_sec))
logger.debug(f"Waited for {timeout_sec} on condition variable.")
response = messages_pb2.LoadBalancerStatsResponse()
for peer, count in self._rpcs_by_peer.items():
response.rpcs_by_peer[peer] = count
response.num_failures = self._no_remote_peer + self._rpcs_needed
logger.info("Finished awaiting rpc stats response")
return response
@ -95,8 +110,7 @@ class _LoadBalancerStatsServicer(test_pb2_grpc.LoadBalancerStatsServiceServicer
def GetClientStats(self, request: messages_pb2.LoadBalancerStatsRequest,
context: grpc.ServicerContext
) -> messages_pb2.LoadBalancerStatsResponse:
print("Received stats request.")
sys.stdout.flush()
logger.info("Received stats request.")
start = None
end = None
watcher = None
@ -108,8 +122,62 @@ class _LoadBalancerStatsServicer(test_pb2_grpc.LoadBalancerStatsServiceServicer
response = watcher.await_rpc_stats_response(request.timeout_sec)
with _global_lock:
_watchers.remove(watcher)
logger.info("Returning stats response: {}".format(response))
return response
def _start_rpc(request_id: int,
stub: test_pb2_grpc.TestServiceStub,
timeout: float,
futures: Mapping[int, grpc.Future]) -> None:
logger.info(f"[{threading.get_ident()}] Sending request to backend: {request_id}")
future = stub.UnaryCall.future(messages_pb2.SimpleRequest(),
timeout=timeout)
futures[request_id] = future
def _on_rpc_done(rpc_id: int,
future: grpc.Future,
print_response: bool) -> None:
exception = future.exception()
hostname = ""
if exception is not None:
if exception.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
logger.error(f"RPC {rpc_id} timed out")
else:
logger.error(exception)
else:
response = future.result()
logger.info(f"Got result {rpc_id}")
hostname = response.hostname
if print_response:
if future.code() == grpc.StatusCode.OK:
logger.info("Successful response.")
else:
logger.info(f"RPC failed: {call}")
with _global_lock:
for watcher in _watchers:
watcher.on_rpc_complete(rpc_id, hostname)
def _remove_completed_rpcs(futures: Mapping[int, grpc.Future],
print_response: bool) -> None:
logger.debug("Removing completed RPCs")
done = []
for future_id, future in futures.items():
if future.done():
logger.debug("Calling _on_rpc_done")
_on_rpc_done(future_id, future, args.print_response)
logger.debug("Called _on_rpc_done")
done.append(future_id)
for rpc_id in done:
del futures[rpc_id]
logger.debug("Removed completed RPCs")
def _cancel_all_rpcs(futures: Mapping[int, grpc.Future]) -> None:
logger.info("Cancelling all remaining RPCs")
for future in futures.values():
future.cancel()
# TODO: Accept finer-grained arguments.
def _run_single_channel(args: argparse.Namespace):
@ -117,45 +185,28 @@ def _run_single_channel(args: argparse.Namespace):
duration_per_query = 1.0 / float(args.qps)
with grpc.insecure_channel(args.server) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
futures: Dict[int, grpc.Future] = {}
while not _stop_event.is_set():
request_id = None
with _global_lock:
request_id = _global_rpc_id
_global_rpc_id += 1
print(f"[{threading.get_ident()}] Sending request to backend: {request_id}")
sys.stdout.flush()
start = time.time()
end = start + duration_per_query
try:
response, call = stub.UnaryCall.with_call(messages_pb2.SimpleRequest(),
timeout=float(
args.rpc_timeout_sec))
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
print(f"RPC timed out after {args.rpc_timeout_sec}")
else:
raise
else:
print(f"Got result {request_id}")
sys.stdout.flush()
with _global_lock:
for watcher in _watchers:
watcher.on_rpc_complete(request_id, response.hostname)
if args.print_response:
if call.code() == grpc.StatusCode.OK:
print("Successful response.")
sys.stdout.flush()
else:
print(f"RPC failed: {call}")
sys.stdout.flush()
_start_rpc(request_id, stub, float(args.rpc_timeout_sec), futures)
# TODO: Complete RPCs more frequently than 1 / QPS?
_remove_completed_rpcs(futures, args.print_response)
logger.debug(f"Currently {len(futures)} in-flight RPCs")
now = time.time()
while now < end:
time.sleep(end - now)
now = time.time()
_cancel_all_rpcs(futures)
# TODO: Accept finer-grained arguments.
def _run(args: argparse.Namespace) -> None:
logger.info("Starting python xDS Interop Client.")
global _global_server # pylint: disable=global-statement
channel_threads: List[threading.Thread] = []
for i in range(args.num_channels):
@ -190,7 +241,7 @@ if __name__ == "__main__":
type=int,
help="The number of queries to send from each channel per second.")
parser.add_argument("--rpc_timeout_sec",
default=10,
default=30,
type=int,
help="The per-RPC timeout in seconds.")
parser.add_argument("--server",
@ -203,4 +254,6 @@ if __name__ == "__main__":
help="The port on which to expose the peer distribution stats service.")
args = parser.parse_args()
signal.signal(signal.SIGINT, _handle_sigint)
logger.setLevel(logging.DEBUG)
# logging.basicConfig(level=logging.INFO, stream=sys.stderr)
_run(args)

Loading…
Cancel
Save