xds-k8s driver: wait server channelz - adjust RPC timeouts

pull/25271/head
Sergii Tkachenko 4 years ago
parent ea662ef791
commit bde2b79cbd
  1. 14
      tools/run_tests/xds_k8s_test_driver/framework/helpers/retryers.py
  2. 21
      tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py
  3. 61
      tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py
  4. 2
      tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_testing.py
  5. 50
      tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py

@ -21,13 +21,17 @@ We use tenacity as a general-purpose retrying library.
> - https://tenacity.readthedocs.io/en/latest/index.html > - https://tenacity.readthedocs.io/en/latest/index.html
""" """
import datetime import datetime
import logging
from typing import Any, List, Optional from typing import Any, List, Optional
import tenacity import tenacity
retryers_logger = logging.getLogger(__name__)
# Type aliases # Type aliases
timedelta = datetime.timedelta timedelta = datetime.timedelta
Retrying = tenacity.Retrying Retrying = tenacity.Retrying
_after_log = tenacity.after_log
_before_sleep_log = tenacity.before_sleep_log
_retry_if_exception_type = tenacity.retry_if_exception_type _retry_if_exception_type = tenacity.retry_if_exception_type
_stop_after_delay = tenacity.stop_after_delay _stop_after_delay = tenacity.stop_after_delay
_wait_exponential = tenacity.wait_exponential _wait_exponential = tenacity.wait_exponential
@ -45,9 +49,15 @@ def exponential_retryer_with_timeout(
wait_min: timedelta, wait_min: timedelta,
wait_max: timedelta, wait_max: timedelta,
timeout: timedelta, timeout: timedelta,
retry_on_exceptions: Optional[List[Any]] = None) -> Retrying: retry_on_exceptions: Optional[List[Any]] = None,
logger: Optional[logging.Logger] = None,
log_level: Optional[int] = logging.DEBUG) -> Retrying:
if logger is None:
logger = retryers_logger
if log_level is None:
log_level = logging.DEBUG
return Retrying(retry=_retry_on_exceptions(retry_on_exceptions), return Retrying(retry=_retry_on_exceptions(retry_on_exceptions),
wait=_wait_exponential(min=wait_min.total_seconds(), wait=_wait_exponential(min=wait_min.total_seconds(),
max=wait_max.total_seconds()), max=wait_max.total_seconds()),
stop=_stop_after_delay(timeout.total_seconds()), stop=_stop_after_delay(timeout.total_seconds()),
reraise=True) before_sleep=_before_sleep_log(logger, log_level))

@ -29,8 +29,7 @@ Message = google.protobuf.message.Message
class GrpcClientHelper: class GrpcClientHelper:
channel: grpc.Channel channel: grpc.Channel
DEFAULT_CONNECTION_TIMEOUT_SEC = 60 DEFAULT_RPC_DEADLINE_SEC = 90
DEFAULT_WAIT_FOR_READY_SEC = 60
def __init__(self, channel: grpc.Channel, stub_class: ClassVar): def __init__(self, channel: grpc.Channel, stub_class: ClassVar):
self.channel = channel self.channel = channel
@ -44,20 +43,16 @@ class GrpcClientHelper:
*, *,
rpc: str, rpc: str,
req: Message, req: Message,
wait_for_ready_sec: Optional[int] = DEFAULT_WAIT_FOR_READY_SEC, deadline_sec: Optional[int] = DEFAULT_RPC_DEADLINE_SEC,
connection_timeout_sec: Optional[
int] = DEFAULT_CONNECTION_TIMEOUT_SEC,
log_level: Optional[int] = logging.DEBUG) -> Message: log_level: Optional[int] = logging.DEBUG) -> Message:
if wait_for_ready_sec is None: if deadline_sec is None:
wait_for_ready_sec = self.DEFAULT_WAIT_FOR_READY_SEC deadline_sec = self.DEFAULT_RPC_DEADLINE_SEC
if connection_timeout_sec is None:
connection_timeout_sec = self.DEFAULT_CONNECTION_TIMEOUT_SEC
timeout_sec = wait_for_ready_sec + connection_timeout_sec call_kwargs = dict(wait_for_ready=True, timeout=deadline_sec)
rpc_callable: grpc.UnaryUnaryMultiCallable = getattr(self.stub, rpc)
call_kwargs = dict(wait_for_ready=True, timeout=timeout_sec)
self._log_rpc_request(rpc, req, call_kwargs, log_level) self._log_rpc_request(rpc, req, call_kwargs, log_level)
# Call RPC, e.g. RpcStub(channel).RpcMethod(req, ...options)
rpc_callable: grpc.UnaryUnaryMultiCallable = getattr(self.stub, rpc)
return rpc_callable(req, **call_kwargs) return rpc_callable(req, **call_kwargs)
def _log_rpc_request(self, rpc, req, call_kwargs, log_level=logging.DEBUG): def _log_rpc_request(self, rpc, req, call_kwargs, log_level=logging.DEBUG):

