Working client

pull/22104/head
Richard Belleville 5 years ago
parent f26f80d532
commit 34e320a439
  1. 109
      src/python/grpcio_tests/tests/interop/xds_interop_client.py

@ -13,12 +13,13 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import logging
import signal import signal
import threading import threading
import time import time
import sys import sys
from typing import DefaultDict, List, Set from typing import DefaultDict, Dict, List, Mapping, Set
import collections import collections
from concurrent import futures from concurrent import futures
@ -30,6 +31,16 @@ from src.proto.grpc.testing import test_pb2_grpc
from src.proto.grpc.testing import messages_pb2 from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import empty_pb2 from src.proto.grpc.testing import empty_pb2
logger = logging.getLogger()
console_handler = logging.StreamHandler()
formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)-8s %(message)s')
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# TODO: Make this logfile configurable.
file_handler = logging.FileHandler('/tmp/python_xds_interop_client.log', mode='a')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# TODO: Back with a LoadBalancerStatsResponse proto? # TODO: Back with a LoadBalancerStatsResponse proto?
class _StatsWatcher: class _StatsWatcher:
@ -64,13 +75,17 @@ class _StatsWatcher:
def await_rpc_stats_response(self, timeout_sec: int def await_rpc_stats_response(self, timeout_sec: int
) -> messages_pb2.LoadBalancerStatsResponse: ) -> messages_pb2.LoadBalancerStatsResponse:
"""Blocks until a full response has been collected.""" """Blocks until a full response has been collected."""
logger.info("Awaiting RPC stats response")
with self._lock: with self._lock:
logger.debug(f"Waiting for {timeout_sec} on condition variable.")
self._condition.wait_for(lambda: not self._rpcs_needed, self._condition.wait_for(lambda: not self._rpcs_needed,
timeout=float(timeout_sec)) timeout=float(timeout_sec))
logger.debug(f"Waited for {timeout_sec} on condition variable.")
response = messages_pb2.LoadBalancerStatsResponse() response = messages_pb2.LoadBalancerStatsResponse()
for peer, count in self._rpcs_by_peer.items(): for peer, count in self._rpcs_by_peer.items():
response.rpcs_by_peer[peer] = count response.rpcs_by_peer[peer] = count
response.num_failures = self._no_remote_peer + self._rpcs_needed response.num_failures = self._no_remote_peer + self._rpcs_needed
logger.info("Finished awaiting rpc stats response")
return response return response
@ -95,8 +110,7 @@ class _LoadBalancerStatsServicer(test_pb2_grpc.LoadBalancerStatsServiceServicer
def GetClientStats(self, request: messages_pb2.LoadBalancerStatsRequest, def GetClientStats(self, request: messages_pb2.LoadBalancerStatsRequest,
context: grpc.ServicerContext context: grpc.ServicerContext
) -> messages_pb2.LoadBalancerStatsResponse: ) -> messages_pb2.LoadBalancerStatsResponse:
print("Received stats request.") logger.info("Received stats request.")
sys.stdout.flush()
start = None start = None
end = None end = None
watcher = None watcher = None
@ -108,8 +122,62 @@ class _LoadBalancerStatsServicer(test_pb2_grpc.LoadBalancerStatsServiceServicer
response = watcher.await_rpc_stats_response(request.timeout_sec) response = watcher.await_rpc_stats_response(request.timeout_sec)
with _global_lock: with _global_lock:
_watchers.remove(watcher) _watchers.remove(watcher)
logger.info("Returning stats response: {}".format(response))
return response return response
def _start_rpc(request_id: int,
stub: test_pb2_grpc.TestServiceStub,
timeout: float,
futures: Mapping[int, grpc.Future]) -> None:
logger.info(f"[{threading.get_ident()}] Sending request to backend: {request_id}")
future = stub.UnaryCall.future(messages_pb2.SimpleRequest(),
timeout=timeout)
futures[request_id] = future
def _on_rpc_done(rpc_id: int,
future: grpc.Future,
print_response: bool) -> None:
exception = future.exception()
hostname = ""
if exception is not None:
if exception.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
logger.error(f"RPC {rpc_id} timed out")
else:
logger.error(exception)
else:
response = future.result()
logger.info(f"Got result {rpc_id}")
hostname = response.hostname
if print_response:
if future.code() == grpc.StatusCode.OK:
logger.info("Successful response.")
else:
logger.info(f"RPC failed: {call}")
with _global_lock:
for watcher in _watchers:
watcher.on_rpc_complete(rpc_id, hostname)
def _remove_completed_rpcs(futures: Mapping[int, grpc.Future],
print_response: bool) -> None:
logger.debug("Removing completed RPCs")
done = []
for future_id, future in futures.items():
if future.done():
logger.debug("Calling _on_rpc_done")
_on_rpc_done(future_id, future, args.print_response)
logger.debug("Called _on_rpc_done")
done.append(future_id)
for rpc_id in done:
del futures[rpc_id]
logger.debug("Removed completed RPCs")
def _cancel_all_rpcs(futures: Mapping[int, grpc.Future]) -> None:
logger.info("Cancelling all remaining RPCs")
for future in futures.values():
future.cancel()
# TODO: Accept finer-grained arguments. # TODO: Accept finer-grained arguments.
def _run_single_channel(args: argparse.Namespace): def _run_single_channel(args: argparse.Namespace):
@ -117,45 +185,28 @@ def _run_single_channel(args: argparse.Namespace):
duration_per_query = 1.0 / float(args.qps) duration_per_query = 1.0 / float(args.qps)
with grpc.insecure_channel(args.server) as channel: with grpc.insecure_channel(args.server) as channel:
stub = test_pb2_grpc.TestServiceStub(channel) stub = test_pb2_grpc.TestServiceStub(channel)
futures: Dict[int, grpc.Future] = {}
while not _stop_event.is_set(): while not _stop_event.is_set():
request_id = None request_id = None
with _global_lock: with _global_lock:
request_id = _global_rpc_id request_id = _global_rpc_id
_global_rpc_id += 1 _global_rpc_id += 1
print(f"[{threading.get_ident()}] Sending request to backend: {request_id}")
sys.stdout.flush()
start = time.time() start = time.time()
end = start + duration_per_query end = start + duration_per_query
try: _start_rpc(request_id, stub, float(args.rpc_timeout_sec), futures)
response, call = stub.UnaryCall.with_call(messages_pb2.SimpleRequest(), # TODO: Complete RPCs more frequently than 1 / QPS?
timeout=float( _remove_completed_rpcs(futures, args.print_response)
args.rpc_timeout_sec)) logger.debug(f"Currently {len(futures)} in-flight RPCs")
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
print(f"RPC timed out after {args.rpc_timeout_sec}")
else:
raise
else:
print(f"Got result {request_id}")
sys.stdout.flush()
with _global_lock:
for watcher in _watchers:
watcher.on_rpc_complete(request_id, response.hostname)
if args.print_response:
if call.code() == grpc.StatusCode.OK:
print("Successful response.")
sys.stdout.flush()
else:
print(f"RPC failed: {call}")
sys.stdout.flush()
now = time.time() now = time.time()
while now < end: while now < end:
time.sleep(end - now) time.sleep(end - now)
now = time.time() now = time.time()
_cancel_all_rpcs(futures)
# TODO: Accept finer-grained arguments. # TODO: Accept finer-grained arguments.
def _run(args: argparse.Namespace) -> None: def _run(args: argparse.Namespace) -> None:
logger.info("Starting python xDS Interop Client.")
global _global_server # pylint: disable=global-statement global _global_server # pylint: disable=global-statement
channel_threads: List[threading.Thread] = [] channel_threads: List[threading.Thread] = []
for i in range(args.num_channels): for i in range(args.num_channels):
@ -190,7 +241,7 @@ if __name__ == "__main__":
type=int, type=int,
help="The number of queries to send from each channel per second.") help="The number of queries to send from each channel per second.")
parser.add_argument("--rpc_timeout_sec", parser.add_argument("--rpc_timeout_sec",
default=10, default=30,
type=int, type=int,
help="The per-RPC timeout in seconds.") help="The per-RPC timeout in seconds.")
parser.add_argument("--server", parser.add_argument("--server",
@ -203,4 +254,6 @@ if __name__ == "__main__":
help="The port on which to expose the peer distribution stats service.") help="The port on which to expose the peer distribution stats service.")
args = parser.parse_args() args = parser.parse_args()
signal.signal(signal.SIGINT, _handle_sigint) signal.signal(signal.SIGINT, _handle_sigint)
logger.setLevel(logging.DEBUG)
# logging.basicConfig(level=logging.INFO, stream=sys.stderr)
_run(args) _run(args)

Loading…
Cancel
Save