Merge branch 'shush-chaos' into shush-tsan

pull/35650/head
Craig Tiller 10 months ago
commit fd55ea1be3
  1. 111
      src/python/grpcio_tests/tests_py3_only/interop/xds_interop_client.py

@ -14,16 +14,25 @@
import argparse import argparse
import collections import collections
from concurrent import futures import concurrent.futures
import datetime import datetime
import logging import logging
import signal import signal
import sys
import threading import threading
import time import time
from typing import DefaultDict, Dict, List, Mapping, Sequence, Set, Tuple from typing import (
DefaultDict,
Dict,
Iterable,
List,
Mapping,
Sequence,
Set,
Tuple,
)
import grpc import grpc
from grpc import _typing as grpc_typing
import grpc_admin import grpc_admin
from grpc_channelz.v1 import channelz from grpc_channelz.v1 import channelz
@ -57,6 +66,12 @@ _METHOD_ENUM_TO_STR = {v: k for k, v in _METHOD_STR_TO_ENUM.items()}
PerMethodMetadataType = Mapping[str, Sequence[Tuple[str, str]]] PerMethodMetadataType = Mapping[str, Sequence[Tuple[str, str]]]
# FutureFromCall is both a grpc.Call and grpc.Future
class FutureFromCallType(grpc.Call, grpc.Future):
pass
_CONFIG_CHANGE_TIMEOUT = datetime.timedelta(milliseconds=500) _CONFIG_CHANGE_TIMEOUT = datetime.timedelta(milliseconds=500)
@ -69,8 +84,13 @@ class _StatsWatcher:
_no_remote_peer: int _no_remote_peer: int
_lock: threading.Lock _lock: threading.Lock
_condition: threading.Condition _condition: threading.Condition
_metadata_keys: frozenset
_include_all_metadata: bool
_metadata_by_peer: DefaultDict[
str, messages_pb2.LoadBalancerStatsResponse.MetadataByPeer
]
def __init__(self, start: int, end: int): def __init__(self, start: int, end: int, metadata_keys: Iterable[str]):
self._start = start self._start = start
self._end = end self._end = end
self._rpcs_needed = end - start self._rpcs_needed = end - start
@ -80,8 +100,44 @@ class _StatsWatcher:
) )
self._condition = threading.Condition() self._condition = threading.Condition()
self._no_remote_peer = 0 self._no_remote_peer = 0
self._metadata_keys = frozenset(
self._sanitize_metadata_key(key) for key in metadata_keys
)
self._include_all_metadata = "*" in self._metadata_keys
self._metadata_by_peer = collections.defaultdict(
messages_pb2.LoadBalancerStatsResponse.MetadataByPeer
)
@classmethod
def _sanitize_metadata_key(cls, metadata_key: str) -> str:
return metadata_key.strip().lower()
def on_rpc_complete(self, request_id: int, peer: str, method: str) -> None: def _add_metadata(
self,
rpc_metadata: messages_pb2.LoadBalancerStatsResponse.RpcMetadata,
metadata_to_add: grpc_typing.MetadataType,
metadata_type: messages_pb2.LoadBalancerStatsResponse.MetadataType,
) -> None:
for key, value in metadata_to_add:
if (
self._include_all_metadata
or self._sanitize_metadata_key(key) in self._metadata_keys
):
rpc_metadata.metadata.append(
messages_pb2.LoadBalancerStatsResponse.MetadataEntry(
key=key, value=value, type=metadata_type
)
)
def on_rpc_complete(
self,
request_id: int,
peer: str,
method: str,
*,
initial_metadata: grpc_typing.MetadataType,
trailing_metadata: grpc_typing.MetadataType,
) -> None:
"""Records statistics for a single RPC.""" """Records statistics for a single RPC."""
if self._start <= request_id < self._end: if self._start <= request_id < self._end:
with self._condition: with self._condition:
@ -90,6 +146,23 @@ class _StatsWatcher:
else: else:
self._rpcs_by_peer[peer] += 1 self._rpcs_by_peer[peer] += 1
self._rpcs_by_method[method][peer] += 1 self._rpcs_by_method[method][peer] += 1
if self._metadata_keys:
rpc_metadata = (
messages_pb2.LoadBalancerStatsResponse.RpcMetadata()
)
self._add_metadata(
rpc_metadata,
initial_metadata,
messages_pb2.LoadBalancerStatsResponse.MetadataType.INITIAL,
)
self._add_metadata(
rpc_metadata,
trailing_metadata,
messages_pb2.LoadBalancerStatsResponse.MetadataType.TRAILING,
)
self._metadata_by_peer[peer].rpc_metadata.append(
rpc_metadata
)
self._rpcs_needed -= 1 self._rpcs_needed -= 1
self._condition.notify() self._condition.notify()
@ -107,6 +180,8 @@ class _StatsWatcher:
for method, count_by_peer in self._rpcs_by_method.items(): for method, count_by_peer in self._rpcs_by_method.items():
for peer, count in count_by_peer.items(): for peer, count in count_by_peer.items():
response.rpcs_by_method[method].rpcs_by_peer[peer] = count response.rpcs_by_method[method].rpcs_by_peer[peer] = count
for peer, metadata_by_peer in self._metadata_by_peer.items():
response.metadatas_by_peer[peer].CopyFrom(metadata_by_peer)
response.num_failures = self._no_remote_peer + self._rpcs_needed response.num_failures = self._no_remote_peer + self._rpcs_needed
return response return response
@ -150,7 +225,7 @@ class _LoadBalancerStatsServicer(
with _global_lock: with _global_lock:
start = _global_rpc_id + 1 start = _global_rpc_id + 1
end = start + request.num_rpcs end = start + request.num_rpcs
watcher = _StatsWatcher(start, end) watcher = _StatsWatcher(start, end, request.metadata_keys)
_watchers.add(watcher) _watchers.add(watcher)
response = watcher.await_rpc_stats_response(request.timeout_sec) response = watcher.await_rpc_stats_response(request.timeout_sec)
with _global_lock: with _global_lock:
@ -192,7 +267,7 @@ def _start_rpc(
request_id: int, request_id: int,
stub: test_pb2_grpc.TestServiceStub, stub: test_pb2_grpc.TestServiceStub,
timeout: float, timeout: float,
futures: Mapping[int, Tuple[grpc.Future, str]], futures: Mapping[int, Tuple[FutureFromCallType, str]],
) -> None: ) -> None:
logger.debug(f"Sending {method} request to backend: {request_id}") logger.debug(f"Sending {method} request to backend: {request_id}")
if method == "UnaryCall": if method == "UnaryCall":
@ -209,7 +284,7 @@ def _start_rpc(
def _on_rpc_done( def _on_rpc_done(
rpc_id: int, future: grpc.Future, method: str, print_response: bool rpc_id: int, future: FutureFromCallType, method: str, print_response: bool
) -> None: ) -> None:
exception = future.exception() exception = future.exception()
hostname = "" hostname = ""
@ -241,23 +316,29 @@ def _on_rpc_done(
if future.code() == grpc.StatusCode.OK: if future.code() == grpc.StatusCode.OK:
logger.debug("Successful response.") logger.debug("Successful response.")
else: else:
logger.debug(f"RPC failed: {call}") logger.debug(f"RPC failed: {rpc_id}")
with _global_lock: with _global_lock:
for watcher in _watchers: for watcher in _watchers:
watcher.on_rpc_complete(rpc_id, hostname, method) watcher.on_rpc_complete(
rpc_id,
hostname,
method,
initial_metadata=future.initial_metadata(),
trailing_metadata=future.trailing_metadata(),
)
def _remove_completed_rpcs( def _remove_completed_rpcs(
futures: Mapping[int, grpc.Future], print_response: bool rpc_futures: Mapping[int, FutureFromCallType], print_response: bool
) -> None: ) -> None:
logger.debug("Removing completed RPCs") logger.debug("Removing completed RPCs")
done = [] done = []
for future_id, (future, method) in futures.items(): for future_id, (future, method) in rpc_futures.items():
if future.done(): if future.done():
_on_rpc_done(future_id, future, method, args.print_response) _on_rpc_done(future_id, future, method, args.print_response)
done.append(future_id) done.append(future_id)
for rpc_id in done: for rpc_id in done:
del futures[rpc_id] del rpc_futures[rpc_id]
def _cancel_all_rpcs(futures: Mapping[int, Tuple[grpc.Future, str]]) -> None: def _cancel_all_rpcs(futures: Mapping[int, Tuple[grpc.Future, str]]) -> None:
@ -309,7 +390,7 @@ def _run_single_channel(config: _ChannelConfiguration) -> None:
channel = grpc.insecure_channel(server) channel = grpc.insecure_channel(server)
with channel: with channel:
stub = test_pb2_grpc.TestServiceStub(channel) stub = test_pb2_grpc.TestServiceStub(channel)
futures: Dict[int, Tuple[grpc.Future, str]] = {} futures: Dict[int, Tuple[FutureFromCallType, str]] = {}
while not _stop_event.is_set(): while not _stop_event.is_set():
with config.condition: with config.condition:
if config.qps == 0: if config.qps == 0:
@ -438,7 +519,7 @@ def _run(
) )
channel_configs[method] = channel_config channel_configs[method] = channel_config
method_handles.append(_MethodHandle(args.num_channels, channel_config)) method_handles.append(_MethodHandle(args.num_channels, channel_config))
_global_server = grpc.server(futures.ThreadPoolExecutor()) _global_server = grpc.server(concurrent.futures.ThreadPoolExecutor())
_global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}") _global_server.add_insecure_port(f"0.0.0.0:{args.stats_port}")
test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server( test_pb2_grpc.add_LoadBalancerStatsServiceServicer_to_server(
_LoadBalancerStatsServicer(), _global_server _LoadBalancerStatsServicer(), _global_server

Loading…
Cancel
Save