From bde2b79cbdd5f8a7acc4f77b61d413f24b9fd2cd Mon Sep 17 00:00:00 2001
From: Sergii Tkachenko <sergiitk@google.com>
Date: Mon, 25 Jan 2021 21:49:48 -0500
Subject: [PATCH] xds-k8s driver: wait server channelz - adjust RPC timeouts

---
 .../framework/helpers/retryers.py             | 14 ++++-
 .../xds_k8s_test_driver/framework/rpc/grpc.py | 21 +++----
 .../framework/rpc/grpc_channelz.py            | 61 +++++++++++--------
 .../framework/rpc/grpc_testing.py             |  2 +-
 .../framework/test_app/client_app.py          | 50 ++++++++-------
 5 files changed, 86 insertions(+), 62 deletions(-)

diff --git a/tools/run_tests/xds_k8s_test_driver/framework/helpers/retryers.py b/tools/run_tests/xds_k8s_test_driver/framework/helpers/retryers.py
index d76a4066127..248c5b648d7 100644
--- a/tools/run_tests/xds_k8s_test_driver/framework/helpers/retryers.py
+++ b/tools/run_tests/xds_k8s_test_driver/framework/helpers/retryers.py
@@ -21,13 +21,17 @@ We use tenacity as a general-purpose retrying library.
 > - https://tenacity.readthedocs.io/en/latest/index.html
 """
 import datetime
+import logging
 from typing import Any, List, Optional
 
 import tenacity
 
+retryers_logger = logging.getLogger(__name__)
 # Type aliases
 timedelta = datetime.timedelta
 Retrying = tenacity.Retrying
+_after_log = tenacity.after_log
+_before_sleep_log = tenacity.before_sleep_log
 _retry_if_exception_type = tenacity.retry_if_exception_type
 _stop_after_delay = tenacity.stop_after_delay
 _wait_exponential = tenacity.wait_exponential
@@ -45,9 +49,15 @@ def exponential_retryer_with_timeout(
         wait_min: timedelta,
         wait_max: timedelta,
         timeout: timedelta,
-        retry_on_exceptions: Optional[List[Any]] = None) -> Retrying:
+        retry_on_exceptions: Optional[List[Any]] = None,
+        logger: Optional[logging.Logger] = None,
+        log_level: Optional[int] = logging.DEBUG) -> Retrying:
+    if logger is None:
+        logger = retryers_logger
+    if log_level is None:
+        log_level = logging.DEBUG
     return Retrying(retry=_retry_on_exceptions(retry_on_exceptions),
                     wait=_wait_exponential(min=wait_min.total_seconds(),
                                            max=wait_max.total_seconds()),
                     stop=_stop_after_delay(timeout.total_seconds()),
-                    reraise=True)
+                    before_sleep=_before_sleep_log(logger, log_level))
diff --git a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py
index 79ab84ad531..3e155532d58 100644
--- a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py
+++ b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py
@@ -29,8 +29,7 @@ Message = google.protobuf.message.Message
 
 class GrpcClientHelper:
     channel: grpc.Channel
-    DEFAULT_CONNECTION_TIMEOUT_SEC = 60
-    DEFAULT_WAIT_FOR_READY_SEC = 60
+    DEFAULT_RPC_DEADLINE_SEC = 90
 
     def __init__(self, channel: grpc.Channel, stub_class: ClassVar):
         self.channel = channel
@@ -44,20 +43,16 @@ class GrpcClientHelper:
             *,
             rpc: str,
             req: Message,
-            wait_for_ready_sec: Optional[int] = DEFAULT_WAIT_FOR_READY_SEC,
-            connection_timeout_sec: Optional[
-                int] = DEFAULT_CONNECTION_TIMEOUT_SEC,
+            deadline_sec: Optional[int] = DEFAULT_RPC_DEADLINE_SEC,
             log_level: Optional[int] = logging.DEBUG) -> Message:
-        if wait_for_ready_sec is None:
-            wait_for_ready_sec = self.DEFAULT_WAIT_FOR_READY_SEC
-        if connection_timeout_sec is None:
-            connection_timeout_sec = self.DEFAULT_CONNECTION_TIMEOUT_SEC
+        if deadline_sec is None:
+            deadline_sec = self.DEFAULT_RPC_DEADLINE_SEC
 
-        timeout_sec = wait_for_ready_sec + connection_timeout_sec
-        rpc_callable: grpc.UnaryUnaryMultiCallable = getattr(self.stub, rpc)
-
-        call_kwargs = dict(wait_for_ready=True, timeout=timeout_sec)
+        call_kwargs = dict(wait_for_ready=True, timeout=deadline_sec)
         self._log_rpc_request(rpc, req, call_kwargs, log_level)
+
+        # Call RPC, e.g. RpcStub(channel).RpcMethod(req, ...options)
+        rpc_callable: grpc.UnaryUnaryMultiCallable = getattr(self.stub, rpc)
         return rpc_callable(req, **call_kwargs)
 
     def _log_rpc_request(self, rpc, req, call_kwargs, log_level=logging.DEBUG):
diff --git a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py
index 3bf4b261313..b4e6b18761d 100644
--- a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py
+++ b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py
@@ -95,22 +95,25 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
                 return server_socket
         return None
 
-    def find_channels_for_target(self, target: str) -> Iterator[Channel]:
-        return (channel for channel in self.list_channels()
+    def find_channels_for_target(self, target: str,
+                                 **kwargs) -> Iterator[Channel]:
+        return (channel for channel in self.list_channels(**kwargs)
                 if channel.data.target == target)
 
-    def find_server_listening_on_port(self, port: int) -> Optional[Server]:
-        for server in self.list_servers():
+    def find_server_listening_on_port(self, port: int,
+                                      **kwargs) -> Optional[Server]:
+        for server in self.list_servers(**kwargs):
             listen_socket_ref: SocketRef
             for listen_socket_ref in server.listen_socket:
-                listen_socket = self.get_socket(listen_socket_ref.socket_id)
+                listen_socket = self.get_socket(listen_socket_ref.socket_id,
+                                                **kwargs)
                 listen_address: Address = listen_socket.local
                 if (self.is_sock_tcpip_address(listen_address) and
                         listen_address.tcpip_address.port == port):
                     return server
         return None
 
-    def list_channels(self) -> Iterator[Channel]:
+    def list_channels(self, **kwargs) -> Iterator[Channel]:
         """
         Iterate over all pages of all root channels.
 