@ -95,22 +95,25 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
return server_socket return server_socket
return None return None
def find_channels_for_target(self, target: str) -> Iterator[Channel]: def find_channels_for_target(self, target: str,
return (channel for channel in self.list_channels() **kwargs) -> Iterator[Channel]:
return (channel for channel in self.list_channels(**kwargs)
if channel.data.target == target) if channel.data.target == target)
def find_server_listening_on_port(self, port: int) -> Optional[Server]: def find_server_listening_on_port(self, port: int,
for server in self.list_servers(): **kwargs) -> Optional[Server]:
for server in self.list_servers(**kwargs):
listen_socket_ref: SocketRef listen_socket_ref: SocketRef
for listen_socket_ref in server.listen_socket: for listen_socket_ref in server.listen_socket:
listen_socket = self.get_socket(listen_socket_ref.socket_id) listen_socket = self.get_socket(listen_socket_ref.socket_id,
**kwargs)
listen_address: Address = listen_socket.local listen_address: Address = listen_socket.local
if (self.is_sock_tcpip_address(listen_address) and if (self.is_sock_tcpip_address(listen_address) and
listen_address.tcpip_address.port == port): listen_address.tcpip_address.port == port):
return server return server
return None return None
def list_channels(self) -> Iterator[Channel]: def list_channels(self, **kwargs) -> Iterator[Channel]:
""" """
Iterate over all pages of all root channels. Iterate over all pages of all root channels.
@ -125,12 +128,13 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
start += 1 start += 1
response = self.call_unary_with_deadline( response = self.call_unary_with_deadline(
rpc='GetTopChannels', rpc='GetTopChannels',
req=_GetTopChannelsRequest(start_channel_id=start)) req=_GetTopChannelsRequest(start_channel_id=start),
**kwargs)
for channel in response.channel: for channel in response.channel:
start = max(start, channel.ref.channel_id) start = max(start, channel.ref.channel_id)
yield channel yield channel
def list_servers(self) -> Iterator[Server]: def list_servers(self, **kwargs) -> Iterator[Server]:
"""Iterate over all pages of all servers that exist in the process.""" """Iterate over all pages of all servers that exist in the process."""
start: int = -1 start: int = -1
response: Optional[_GetServersResponse] = None response: Optional[_GetServersResponse] = None
@ -139,12 +143,14 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
# value by adding 1 to the highest seen result ID. # value by adding 1 to the highest seen result ID.
start += 1 start += 1
response = self.call_unary_with_deadline( response = self.call_unary_with_deadline(
rpc='GetServers', req=_GetServersRequest(start_server_id=start)) rpc='GetServers',
req=_GetServersRequest(start_server_id=start),
**kwargs)
for server in response.server: for server in response.server:
start = max(start, server.ref.server_id) start = max(start, server.ref.server_id)
yield server yield server
def list_server_sockets(self, server: Server) -> Iterator[Socket]: def list_server_sockets(self, server: Server, **kwargs) -> Iterator[Socket]:
"""List all server sockets that exist in server process. """List all server sockets that exist in server process.
Iterating over the results will resolve additional pages automatically. Iterating over the results will resolve additional pages automatically.
@ -158,39 +164,44 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
response = self.call_unary_with_deadline( response = self.call_unary_with_deadline(
rpc='GetServerSockets', rpc='GetServerSockets',
req=_GetServerSocketsRequest(server_id=server.ref.server_id, req=_GetServerSocketsRequest(server_id=server.ref.server_id,
start_socket_id=start)) start_socket_id=start),
**kwargs)
socket_ref: SocketRef socket_ref: SocketRef
for socket_ref in response.socket_ref: for socket_ref in response.socket_ref:
start = max(start, socket_ref.socket_id) start = max(start, socket_ref.socket_id)
# Yield actual socket # Yield actual socket
yield self.get_socket(socket_ref.socket_id) yield self.get_socket(socket_ref.socket_id, **kwargs)
def list_channel_sockets(self, channel: Channel) -> Iterator[Socket]: def list_channel_sockets(self, channel: Channel,
**kwargs) -> Iterator[Socket]:
"""List all sockets of all subchannels of a given channel.""" """List all sockets of all subchannels of a given channel."""
for subchannel in self.list_channel_subchannels(channel): for subchannel in self.list_channel_subchannels(channel, **kwargs):
yield from self.list_subchannels_sockets(subchannel) yield from self.list_subchannels_sockets(subchannel, **kwargs)
def list_channel_subchannels(self, def list_channel_subchannels(self, channel: Channel,
channel: Channel) -> Iterator[Subchannel]: **kwargs) -> Iterator[Subchannel]:
"""List all subchannels of a given channel.""" """List all subchannels of a given channel."""
for subchannel_ref in channel.subchannel_ref: for subchannel_ref in channel.subchannel_ref:
yield self.get_subchannel(subchannel_ref.subchannel_id) yield self.get_subchannel(subchannel_ref.subchannel_id, **kwargs)
def list_subchannels_sockets(self, def list_subchannels_sockets(self, subchannel: Subchannel,
subchannel: Subchannel) -> Iterator[Socket]: **kwargs) -> Iterator[Socket]:
"""List all sockets of a given subchannel.""" """List all sockets of a given subchannel."""
for socket_ref in subchannel.socket_ref: for socket_ref in subchannel.socket_ref:
yield self.get_socket(socket_ref.socket_id) yield self.get_socket(socket_ref.socket_id, **kwargs)
def get_subchannel(self, subchannel_id) -> Subchannel: def get_subchannel(self, subchannel_id, **kwargs) -> Subchannel:
"""Return a single Subchannel, otherwise raises RpcError.""" """Return a single Subchannel, otherwise raises RpcError."""
response: _GetSubchannelResponse = self.call_unary_with_deadline( response: _GetSubchannelResponse = self.call_unary_with_deadline(
rpc='GetSubchannel', rpc='GetSubchannel',
req=_GetSubchannelRequest(subchannel_id=subchannel_id)) req=_GetSubchannelRequest(subchannel_id=subchannel_id),
**kwargs)
return response.subchannel return response.subchannel
def get_socket(self, socket_id) -> Socket: def get_socket(self, socket_id, **kwargs) -> Socket:
"""Return a single Socket, otherwise raises RpcError.""" """Return a single Socket, otherwise raises RpcError."""
response: _GetSocketResponse = self.call_unary_with_deadline( response: _GetSocketResponse = self.call_unary_with_deadline(
rpc='GetSocket', req=_GetSocketRequest(socket_id=socket_id)) rpc='GetSocket',
req=_GetSocketRequest(socket_id=socket_id),
**kwargs)
return response.socket return response.socket

