diff --git a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/compute.py b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/compute.py index 8d128ef3f9e..c17f4c7559d 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/compute.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/compute.py @@ -113,7 +113,8 @@ class ComputeV1(gcp.api.GcpProjectApiResource): name: str, health_check: GcpResource, affinity_header: str = None, - protocol: Optional[BackendServiceProtocol] = None) -> GcpResource: + protocol: Optional[BackendServiceProtocol] = None, + subset_size: Optional[int] = None) -> GcpResource: if not isinstance(protocol, self.BackendServiceProtocol): raise TypeError(f'Unexpected Backend Service protocol: {protocol}') body = { @@ -130,6 +131,11 @@ class ComputeV1(gcp.api.GcpProjectApiResource): body['consistentHash'] = { 'httpHeaderName': affinity_header, } + if subset_size: + body['subsetting'] = { + 'policy': 'CONSISTENT_HASH_SUBSETTING', + 'subsetSize': subset_size + } return self._insert_resource(self.api.backendServices(), body) def get_backend_service_traffic_director(self, name: str) -> GcpResource: diff --git a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/traffic_director.py b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/traffic_director.py index 9de4e2852c0..3634b869aea 100644 --- a/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/traffic_director.py +++ b/tools/run_tests/xds_k8s_test_driver/framework/infrastructure/traffic_director.py @@ -177,14 +177,19 @@ class TrafficDirectorManager: self.health_check = None def create_backend_service( - self, protocol: Optional[BackendServiceProtocol] = _BackendGRPC): + self, + protocol: Optional[BackendServiceProtocol] = _BackendGRPC, + subset_size: Optional[int] = None): if protocol is None: protocol = _BackendGRPC name = self.make_resource_name(self.BACKEND_SERVICE_NAME) logger.info('Creating %s Backend Service "%s"', protocol.name, name) resource = self.compute.create_backend_service_traffic_director( - name, health_check=self.health_check, protocol=protocol) + name, + health_check=self.health_check, + protocol=protocol, + subset_size=subset_size) self.backend_service = resource self.backend_service_protocol = protocol diff --git a/tools/run_tests/xds_k8s_test_driver/tests/subsetting_test.py b/tools/run_tests/xds_k8s_test_driver/tests/subsetting_test.py new file mode 100644 index 00000000000..2b652240450 --- /dev/null +++ b/tools/run_tests/xds_k8s_test_driver/tests/subsetting_test.py @@ -0,0 +1,104 @@ +# Copyright 2021 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +from typing import List, Optional + +from absl import flags +from absl import logging +from absl.testing import absltest +from google.protobuf import json_format + +from framework import xds_k8s_testcase +from framework import xds_url_map_testcase +from framework.infrastructure import k8s +from framework.test_app import server_app + +flags.adopt_module_key_flags(xds_k8s_testcase) + +# Type aliases +_XdsTestServer = xds_k8s_testcase.XdsTestServer +_XdsTestClient = xds_k8s_testcase.XdsTestClient + +_SUBSET_SIZE = 4 +_NUM_BACKENDS = 8 +_NUM_CLIENTS = 3 + + +class SubsettingTest(xds_k8s_testcase.RegularXdsKubernetesTestCase): + + def test_subsetting_basic(self) -> None: + with self.subTest('00_create_health_check'): + self.td.create_health_check() + + with self.subTest('01_create_backend_services'): + self.td.create_backend_service(subset_size=_SUBSET_SIZE) + + with self.subTest('02_create_url_map'): + self.td.create_url_map(self.server_xds_host, self.server_xds_port) + + with self.subTest('03_create_target_proxy'): + self.td.create_target_proxy() + + with self.subTest('04_create_forwarding_rule'): + self.td.create_forwarding_rule(self.server_xds_port) + + with self.subTest('05_start_test_servers'): + self.test_servers: List[_XdsTestServer] = self.startTestServers( + replica_count=_NUM_BACKENDS) + + with self.subTest('06_add_server_backends_to_backend_services'): + self.setupServerBackends() + + rpc_distribution = collections.defaultdict(int) + with self.subTest('07_start_test_client'): + for i in range(_NUM_CLIENTS): + # Clean created client pods if there is any + self.client_runner.cleanup(force=True) + # Create a test client + test_client: _XdsTestClient = self.startTestClient( + self.test_servers[0]) + # Validate the number of received endpoints + config = test_client.csds.fetch_client_status( + log_level=logging.INFO) + self.assertIsNotNone(config) + json_config = json_format.MessageToDict(config) + parsed = xds_url_map_testcase.DumpedXdsConfig(json_config) + logging.info('Client %d received endpoints (len=%s): %s', i, + len(parsed.endpoints), parsed.endpoints) + self.assertLen(parsed.endpoints, _SUBSET_SIZE) + # Record RPC stats + lb_stats = self.getClientRpcStats(test_client, + _NUM_BACKENDS * 25) + for key, value in lb_stats.rpcs_by_peer.items(): + rpc_distribution[key] += value + + with self.subTest('08_log_rpc_distribution'): + server_entries = sorted(rpc_distribution.items(), + key=lambda x: -x[1]) + # Validate if clients are receiving different sets of backends (3 + # client received a total of 4 unique backends == FAIL, a total of 5 + # unique backends == PASS) + self.assertGreater(len(server_entries), _SUBSET_SIZE) + logging.info('RPC distribution (len=%s): %s', len(server_entries), + server_entries) + peak = server_entries[0][1] + mean = sum(map(lambda x: x[1], + server_entries)) / len(server_entries) + logging.info('Peak=%d Mean=%.1f Peak-to-Mean-Ratio=%.2f', peak, + mean, peak / mean) + + +if __name__ == '__main__': + absltest.main(failfast=True)