@@ -125,12 +128,13 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
             start += 1
             response = self.call_unary_with_deadline(
                 rpc='GetTopChannels',
-                req=_GetTopChannelsRequest(start_channel_id=start))
+                req=_GetTopChannelsRequest(start_channel_id=start),
+                **kwargs)
             for channel in response.channel:
                 start = max(start, channel.ref.channel_id)
                 yield channel
 
-    def list_servers(self) -> Iterator[Server]:
+    def list_servers(self, **kwargs) -> Iterator[Server]:
         """Iterate over all pages of all servers that exist in the process."""
         start: int = -1
         response: Optional[_GetServersResponse] = None
@@ -139,12 +143,14 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
             # value by adding 1 to the highest seen result ID.
             start += 1
             response = self.call_unary_with_deadline(
-                rpc='GetServers', req=_GetServersRequest(start_server_id=start))
+                rpc='GetServers',
+                req=_GetServersRequest(start_server_id=start),
+                **kwargs)
             for server in response.server:
                 start = max(start, server.ref.server_id)
                 yield server
 
-    def list_server_sockets(self, server: Server) -> Iterator[Socket]:
+    def list_server_sockets(self, server: Server, **kwargs) -> Iterator[Socket]:
         """List all server sockets that exist in server process.
 
         Iterating over the results will resolve additional pages automatically.
@@ -158,39 +164,44 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
             response = self.call_unary_with_deadline(
                 rpc='GetServerSockets',
                 req=_GetServerSocketsRequest(server_id=server.ref.server_id,
-                                             start_socket_id=start))
+                                             start_socket_id=start),
+                **kwargs)
             socket_ref: SocketRef
             for socket_ref in response.socket_ref:
                 start = max(start, socket_ref.socket_id)
                 # Yield actual socket
-                yield self.get_socket(socket_ref.socket_id)
+                yield self.get_socket(socket_ref.socket_id, **kwargs)
 
-    def list_channel_sockets(self, channel: Channel) -> Iterator[Socket]:
+    def list_channel_sockets(self, channel: Channel,
+                             **kwargs) -> Iterator[Socket]:
         """List all sockets of all subchannels of a given channel."""
-        for subchannel in self.list_channel_subchannels(channel):
-            yield from self.list_subchannels_sockets(subchannel)
+        for subchannel in self.list_channel_subchannels(channel, **kwargs):
+            yield from self.list_subchannels_sockets(subchannel, **kwargs)
 
-    def list_channel_subchannels(self,
-                                 channel: Channel) -> Iterator[Subchannel]:
+    def list_channel_subchannels(self, channel: Channel,
+                                 **kwargs) -> Iterator[Subchannel]:
         """List all subchannels of a given channel."""
         for subchannel_ref in channel.subchannel_ref:
-            yield self.get_subchannel(subchannel_ref.subchannel_id)
+            yield self.get_subchannel(subchannel_ref.subchannel_id, **kwargs)
 
-    def list_subchannels_sockets(self,
-                                 subchannel: Subchannel) -> Iterator[Socket]:
+    def list_subchannels_sockets(self, subchannel: Subchannel,
+                                 **kwargs) -> Iterator[Socket]:
         """List all sockets of a given subchannel."""
         for socket_ref in subchannel.socket_ref:
-            yield self.get_socket(socket_ref.socket_id)
+            yield self.get_socket(socket_ref.socket_id, **kwargs)
 
-    def get_subchannel(self, subchannel_id) -> Subchannel:
+    def get_subchannel(self, subchannel_id, **kwargs) -> Subchannel:
         """Return a single Subchannel, otherwise raises RpcError."""
         response: _GetSubchannelResponse = self.call_unary_with_deadline(
             rpc='GetSubchannel',
-            req=_GetSubchannelRequest(subchannel_id=subchannel_id))
+            req=_GetSubchannelRequest(subchannel_id=subchannel_id),
+            **kwargs)
         return response.subchannel
 
-    def get_socket(self, socket_id) -> Socket:
+    def get_socket(self, socket_id, **kwargs) -> Socket:
         """Return a single Socket, otherwise raises RpcError."""
         response: _GetSocketResponse = self.call_unary_with_deadline(
-            rpc='GetSocket', req=_GetSocketRequest(socket_id=socket_id))
+            rpc='GetSocket',
+            req=_GetSocketRequest(socket_id=socket_id),
+            **kwargs)
         return response.socket
diff --git a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_testing.py b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_testing.py
index 8c56cede09d..31485f9d561 100644
--- a/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_testing.py
+++ b/tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_testing.py
@@ -49,5 +49,5 @@ class LoadBalancerStatsServiceClient(framework.rpc.grpc.GrpcClientHelper):
                                              req=_LoadBalancerStatsRequest(
                                                  num_rpcs=num_rpcs,
                                                  timeout_sec=timeout_sec),
-                                             wait_for_ready_sec=timeout_sec,
+                                             deadline_sec=timeout_sec,
                                              log_level=logging.INFO)
diff --git a/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py b/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py
index 34086cf5d52..ea1ab8d4a96 100644
--- a/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py
+++ b/tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py
@@ -83,9 +83,6 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp):
         return self.load_balancer_stats.get_client_stats(
             num_rpcs=num_rpcs, timeout_sec=timeout_sec)
 
-    def get_server_channels(self) -> Iterator[_ChannelzChannel]:
-        return self.channelz.find_channels_for_target(self.server_target)
-
     def wait_for_active_server_channel(self) -> _ChannelzChannel:
         """Wait for the channel to the server to transition to READY.
 
