From bde2b79cbdd5f8a7acc4f77b61d413f24b9fd2cd Mon Sep 17 00:00:00 2001 From: Sergii Tkachenko <sergiitk@google.com> Date: Mon, 25 Jan 2021 21:49:48 -0500 Subject: [PATCH] xds-k8s driver: wait server channelz - adjust RPC timeouts --- .../framework/helpers/retryers.py | 14 ++++- .../xds_k8s_test_driver/framework/rpc/grpc.py | 21 +++---- .../framework/rpc/grpc_channelz.py | 61 +++++++++++-------- .../framework/rpc/grpc_testing.py | 2 +- .../framework/test_app/client_app.py | 50 ++++++++------- 5 files changed, 86 insertions(+), 62 deletions(-) diff --git a/tools/run_tests/xds_k8s_test_driver/framework/helpers/retryers.py b/tools/run_tests/xds_k8s_test_driver/framework/helpers/retryers.py index d76a4066127..248c5b648d7 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/helpers/retryers.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/helpers/retryers.py @@ -21,13 +21,17 @@ We use tenacity as a general-purpose retrying library. > - https://tenacity.readthedocs.io/en/latest/index.html """ import datetime +import logging from typing import Any, List, Optional import tenacity +retryers_logger = logging.getLogger(__name__) # Type aliases timedelta = datetime.timedelta Retrying = tenacity.Retrying +_after_log = tenacity.after_log +_before_sleep_log = tenacity.before_sleep_log _retry_if_exception_type = tenacity.retry_if_exception_type _stop_after_delay = tenacity.stop_after_delay _wait_exponential = tenacity.wait_exponential @@ -45,9 +49,15 @@ def exponential_retryer_with_timeout( wait_min: timedelta, wait_max: timedelta, timeout: timedelta, - retry_on_exceptions: Optional[List[Any]] = None) -> Retrying: + retry_on_exceptions: Optional[List[Any]] = None, + logger: Optional[logging.Logger] = None, + log_level: Optional[int] = logging.DEBUG) -> Retrying: + if logger is None: + logger = retryers_logger + if log_level is None: + log_level = logging.DEBUG return Retrying(retry=_retry_on_exceptions(retry_on_exceptions), wait=_wait_exponential(min=wait_min.total_seconds(), max=wait_max.total_seconds()), stop=_stop_after_delay(timeout.total_seconds()), - reraise=True) + before_sleep=_before_sleep_log(logger, log_level)) diff --git a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py index 79ab84ad531..3e155532d58 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py @@ -29,8 +29,7 @@ Message = google.protobuf.message.Message class GrpcClientHelper: channel: grpc.Channel - DEFAULT_CONNECTION_TIMEOUT_SEC = 60 - DEFAULT_WAIT_FOR_READY_SEC = 60 + DEFAULT_RPC_DEADLINE_SEC = 90 def __init__(self, channel: grpc.Channel, stub_class: ClassVar): self.channel = channel @@ -44,20 +43,16 @@ class GrpcClientHelper: *, rpc: str, req: Message, - wait_for_ready_sec: Optional[int] = DEFAULT_WAIT_FOR_READY_SEC, - connection_timeout_sec: Optional[ - int] = DEFAULT_CONNECTION_TIMEOUT_SEC, + deadline_sec: Optional[int] = DEFAULT_RPC_DEADLINE_SEC, log_level: Optional[int] = logging.DEBUG) -> Message: - if wait_for_ready_sec is None: - wait_for_ready_sec = self.DEFAULT_WAIT_FOR_READY_SEC - if connection_timeout_sec is None: - connection_timeout_sec = self.DEFAULT_CONNECTION_TIMEOUT_SEC + if deadline_sec is None: + deadline_sec = self.DEFAULT_RPC_DEADLINE_SEC - timeout_sec = wait_for_ready_sec + connection_timeout_sec - rpc_callable: grpc.UnaryUnaryMultiCallable = getattr(self.stub, rpc) - - call_kwargs = dict(wait_for_ready=True, timeout=timeout_sec) + call_kwargs = dict(wait_for_ready=True, timeout=deadline_sec) self._log_rpc_request(rpc, req, call_kwargs, log_level) + + # Call RPC, e.g. RpcStub(channel).RpcMethod(req, ...options) + rpc_callable: grpc.UnaryUnaryMultiCallable = getattr(self.stub, rpc) return rpc_callable(req, **call_kwargs) def _log_rpc_request(self, rpc, req, call_kwargs, log_level=logging.DEBUG): diff --git a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py index 3bf4b261313..b4e6b18761d 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py @@ -95,22 +95,25 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper): return server_socket return None - def find_channels_for_target(self, target: str) -> Iterator[Channel]: - return (channel for channel in self.list_channels() + def find_channels_for_target(self, target: str, + **kwargs) -> Iterator[Channel]: + return (channel for channel in self.list_channels(**kwargs) if channel.data.target == target) - def find_server_listening_on_port(self, port: int) -> Optional[Server]: - for server in self.list_servers(): + def find_server_listening_on_port(self, port: int, + **kwargs) -> Optional[Server]: + for server in self.list_servers(**kwargs): listen_socket_ref: SocketRef for listen_socket_ref in server.listen_socket: - listen_socket = self.get_socket(listen_socket_ref.socket_id) + listen_socket = self.get_socket(listen_socket_ref.socket_id, + **kwargs) listen_address: Address = listen_socket.local if (self.is_sock_tcpip_address(listen_address) and listen_address.tcpip_address.port == port): return server return None - def list_channels(self) -> Iterator[Channel]: + def list_channels(self, **kwargs) -> Iterator[Channel]: """ Iterate over all pages of all root channels. @@ -125,12 +128,13 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper): start += 1 response = self.call_unary_with_deadline( rpc='GetTopChannels', - req=_GetTopChannelsRequest(start_channel_id=start)) + req=_GetTopChannelsRequest(start_channel_id=start), + **kwargs) for channel in response.channel: start = max(start, channel.ref.channel_id) yield channel - def list_servers(self) -> Iterator[Server]: + def list_servers(self, **kwargs) -> Iterator[Server]: """Iterate over all pages of all servers that exist in the process.""" start: int = -1 response: Optional[_GetServersResponse] = None @@ -139,12 +143,14 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper): # value by adding 1 to the highest seen result ID. start += 1 response = self.call_unary_with_deadline( - rpc='GetServers', req=_GetServersRequest(start_server_id=start)) + rpc='GetServers', + req=_GetServersRequest(start_server_id=start), + **kwargs) for server in response.server: start = max(start, server.ref.server_id) yield server - def list_server_sockets(self, server: Server) -> Iterator[Socket]: + def list_server_sockets(self, server: Server, **kwargs) -> Iterator[Socket]: """List all server sockets that exist in server process. Iterating over the results will resolve additional pages automatically. @@ -158,39 +164,44 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper): response = self.call_unary_with_deadline( rpc='GetServerSockets', req=_GetServerSocketsRequest(server_id=server.ref.server_id, - start_socket_id=start)) + start_socket_id=start), + **kwargs) socket_ref: SocketRef for socket_ref in response.socket_ref: start = max(start, socket_ref.socket_id) # Yield actual socket - yield self.get_socket(socket_ref.socket_id) + yield self.get_socket(socket_ref.socket_id, **kwargs) - def list_channel_sockets(self, channel: Channel) -> Iterator[Socket]: + def list_channel_sockets(self, channel: Channel, + **kwargs) -> Iterator[Socket]: """List all sockets of all subchannels of a given channel.""" - for subchannel in self.list_channel_subchannels(channel): - yield from self.list_subchannels_sockets(subchannel) + for subchannel in self.list_channel_subchannels(channel, **kwargs): + yield from self.list_subchannels_sockets(subchannel, **kwargs) - def list_channel_subchannels(self, - channel: Channel) -> Iterator[Subchannel]: + def list_channel_subchannels(self, channel: Channel, + **kwargs) -> Iterator[Subchannel]: """List all subchannels of a given channel.""" for subchannel_ref in channel.subchannel_ref: - yield self.get_subchannel(subchannel_ref.subchannel_id) + yield self.get_subchannel(subchannel_ref.subchannel_id, **kwargs) - def list_subchannels_sockets(self, - subchannel: Subchannel) -> Iterator[Socket]: + def list_subchannels_sockets(self, subchannel: Subchannel, + **kwargs) -> Iterator[Socket]: """List all sockets of a given subchannel.""" for socket_ref in subchannel.socket_ref: - yield self.get_socket(socket_ref.socket_id) + yield self.get_socket(socket_ref.socket_id, **kwargs) - def get_subchannel(self, subchannel_id) -> Subchannel: + def get_subchannel(self, subchannel_id, **kwargs) -> Subchannel: """Return a single Subchannel, otherwise raises RpcError.""" response: _GetSubchannelResponse = self.call_unary_with_deadline( rpc='GetSubchannel', - req=_GetSubchannelRequest(subchannel_id=subchannel_id)) + req=_GetSubchannelRequest(subchannel_id=subchannel_id), + **kwargs) return response.subchannel - def get_socket(self, socket_id) -> Socket: + def get_socket(self, socket_id, **kwargs) -> Socket: """Return a single Socket, otherwise raises RpcError.""" response: _GetSocketResponse = self.call_unary_with_deadline( - rpc='GetSocket', req=_GetSocketRequest(socket_id=socket_id)) + rpc='GetSocket', + req=_GetSocketRequest(socket_id=socket_id), + **kwargs) return response.socket diff --git a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_testing.py b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_testing.py index 8c56cede09d..31485f9d561 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_testing.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_testing.py @@ -49,5 +49,5 @@ class LoadBalancerStatsServiceClient(framework.rpc.grpc.GrpcClientHelper): req=_LoadBalancerStatsRequest( num_rpcs=num_rpcs, timeout_sec=timeout_sec), - wait_for_ready_sec=timeout_sec, + deadline_sec=timeout_sec, log_level=logging.INFO) diff --git a/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py b/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py index 34086cf5d52..ea1ab8d4a96 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py @@ -83,9 +83,6 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp): return self.load_balancer_stats.get_client_stats( num_rpcs=num_rpcs, timeout_sec=timeout_sec) - def get_server_channels(self) -> Iterator[_ChannelzChannel]: - return self.channelz.find_channels_for_target(self.server_target) - def wait_for_active_server_channel(self) -> _ChannelzChannel: """Wait for the channel to the server to transition to READY. @@ -94,16 +91,9 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp): """ return self.wait_for_server_channel_state(_ChannelzChannelState.READY) - def get_active_server_channel(self) -> _ChannelzChannel: - """Return a READY channel to the server. - - Raises: - GrpcApp.NotFound: If there's no READY channel to the server. - """ - return self.find_server_channel_with_state(_ChannelzChannelState.READY) - def get_active_server_channel_socket(self) -> _ChannelzSocket: - channel = self.get_active_server_channel() + channel = self.find_server_channel_with_state( + _ChannelzChannelState.READY) # Get the first subchannel of the active channel to the server. logger.debug( 'Retrieving client -> server socket, ' @@ -125,17 +115,25 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp): self, state: _ChannelzChannelState, *, - timeout: Optional[_timedelta] = None) -> _ChannelzChannel: + timeout: Optional[_timedelta] = None, + rpc_deadline: Optional[_timedelta] = None) -> _ChannelzChannel: + # When polling for a state, prefer smaller wait times to avoid + # exhausting all allowed time on a single long RPC. + if rpc_deadline is None: + rpc_deadline = _timedelta(seconds=30) + # Fine-tuned to wait for the channel to the server. retryer = retryers.exponential_retryer_with_timeout( wait_min=_timedelta(seconds=10), wait_max=_timedelta(seconds=25), - timeout=_timedelta(minutes=3) if timeout is None else timeout) + timeout=_timedelta(minutes=5) if timeout is None else timeout) logger.info('Waiting for client %s to report a %s channel to %s', self.ip, _ChannelzChannelState.Name(state), self.server_target) - channel = retryer(self.find_server_channel_with_state, state) + channel = retryer(self.find_server_channel_with_state, + state, + rpc_deadline=rpc_deadline) logger.info('Client %s channel to %s transitioned to state %s:\n%s', self.ip, self.server_target, _ChannelzChannelState.Name(state), channel) @@ -145,8 +143,13 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp): self, state: _ChannelzChannelState, *, + rpc_deadline: Optional[_timedelta] = None, check_subchannel=True) -> _ChannelzChannel: - for channel in self.get_server_channels(): + rpc_params = {} + if rpc_deadline is not None: + rpc_params['deadline_sec'] = rpc_deadline.total_seconds() + + for channel in self.get_server_channels(**rpc_params): channel_state: _ChannelzChannelState = channel.data.state.state logger.info('Server channel: %s, state: %s', channel.ref.name, _ChannelzChannelState.Name(channel_state)) @@ -156,7 +159,7 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp): # one subchannel in the requested state. try: subchannel = self.find_subchannel_with_state( - channel, state) + channel, state, **rpc_params) logger.info('Found subchannel in state %s: %s', state, subchannel) except self.NotFound as e: @@ -169,10 +172,15 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp): f'Client has no {_ChannelzChannelState.Name(state)} channel with ' 'the server') - def find_subchannel_with_state( - self, channel: _ChannelzChannel, - state: _ChannelzChannelState) -> _ChannelzSubchannel: - for subchannel in self.channelz.list_channel_subchannels(channel): + def get_server_channels(self, **kwargs) -> Iterator[_ChannelzChannel]: + return self.channelz.find_channels_for_target(self.server_target, + **kwargs) + + def find_subchannel_with_state(self, channel: _ChannelzChannel, + state: _ChannelzChannelState, + **kwargs) -> _ChannelzSubchannel: + subchannels = self.channelz.list_channel_subchannels(channel, **kwargs) + for subchannel in subchannels: if subchannel.data.state.state is state: return subchannel