Merge pull request #25271 from sergiitk/xds-k8s-adjust-grpc-retry-deadline

xds-k8s driver: wait server channelz - adjust RPC timeouts
pull/25278/head
Sergii Tkachenko 4 years ago committed by GitHub
commit c2d03a691e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 14
      tools/run_tests/xds_k8s_test_driver/framework/helpers/retryers.py
  2. 21
      tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py
  3. 61
      tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py
  4. 2
      tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_testing.py
  5. 50
      tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py

@ -21,13 +21,17 @@ We use tenacity as a general-purpose retrying library.
> - https://tenacity.readthedocs.io/en/latest/index.html
"""
import datetime
import logging
from typing import Any, List, Optional
import tenacity
retryers_logger = logging.getLogger(__name__)
# Type aliases
timedelta = datetime.timedelta
Retrying = tenacity.Retrying
_after_log = tenacity.after_log
_before_sleep_log = tenacity.before_sleep_log
_retry_if_exception_type = tenacity.retry_if_exception_type
_stop_after_delay = tenacity.stop_after_delay
_wait_exponential = tenacity.wait_exponential
@ -45,9 +49,15 @@ def exponential_retryer_with_timeout(
wait_min: timedelta,
wait_max: timedelta,
timeout: timedelta,
retry_on_exceptions: Optional[List[Any]] = None) -> Retrying:
retry_on_exceptions: Optional[List[Any]] = None,
logger: Optional[logging.Logger] = None,
log_level: Optional[int] = logging.DEBUG) -> Retrying:
if logger is None:
logger = retryers_logger
if log_level is None:
log_level = logging.DEBUG
return Retrying(retry=_retry_on_exceptions(retry_on_exceptions),
wait=_wait_exponential(min=wait_min.total_seconds(),
max=wait_max.total_seconds()),
stop=_stop_after_delay(timeout.total_seconds()),
reraise=True)
before_sleep=_before_sleep_log(logger, log_level))

@ -29,8 +29,7 @@ Message = google.protobuf.message.Message
class GrpcClientHelper:
channel: grpc.Channel
DEFAULT_CONNECTION_TIMEOUT_SEC = 60
DEFAULT_WAIT_FOR_READY_SEC = 60
DEFAULT_RPC_DEADLINE_SEC = 90
def __init__(self, channel: grpc.Channel, stub_class: ClassVar):
self.channel = channel
@ -44,20 +43,16 @@ class GrpcClientHelper:
*,
rpc: str,
req: Message,
wait_for_ready_sec: Optional[int] = DEFAULT_WAIT_FOR_READY_SEC,
connection_timeout_sec: Optional[
int] = DEFAULT_CONNECTION_TIMEOUT_SEC,
deadline_sec: Optional[int] = DEFAULT_RPC_DEADLINE_SEC,
log_level: Optional[int] = logging.DEBUG) -> Message:
if wait_for_ready_sec is None:
wait_for_ready_sec = self.DEFAULT_WAIT_FOR_READY_SEC
if connection_timeout_sec is None:
connection_timeout_sec = self.DEFAULT_CONNECTION_TIMEOUT_SEC
if deadline_sec is None:
deadline_sec = self.DEFAULT_RPC_DEADLINE_SEC
timeout_sec = wait_for_ready_sec + connection_timeout_sec
rpc_callable: grpc.UnaryUnaryMultiCallable = getattr(self.stub, rpc)
call_kwargs = dict(wait_for_ready=True, timeout=timeout_sec)
call_kwargs = dict(wait_for_ready=True, timeout=deadline_sec)
self._log_rpc_request(rpc, req, call_kwargs, log_level)
# Call RPC, e.g. RpcStub(channel).RpcMethod(req, ...options)
rpc_callable: grpc.UnaryUnaryMultiCallable = getattr(self.stub, rpc)
return rpc_callable(req, **call_kwargs)
def _log_rpc_request(self, rpc, req, call_kwargs, log_level=logging.DEBUG):

@ -95,22 +95,25 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
return server_socket
return None
def find_channels_for_target(self, target: str) -> Iterator[Channel]:
return (channel for channel in self.list_channels()
def find_channels_for_target(self, target: str,
**kwargs) -> Iterator[Channel]:
return (channel for channel in self.list_channels(**kwargs)
if channel.data.target == target)
def find_server_listening_on_port(self, port: int) -> Optional[Server]:
for server in self.list_servers():
def find_server_listening_on_port(self, port: int,
**kwargs) -> Optional[Server]:
for server in self.list_servers(**kwargs):
listen_socket_ref: SocketRef
for listen_socket_ref in server.listen_socket:
listen_socket = self.get_socket(listen_socket_ref.socket_id)
listen_socket = self.get_socket(listen_socket_ref.socket_id,
**kwargs)
listen_address: Address = listen_socket.local
if (self.is_sock_tcpip_address(listen_address) and
listen_address.tcpip_address.port == port):
return server
return None
def list_channels(self) -> Iterator[Channel]:
def list_channels(self, **kwargs) -> Iterator[Channel]:
"""
Iterate over all pages of all root channels.
@ -125,12 +128,13 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
start += 1
response = self.call_unary_with_deadline(
rpc='GetTopChannels',
req=_GetTopChannelsRequest(start_channel_id=start))
req=_GetTopChannelsRequest(start_channel_id=start),
**kwargs)
for channel in response.channel:
start = max(start, channel.ref.channel_id)
yield channel
def list_servers(self) -> Iterator[Server]:
def list_servers(self, **kwargs) -> Iterator[Server]:
"""Iterate over all pages of all servers that exist in the process."""
start: int = -1
response: Optional[_GetServersResponse] = None
@ -139,12 +143,14 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
# value by adding 1 to the highest seen result ID.
start += 1
response = self.call_unary_with_deadline(
rpc='GetServers', req=_GetServersRequest(start_server_id=start))
rpc='GetServers',
req=_GetServersRequest(start_server_id=start),
**kwargs)
for server in response.server:
start = max(start, server.ref.server_id)
yield server
def list_server_sockets(self, server: Server) -> Iterator[Socket]:
def list_server_sockets(self, server: Server, **kwargs) -> Iterator[Socket]:
"""List all server sockets that exist in server process.
Iterating over the results will resolve additional pages automatically.
@ -158,39 +164,44 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
response = self.call_unary_with_deadline(
rpc='GetServerSockets',
req=_GetServerSocketsRequest(server_id=server.ref.server_id,
start_socket_id=start))
start_socket_id=start),
**kwargs)
socket_ref: SocketRef
for socket_ref in response.socket_ref:
start = max(start, socket_ref.socket_id)
# Yield actual socket
yield self.get_socket(socket_ref.socket_id)
yield self.get_socket(socket_ref.socket_id, **kwargs)
def list_channel_sockets(self, channel: Channel) -> Iterator[Socket]:
def list_channel_sockets(self, channel: Channel,
**kwargs) -> Iterator[Socket]:
"""List all sockets of all subchannels of a given channel."""
for subchannel in self.list_channel_subchannels(channel):
yield from self.list_subchannels_sockets(subchannel)
for subchannel in self.list_channel_subchannels(channel, **kwargs):
yield from self.list_subchannels_sockets(subchannel, **kwargs)
def list_channel_subchannels(self,
channel: Channel) -> Iterator[Subchannel]:
def list_channel_subchannels(self, channel: Channel,
**kwargs) -> Iterator[Subchannel]:
"""List all subchannels of a given channel."""
for subchannel_ref in channel.subchannel_ref:
yield self.get_subchannel(subchannel_ref.subchannel_id)
yield self.get_subchannel(subchannel_ref.subchannel_id, **kwargs)
def list_subchannels_sockets(self,
subchannel: Subchannel) -> Iterator[Socket]:
def list_subchannels_sockets(self, subchannel: Subchannel,
**kwargs) -> Iterator[Socket]:
"""List all sockets of a given subchannel."""
for socket_ref in subchannel.socket_ref:
yield self.get_socket(socket_ref.socket_id)
yield self.get_socket(socket_ref.socket_id, **kwargs)
def get_subchannel(self, subchannel_id) -> Subchannel:
def get_subchannel(self, subchannel_id, **kwargs) -> Subchannel:
"""Return a single Subchannel, otherwise raises RpcError."""
response: _GetSubchannelResponse = self.call_unary_with_deadline(
rpc='GetSubchannel',
req=_GetSubchannelRequest(subchannel_id=subchannel_id))
req=_GetSubchannelRequest(subchannel_id=subchannel_id),
**kwargs)
return response.subchannel
def get_socket(self, socket_id) -> Socket:
def get_socket(self, socket_id, **kwargs) -> Socket:
"""Return a single Socket, otherwise raises RpcError."""
response: _GetSocketResponse = self.call_unary_with_deadline(
rpc='GetSocket', req=_GetSocketRequest(socket_id=socket_id))
rpc='GetSocket',
req=_GetSocketRequest(socket_id=socket_id),
**kwargs)
return response.socket

@ -49,5 +49,5 @@ class LoadBalancerStatsServiceClient(framework.rpc.grpc.GrpcClientHelper):
req=_LoadBalancerStatsRequest(
num_rpcs=num_rpcs,
timeout_sec=timeout_sec),
wait_for_ready_sec=timeout_sec,
deadline_sec=timeout_sec,
log_level=logging.INFO)

@ -83,9 +83,6 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp):
return self.load_balancer_stats.get_client_stats(
num_rpcs=num_rpcs, timeout_sec=timeout_sec)
def get_server_channels(self) -> Iterator[_ChannelzChannel]:
return self.channelz.find_channels_for_target(self.server_target)
def wait_for_active_server_channel(self) -> _ChannelzChannel:
"""Wait for the channel to the server to transition to READY.
@ -94,16 +91,9 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp):
"""
return self.wait_for_server_channel_state(_ChannelzChannelState.READY)
def get_active_server_channel(self) -> _ChannelzChannel:
"""Return a READY channel to the server.
Raises:
GrpcApp.NotFound: If there's no READY channel to the server.
"""
return self.find_server_channel_with_state(_ChannelzChannelState.READY)
def get_active_server_channel_socket(self) -> _ChannelzSocket:
channel = self.get_active_server_channel()
channel = self.find_server_channel_with_state(
_ChannelzChannelState.READY)
# Get the first subchannel of the active channel to the server.
logger.debug(
'Retrieving client -> server socket, '
@ -125,17 +115,25 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp):
self,
state: _ChannelzChannelState,
*,
timeout: Optional[_timedelta] = None) -> _ChannelzChannel:
timeout: Optional[_timedelta] = None,
rpc_deadline: Optional[_timedelta] = None) -> _ChannelzChannel:
# When polling for a state, prefer smaller wait times to avoid
# exhausting all allowed time on a single long RPC.
if rpc_deadline is None:
rpc_deadline = _timedelta(seconds=30)
# Fine-tuned to wait for the channel to the server.
retryer = retryers.exponential_retryer_with_timeout(
wait_min=_timedelta(seconds=10),
wait_max=_timedelta(seconds=25),
timeout=_timedelta(minutes=3) if timeout is None else timeout)
timeout=_timedelta(minutes=5) if timeout is None else timeout)
logger.info('Waiting for client %s to report a %s channel to %s',
self.ip, _ChannelzChannelState.Name(state),
self.server_target)
channel = retryer(self.find_server_channel_with_state, state)
channel = retryer(self.find_server_channel_with_state,
state,
rpc_deadline=rpc_deadline)
logger.info('Client %s channel to %s transitioned to state %s:\n%s',
self.ip, self.server_target,
_ChannelzChannelState.Name(state), channel)
@ -145,8 +143,13 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp):
self,
state: _ChannelzChannelState,
*,
rpc_deadline: Optional[_timedelta] = None,
check_subchannel=True) -> _ChannelzChannel:
for channel in self.get_server_channels():
rpc_params = {}
if rpc_deadline is not None:
rpc_params['deadline_sec'] = rpc_deadline.total_seconds()
for channel in self.get_server_channels(**rpc_params):
channel_state: _ChannelzChannelState = channel.data.state.state
logger.info('Server channel: %s, state: %s', channel.ref.name,
_ChannelzChannelState.Name(channel_state))
@ -156,7 +159,7 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp):
# one subchannel in the requested state.
try:
subchannel = self.find_subchannel_with_state(
channel, state)
channel, state, **rpc_params)
logger.info('Found subchannel in state %s: %s', state,
subchannel)
except self.NotFound as e:
@ -169,10 +172,15 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp):
f'Client has no {_ChannelzChannelState.Name(state)} channel with '
'the server')
def find_subchannel_with_state(
self, channel: _ChannelzChannel,
state: _ChannelzChannelState) -> _ChannelzSubchannel:
for subchannel in self.channelz.list_channel_subchannels(channel):
def get_server_channels(self, **kwargs) -> Iterator[_ChannelzChannel]:
return self.channelz.find_channels_for_target(self.server_target,
**kwargs)
def find_subchannel_with_state(self, channel: _ChannelzChannel,
state: _ChannelzChannelState,
**kwargs) -> _ChannelzSubchannel:
subchannels = self.channelz.list_channel_subchannels(channel, **kwargs)
for subchannel in subchannels:
if subchannel.data.state.state is state:
return subchannel

Loading…
Cancel
Save