@@ -94,16 +91,9 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp):
         """
         return self.wait_for_server_channel_state(_ChannelzChannelState.READY)
 
-    def get_active_server_channel(self) -> _ChannelzChannel:
-        """Return a READY channel to the server.
-
-        Raises:
-            GrpcApp.NotFound: If there's no READY channel to the server.
-        """
-        return self.find_server_channel_with_state(_ChannelzChannelState.READY)
-
     def get_active_server_channel_socket(self) -> _ChannelzSocket:
-        channel = self.get_active_server_channel()
+        channel = self.find_server_channel_with_state(
+            _ChannelzChannelState.READY)
         # Get the first subchannel of the active channel to the server.
         logger.debug(
             'Retrieving client -> server socket, '
@@ -125,17 +115,25 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp):
             self,
             state: _ChannelzChannelState,
             *,
-            timeout: Optional[_timedelta] = None) -> _ChannelzChannel:
+            timeout: Optional[_timedelta] = None,
+            rpc_deadline: Optional[_timedelta] = None) -> _ChannelzChannel:
+        # When polling for a state, prefer smaller wait times to avoid
+        # exhausting all allowed time on a single long RPC.
+        if rpc_deadline is None:
+            rpc_deadline = _timedelta(seconds=30)
+
         # Fine-tuned to wait for the channel to the server.
         retryer = retryers.exponential_retryer_with_timeout(
             wait_min=_timedelta(seconds=10),
             wait_max=_timedelta(seconds=25),
-            timeout=_timedelta(minutes=3) if timeout is None else timeout)
+            timeout=_timedelta(minutes=5) if timeout is None else timeout)
 
         logger.info('Waiting for client %s to report a %s channel to %s',
                     self.ip, _ChannelzChannelState.Name(state),
                     self.server_target)
-        channel = retryer(self.find_server_channel_with_state, state)
+        channel = retryer(self.find_server_channel_with_state,
+                          state,
+                          rpc_deadline=rpc_deadline)
         logger.info('Client %s channel to %s transitioned to state %s:\n%s',
                     self.ip, self.server_target,
                     _ChannelzChannelState.Name(state), channel)
@@ -145,8 +143,13 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp):
             self,
             state: _ChannelzChannelState,
             *,
+            rpc_deadline: Optional[_timedelta] = None,
             check_subchannel=True) -> _ChannelzChannel:
-        for channel in self.get_server_channels():
+        rpc_params = {}
+        if rpc_deadline is not None:
+            rpc_params['deadline_sec'] = rpc_deadline.total_seconds()
+
+        for channel in self.get_server_channels(**rpc_params):
             channel_state: _ChannelzChannelState = channel.data.state.state
             logger.info('Server channel: %s, state: %s', channel.ref.name,
                         _ChannelzChannelState.Name(channel_state))
@@ -156,7 +159,7 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp):
                     # one subchannel in the requested state.
                     try:
                         subchannel = self.find_subchannel_with_state(
-                            channel, state)
+                            channel, state, **rpc_params)
                         logger.info('Found subchannel in state %s: %s', state,
                                     subchannel)
                     except self.NotFound as e:
@@ -169,10 +172,15 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp):
             f'Client has no {_ChannelzChannelState.Name(state)} channel with '
             'the server')
 
-    def find_subchannel_with_state(
-            self, channel: _ChannelzChannel,
-            state: _ChannelzChannelState) -> _ChannelzSubchannel:
-        for subchannel in self.channelz.list_channel_subchannels(channel):
+    def get_server_channels(self, **kwargs) -> Iterator[_ChannelzChannel]:
+        return self.channelz.find_channels_for_target(self.server_target,
+                                                      **kwargs)
+
+    def find_subchannel_with_state(self, channel: _ChannelzChannel,
+                                   state: _ChannelzChannelState,
+                                   **kwargs) -> _ChannelzSubchannel:
+        subchannels = self.channelz.list_channel_subchannels(channel, **kwargs)
+        for subchannel in subchannels:
             if subchannel.data.state.state is state:
                 return subchannel