@ -49,5 +49,5 @@ class LoadBalancerStatsServiceClient(framework.rpc.grpc.GrpcClientHelper):
req=_LoadBalancerStatsRequest( req=_LoadBalancerStatsRequest(
num_rpcs=num_rpcs, num_rpcs=num_rpcs,
timeout_sec=timeout_sec), timeout_sec=timeout_sec),
wait_for_ready_sec=timeout_sec, deadline_sec=timeout_sec,
log_level=logging.INFO) log_level=logging.INFO)

@ -83,9 +83,6 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp):
return self.load_balancer_stats.get_client_stats( return self.load_balancer_stats.get_client_stats(
num_rpcs=num_rpcs, timeout_sec=timeout_sec) num_rpcs=num_rpcs, timeout_sec=timeout_sec)
def get_server_channels(self) -> Iterator[_ChannelzChannel]:
return self.channelz.find_channels_for_target(self.server_target)
def wait_for_active_server_channel(self) -> _ChannelzChannel: def wait_for_active_server_channel(self) -> _ChannelzChannel:
"""Wait for the channel to the server to transition to READY. """Wait for the channel to the server to transition to READY.
@ -94,16 +91,9 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp):
""" """
return self.wait_for_server_channel_state(_ChannelzChannelState.READY) return self.wait_for_server_channel_state(_ChannelzChannelState.READY)
def get_active_server_channel(self) -> _ChannelzChannel:
"""Return a READY channel to the server.
Raises:
GrpcApp.NotFound: If there's no READY channel to the server.
"""
return self.find_server_channel_with_state(_ChannelzChannelState.READY)
def get_active_server_channel_socket(self) -> _ChannelzSocket: def get_active_server_channel_socket(self) -> _ChannelzSocket:
channel = self.get_active_server_channel() channel = self.find_server_channel_with_state(
_ChannelzChannelState.READY)
# Get the first subchannel of the active channel to the server. # Get the first subchannel of the active channel to the server.
logger.debug( logger.debug(
'Retrieving client -> server socket, ' 'Retrieving client -> server socket, '
@ -125,17 +115,25 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp):
self, self,
state: _ChannelzChannelState, state: _ChannelzChannelState,
*, *,
timeout: Optional[_timedelta] = None) -> _ChannelzChannel: timeout: Optional[_timedelta] = None,
rpc_deadline: Optional[_timedelta] = None) -> _ChannelzChannel:
# When polling for a state, prefer smaller wait times to avoid
# exhausting all allowed time on a single long RPC.
if rpc_deadline is None:
rpc_deadline = _timedelta(seconds=30)
# Fine-tuned to wait for the channel to the server. # Fine-tuned to wait for the channel to the server.
retryer = retryers.exponential_retryer_with_timeout( retryer = retryers.exponential_retryer_with_timeout(
wait_min=_timedelta(seconds=10), wait_min=_timedelta(seconds=10),
wait_max=_timedelta(seconds=25), wait_max=_timedelta(seconds=25),
timeout=_timedelta(minutes=3) if timeout is None else timeout) timeout=_timedelta(minutes=5) if timeout is None else timeout)
logger.info('Waiting for client %s to report a %s channel to %s', logger.info('Waiting for client %s to report a %s channel to %s',
self.ip, _ChannelzChannelState.Name(state), self.ip, _ChannelzChannelState.Name(state),
self.server_target) self.server_target)
channel = retryer(self.find_server_channel_with_state, state) channel = retryer(self.find_server_channel_with_state,
state,
rpc_deadline=rpc_deadline)
logger.info('Client %s channel to %s transitioned to state %s:\n%s', logger.info('Client %s channel to %s transitioned to state %s:\n%s',
self.ip, self.server_target, self.ip, self.server_target,
_ChannelzChannelState.Name(state), channel) _ChannelzChannelState.Name(state), channel)
@ -145,8 +143,13 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp):
self, self,
state: _ChannelzChannelState, state: _ChannelzChannelState,
*, *,
rpc_deadline: Optional[_timedelta] = None,
check_subchannel=True) -> _ChannelzChannel: check_subchannel=True) -> _ChannelzChannel:
for channel in self.get_server_channels(): rpc_params = {}
if rpc_deadline is not None:
rpc_params['deadline_sec'] = rpc_deadline.total_seconds()
for channel in self.get_server_channels(**rpc_params):
channel_state: _ChannelzChannelState = channel.data.state.state channel_state: _ChannelzChannelState = channel.data.state.state
logger.info('Server channel: %s, state: %s', channel.ref.name, logger.info('Server channel: %s, state: %s', channel.ref.name,
_ChannelzChannelState.Name(channel_state)) _ChannelzChannelState.Name(channel_state))
@ -156,7 +159,7 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp):
# one subchannel in the requested state. # one subchannel in the requested state.
try: try:
subchannel = self.find_subchannel_with_state( subchannel = self.find_subchannel_with_state(
channel, state) channel, state, **rpc_params)
logger.info('Found subchannel in state %s: %s', state, logger.info('Found subchannel in state %s: %s', state,
subchannel) subchannel)
except self.NotFound as e: except self.NotFound as e:
@ -169,10 +172,15 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp):
f'Client has no {_ChannelzChannelState.Name(state)} channel with ' f'Client has no {_ChannelzChannelState.Name(state)} channel with '
'the server') 'the server')
def find_subchannel_with_state( def get_server_channels(self, **kwargs) -> Iterator[_ChannelzChannel]:
self, channel: _ChannelzChannel, return self.channelz.find_channels_for_target(self.server_target,
state: _ChannelzChannelState) -> _ChannelzSubchannel: **kwargs)
for subchannel in self.channelz.list_channel_subchannels(channel):
def find_subchannel_with_state(self, channel: _ChannelzChannel,
state: _ChannelzChannelState,
**kwargs) -> _ChannelzSubchannel:
subchannels = self.channelz.list_channel_subchannels(channel, **kwargs)
for subchannel in subchannels:
if subchannel.data.state.state is state: if subchannel.data.state.state is state:
return subchannel return subchannel

Loading…
Cancel
Save