Merge branch 'shush-chaos' into shush-tsan

pull/35650/head
Craig Tiller 10 months ago
commit fd55ea1be3
  1. 111
      src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py

@ -14,16 +14,25 @@
import argparse
import collections
from concurrent import futures
import concurrent.futures
import datetime
import logging
import signal
import sys
import threading
import time
from typing import DefaultDict, Dict, List, Mapping, Sequence, Set, Tuple
from typing import (
DefaultDict,
Dict,
Iterable,
List,
Mapping,
Sequence,
Set,
Tuple,
)
import grpc
from grpc import _typing as grpc_typing
import grpc_admin
from grpc_channelz.v1 import channelz
@ -57,6 +66,12 @@ _METHOD_ENUM_TO_STR = {v: k for k, v in _METHOD_STR_TO_ENUM.items()}
PerMethodMetadataType = Mapping[str, Sequence[Tuple[str, str]]]
# FutureFromCall is both a grpc.Call and grpc.Future
class FutureFromCallType(grpc.Call, grpc.Future):
pass
_CONFIG_CHANGE_TIMEOUT = datetime.timedelta(milliseconds=500)
@ -69,8 +84,13 @@ class _StatsWatcher:
_no_remote_peer: int
_lock: threading.Lock
_condition: threading.Condition
_metadata_keys: frozenset
_include_all_metadata: bool
_metadata_by_peer: DefaultDict[
str, messages_pb2.LoadBalancerStatsResponse.MetadataByPeer
]
def __init__(self, start: int, end: int):
def __init__(self, start: int, end: int, metadata_keys: Iterable[str]):
self._start = start
self._end = end
self._rpcs_needed = end - start
@ -80,8 +100,44 @@ class _StatsWatcher:
)
self._condition = threading.Condition()
self._no_remote_peer = 0
self._metadata_keys = frozenset(
self._sanitize_metadata_key(key) for key in metadata_keys
)
self._include_all_metadata = "*" in self._metadata_keys
self._metadata_by_peer = collections.defaultdict(
messages_pb2.LoadBalancerStatsResponse.MetadataByPeer
)
@classmethod
def _sanitize_metadata_key(cls, metadata_key: str) -> str:
return metadata_key.strip().lower()
def on_rpc_complete(self, request_id: int, peer: str, method: str) -> None:
def _add_metadata(
self,
rpc_metadata: messages_pb2.LoadBalancerStatsResponse.RpcMetadata,
metadata_to_add: grpc_typing.MetadataType,
metadata_type: messages_pb2.LoadBalancerStatsResponse.MetadataType,
) -> None:
for key, value in metadata_to_add:
if (
self._include_all_metadata
or self._sanitize_metadata_key(key) in self._metadata_keys
):
rpc_metadata.metadata.append(
messages_pb2.LoadBalancerStatsResponse.MetadataEntry(
key=key, value=value, type=metadata_type
)
)
def on_rpc_complete(
self,
request_id: int,
peer: str,
method: str,
*,
initial_metadata: grpc_typing.MetadataType,
trailing_metadata: grpc_typing.MetadataType,
) -> None:
"""Records statistics for a single RPC."""
if self._start <= request_id < self._end:
with self._condition:
@ -90,6 +146,23 @@ class _StatsWatcher:
else:
self._rpcs_by_peer[peer] += 1
self._rpcs_by_method[method][peer] += 1
if self._metadata_keys:
rpc_metadata = (
messages_pb2.LoadBalancerStatsResponse.RpcMetadata()
)
self._add_metadata(
rpc_metadata,
initial_metadata,
messages_pb2.LoadBalancerStatsResponse.MetadataType.INITIAL,
)
self._add_metadata(
rpc_metadata,
trailing_metadata,
messages_pb2.LoadBalancerStatsResponse.MetadataType.TRAILING,
)
self._metadata_by_peer[peer].rpc_metadata.append(
rpc_metadata
)
self._rpcs_needed -= 1
self._condition.notify()
@ -107,6 +180,8 @@ class _StatsWatcher:
for method, count_by_peer in self._rpcs_by_method.items():
for peer, count in count_by_peer.items():
response.rpcs_by_method[method].rpcs_by_peer[peer] = count
for peer, metadata_by_peer in self._metadata_by_peer.items():
response.metadatas_by_peer[peer].CopyFrom(metadata_by_peer)
response.num_failures = self._no_remote_peer + self._rpcs_needed
return response
@ -150,7 +225,7 @@ class _LoadBalancerStatsServicer(
with _global_lock:
start = _global_rpc_id + 1
end = start + request.num_rpcs
watcher = _StatsWatcher(start, end)
watcher = _StatsWatcher(start, end, request.metadata_keys)
_watchers.add(watcher)
response = watcher.await_rpc_stats_response(request.timeout_sec)
with _global_lock:
@ -192,7 +267,7 @@ def _start_rpc(
request_id: int,
stub: test_pb2_grpc.TestServiceStub,
timeout: float,
futures: Mapping[int, Tuple[grpc.Future, str]],
futures: Mapping[int, Tuple[FutureFromCallType, str]],
) -> None:
logger.debug(f"Sending {method} request to backend: {request_id}")
if method == "UnaryCall":
@ -209,7 +284,7 @@ def _start_rpc(
def _on_rpc_done(
rpc_id: int, future: grpc.Future, method: str, print_response: bool
rpc_id: int, future: FutureFromCallType, method: str, print_response: bool
) -> None:
exception = future.exception()
hostname = ""
@ -241,23 +316,29 @@ def _on_rpc_done(
if future.code() == grpc.StatusCode.OK:
logger.debug("Successful response.")
else:
logger.debug(f"RPC failed: {call}")
logger.debug(f"RPC failed: {rpc_id}")
with _global_lock:
for watcher in _watchers:
watcher.on_rpc_complete(rpc_id, hostname, method)
watcher.on_rpc_complete(
rpc_id,
hostname,
method,
initial_metadata=future.initial_metadata(),
trailing_metadata=future.trailing_metadata(),
)
def _remove_completed_rpcs(
futures: Mapping[int, grpc.Future], print_response: bool
rpc_futures: Mapping[int, FutureFromCallType], print_response: bool
) -> None:
logger.debug("Removing completed RPCs")
done = []
for future_id, (future, method) in futures.items():
for future_id, (future, method) in rpc_futures.items():
if future.done():
_on_rpc_done(future_id, future, method, args.print_response)
done.append(future_id)
for rpc_id in done:
del futures[rpc_id]
del rpc_futures[rpc_id]
def _cancel_all_rpcs(futures: Mapping[int, Tuple[grpc.Future, str]]) -> None:
@ -309,7 +390,7 @@ def _run_single_channel(config: _ChannelConfiguration) -> None:
channel = grpc.insecure_channel(server)
with channel:
stub = test_pb2_grpc.TestServiceStub(channel)
futures: Dict[int, Tuple[grpc.Future, str]] = {}
futures: Dict[int, Tuple[FutureFromCallType, str]] = {}
while not _stop_event.is_set():
with config.condition:
if config.qps == 0:
@ -438,7 +519,7 @@ def _run(
)
channel_configs[method] = channel_config
method_handles.append(_MethodHandle(args.num_channels, channel_config))
_global_server = grpc.server(futures.ThreadPoolExecutor())
_global_server = grpc.server(concurrent.futures.ThreadPoolExecutor())
_global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}")
test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server(
_LoadBalancerStatsServicer(), _global_server

Loading…
Cancel
Save