yapf autoformat

pull/24983/head
Sergii Tkachenko 4 years ago
parent 86f8792136
commit 28a6f740f5
  1. 10
      tools/run_tests/xds_test_driver/bin/run_channelz.py
  2. 56
      tools/run_tests/xds_test_driver/bin/run_td_setup.py
  3. 28
      tools/run_tests/xds_test_driver/bin/run_test_client.py
  4. 22
      tools/run_tests/xds_test_driver/bin/run_test_server.py
  5. 1
      tools/run_tests/xds_test_driver/framework/infrastructure/gcp/__init__.py
  6. 54
      tools/run_tests/xds_test_driver/framework/infrastructure/gcp/api.py
  7. 150
      tools/run_tests/xds_test_driver/framework/infrastructure/gcp/compute.py
  8. 26
      tools/run_tests/xds_test_driver/framework/infrastructure/gcp/network_security.py
  9. 3
      tools/run_tests/xds_test_driver/framework/infrastructure/gcp/network_services.py
  10. 116
      tools/run_tests/xds_test_driver/framework/infrastructure/k8s.py
  11. 149
      tools/run_tests/xds_test_driver/framework/infrastructure/traffic_director.py
  12. 6
      tools/run_tests/xds_test_driver/framework/rpc/__init__.py
  13. 14
      tools/run_tests/xds_test_driver/framework/rpc/grpc_channelz.py
  14. 4
      tools/run_tests/xds_test_driver/framework/rpc/grpc_testing.py
  15. 53
      tools/run_tests/xds_test_driver/framework/test_app/base_runner.py
  16. 26
      tools/run_tests/xds_test_driver/framework/test_app/client_app.py
  17. 51
      tools/run_tests/xds_test_driver/framework/test_app/server_app.py
  18. 49
      tools/run_tests/xds_test_driver/framework/xds_flags.py
  19. 26
      tools/run_tests/xds_test_driver/framework/xds_k8s_flags.py
  20. 146
      tools/run_tests/xds_test_driver/framework/xds_k8s_testcase.py
  21. 1
      tools/run_tests/xds_test_driver/tests/baseline_test.py
  22. 23
      tools/run_tests/xds_test_driver/tests/security_test.py

@ -26,10 +26,12 @@ from framework.test_app import client_app
logger = logging.getLogger(__name__)
# Flags
_SERVER_RPC_HOST = flags.DEFINE_string(
'server_rpc_host', default='127.0.0.1', help='Server RPC host')
_CLIENT_RPC_HOST = flags.DEFINE_string(
'client_rpc_host', default='127.0.0.1', help='Client RPC host')
_SERVER_RPC_HOST = flags.DEFINE_string('server_rpc_host',
default='127.0.0.1',
help='Server RPC host')
_CLIENT_RPC_HOST = flags.DEFINE_string('client_rpc_host',
default='127.0.0.1',
help='Client RPC host')
flags.adopt_module_key_flags(xds_flags)
flags.adopt_module_key_flags(xds_k8s_flags)

@ -22,17 +22,19 @@ from framework.infrastructure import gcp
from framework.infrastructure import k8s
from framework.infrastructure import traffic_director
logger = logging.getLogger(__name__)
# Flags
_CMD = flags.DEFINE_enum(
'cmd', default='create',
enum_values=['cycle', 'create', 'cleanup',
'backends-add', 'backends-cleanup'],
help='Command')
_SECURITY = flags.DEFINE_enum(
'security', default=None, enum_values=['mtls', 'tls', 'plaintext'],
help='Configure td with security')
_CMD = flags.DEFINE_enum('cmd',
default='create',
enum_values=[
'cycle', 'create', 'cleanup', 'backends-add',
'backends-cleanup'
],
help='Command')
_SECURITY = flags.DEFINE_enum('security',
default=None,
enum_values=['mtls', 'tls', 'plaintext'],
help='Configure td with security')
flags.adopt_module_key_flags(xds_flags)
flags.adopt_module_key_flags(xds_k8s_flags)
@ -57,11 +59,10 @@ def main(argv):
gcp_api_manager = gcp.api.GcpApiManager()
if security_mode is None:
td = traffic_director.TrafficDirectorManager(
gcp_api_manager,
project=project,
resource_prefix=namespace,
network=network)
td = traffic_director.TrafficDirectorManager(gcp_api_manager,
project=project,
resource_prefix=namespace,
network=network)
else:
td = traffic_director.TrafficDirectorSecureManager(
gcp_api_manager,
@ -80,26 +81,29 @@ def main(argv):
elif security_mode == 'mtls':
logger.info('Setting up mtls')
td.setup_for_grpc(server_xds_host, server_xds_port)
td.setup_server_security(server_port,
tls=True, mtls=True)
td.setup_client_security(namespace, server_name,
tls=True, mtls=True)
td.setup_server_security(server_port, tls=True, mtls=True)
td.setup_client_security(namespace,
server_name,
tls=True,
mtls=True)
elif security_mode == 'tls':
logger.info('Setting up tls')
td.setup_for_grpc(server_xds_host, server_xds_port)
td.setup_server_security(server_port,
tls=True, mtls=False)
td.setup_client_security(namespace, server_name,
tls=True, mtls=False)
td.setup_server_security(server_port, tls=True, mtls=False)
td.setup_client_security(namespace,
server_name,
tls=True,
mtls=False)
elif security_mode == 'plaintext':
logger.info('Setting up plaintext')
td.setup_for_grpc(server_xds_host, server_xds_port)
td.setup_server_security(server_port,
tls=False, mtls=False)
td.setup_client_security(namespace, server_name,
tls=False, mtls=False)
td.setup_server_security(server_port, tls=False, mtls=False)
td.setup_client_security(namespace,
server_name,
tls=False,
mtls=False)
logger.info('Works!')
except Exception:

@ -23,21 +23,23 @@ from framework.test_app import client_app
logger = logging.getLogger(__name__)
# Flags
_CMD = flags.DEFINE_enum(
'cmd', default='run', enum_values=['run', 'cleanup'],
help='Command')
_SECURE = flags.DEFINE_bool(
"secure", default=False,
help="Run client in the secure mode")
_CMD = flags.DEFINE_enum('cmd',
default='run',
enum_values=['run', 'cleanup'],
help='Command')
_SECURE = flags.DEFINE_bool("secure",
default=False,
help="Run client in the secure mode")
_QPS = flags.DEFINE_integer('qps', default=25, help='Queries per second')
_PRINT_RESPONSE = flags.DEFINE_bool(
"print_response", default=False,
help="Client prints responses")
_REUSE_NAMESPACE = flags.DEFINE_bool(
"reuse_namespace", default=True,
help="Use existing namespace if exists")
_PRINT_RESPONSE = flags.DEFINE_bool("print_response",
default=False,
help="Client prints responses")
_REUSE_NAMESPACE = flags.DEFINE_bool("reuse_namespace",
default=True,
help="Use existing namespace if exists")
_CLEANUP_NAMESPACE = flags.DEFINE_bool(
"cleanup_namespace", default=False,
"cleanup_namespace",
default=False,
help="Delete namespace during resource cleanup")
flags.adopt_module_key_flags(xds_flags)
flags.adopt_module_key_flags(xds_k8s_flags)

@ -35,17 +35,19 @@ from framework.test_app import server_app
logger = logging.getLogger(__name__)
# Flags
_CMD = flags.DEFINE_enum(
'cmd', default='run', enum_values=['run', 'cleanup'],
help='Command')
_SECURE = flags.DEFINE_bool(
"secure", default=False,
help="Run server in the secure mode")
_REUSE_NAMESPACE = flags.DEFINE_bool(
"reuse_namespace", default=True,
help="Use existing namespace if exists")
_CMD = flags.DEFINE_enum('cmd',
default='run',
enum_values=['run', 'cleanup'],
help='Command')
_SECURE = flags.DEFINE_bool("secure",
default=False,
help="Run server in the secure mode")
_REUSE_NAMESPACE = flags.DEFINE_bool("reuse_namespace",
default=True,
help="Use existing namespace if exists")
_CLEANUP_NAMESPACE = flags.DEFINE_bool(
"cleanup_namespace", default=False,
"cleanup_namespace",
default=False,
help="Delete namespace during resource cleanup")
flags.adopt_module_key_flags(xds_flags)
flags.adopt_module_key_flags(xds_k8s_flags)

@ -15,4 +15,3 @@ from framework.infrastructure.gcp import api
from framework.infrastructure.gcp import compute
from framework.infrastructure.gcp import network_security
from framework.infrastructure.gcp import network_services

@ -28,14 +28,15 @@ import googleapiclient.errors
import tenacity
logger = logging.getLogger(__name__)
V1_DISCOVERY_URI = flags.DEFINE_string(
"v1_discovery_uri", default=discovery.V1_DISCOVERY_URI,
help="Override v1 Discovery URI")
V2_DISCOVERY_URI = flags.DEFINE_string(
"v2_discovery_uri", default=discovery.V2_DISCOVERY_URI,
help="Override v2 Discovery URI")
V1_DISCOVERY_URI = flags.DEFINE_string("v1_discovery_uri",
default=discovery.V1_DISCOVERY_URI,
help="Override v1 Discovery URI")
V2_DISCOVERY_URI = flags.DEFINE_string("v2_discovery_uri",
default=discovery.V2_DISCOVERY_URI,
help="Override v2 Discovery URI")
COMPUTE_V1_DISCOVERY_FILE = flags.DEFINE_string(
"compute_v1_discovery_file", default=None,
"compute_v1_discovery_file",
default=None,
help="Load compute v1 from discovery file")
# Type aliases
@ -43,7 +44,9 @@ Operation = operations_pb2.Operation
class GcpApiManager:
def __init__(self, *,
def __init__(self,
*,
v1_discovery_uri=None,
v2_discovery_uri=None,
compute_v1_discovery_file=None,
@ -73,8 +76,9 @@ class GcpApiManager:
def networksecurity(self, version):
api_name = 'networksecurity'
if version == 'v1alpha1':
return self._build_from_discovery_v2(
api_name, version, api_key=self.private_api_key)
return self._build_from_discovery_v2(api_name,
version,
api_key=self.private_api_key)
raise NotImplementedError(f'Network Security {version} not supported')
@ -82,22 +86,26 @@ class GcpApiManager:
def networkservices(self, version):
api_name = 'networkservices'
if version == 'v1alpha1':
return self._build_from_discovery_v2(
api_name, version, api_key=self.private_api_key)
return self._build_from_discovery_v2(api_name,
version,
api_key=self.private_api_key)
raise NotImplementedError(f'Network Services {version} not supported')
def _build_from_discovery_v1(self, api_name, version):
api = discovery.build(
api_name, version, cache_discovery=False,
discoveryServiceUrl=self.v1_discovery_uri)
api = discovery.build(api_name,
version,
cache_discovery=False,
discoveryServiceUrl=self.v1_discovery_uri)
self._exit_stack.enter_context(api)
return api
def _build_from_discovery_v2(self, api_name, version, *, api_key=None):
key_arg = f'&key={api_key}' if api_key else ''
api = discovery.build(
api_name, version, cache_discovery=False,
api_name,
version,
cache_discovery=False,
discoveryServiceUrl=f'{self.v2_discovery_uri}{key_arg}')
self._exit_stack.enter_context(api)
return api
@ -121,6 +129,7 @@ class OperationError(Error):
https://cloud.google.com/apis/design/design_patterns#long_running_operations
https://github.com/googleapis/googleapis/blob/master/google/longrunning/operations.proto
"""
def __init__(self, api_name, operation_response, message=None):
self.api_name = api_name
operation = json_format.ParseDict(operation_response, Operation())
@ -175,7 +184,8 @@ class GcpStandardCloudApiResource(GcpProjectApiResource):
**kwargs):
logger.debug("Creating %s", body)
create_req = collection.create(parent=self.parent(),
body=body, **kwargs)
body=body,
**kwargs)
self._execute(create_req)
@staticmethod
@ -191,15 +201,17 @@ class GcpStandardCloudApiResource(GcpProjectApiResource):
except googleapiclient.errors.HttpError as error:
# noinspection PyProtectedMember
reason = error._get_reason()
logger.info('Delete failed. Error: %s %s',
error.resp.status, reason)
logger.info('Delete failed. Error: %s %s', error.resp.status,
reason)
def _execute(self, request,
def _execute(self,
request,
timeout_sec=GcpProjectApiResource._WAIT_FOR_OPERATION_SEC):
operation = request.execute(num_retries=self._GCP_API_RETRIES)
self._wait(operation, timeout_sec)
def _wait(self, operation,
def _wait(self,
operation,
timeout_sec=GcpProjectApiResource._WAIT_FOR_OPERATION_SEC):
op_name = operation['name']
logger.debug('Waiting for %s operation, timeout %s sec: %s',

@ -50,7 +50,8 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
HTTP2 = enum.auto()
GRPC = enum.auto()
def create_health_check_tcp(self, name,
def create_health_check_tcp(self,
name,
use_serving_port=False) -> GcpResource:
health_check_settings = {}
if use_serving_port:
@ -66,30 +67,31 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
self._delete_resource(self.api.healthChecks(), healthCheck=name)
def create_backend_service_traffic_director(
self,
name: str,
health_check: GcpResource,
protocol: Optional[BackendServiceProtocol] = None
) -> GcpResource:
self,
name: str,
health_check: GcpResource,
protocol: Optional[BackendServiceProtocol] = None) -> GcpResource:
if not isinstance(protocol, self.BackendServiceProtocol):
raise TypeError(f'Unexpected Backend Service protocol: {protocol}')
return self._insert_resource(self.api.backendServices(), {
'name': name,
'loadBalancingScheme': 'INTERNAL_SELF_MANAGED', # Traffic Director
'healthChecks': [health_check.url],
'protocol': protocol.name,
})
return self._insert_resource(
self.api.backendServices(),
{
'name': name,
'loadBalancingScheme':
'INTERNAL_SELF_MANAGED', # Traffic Director
'healthChecks': [health_check.url],
'protocol': protocol.name,
})
def get_backend_service_traffic_director(self, name: str) -> GcpResource:
return self._get_resource(self.api.backendServices(),
backendService=name)
def patch_backend_service(self, backend_service, body, **kwargs):
self._patch_resource(
collection=self.api.backendServices(),
backendService=backend_service.name,
body=body,
**kwargs)
self._patch_resource(collection=self.api.backendServices(),
backendService=backend_service.name,
body=body,
**kwargs)
def backend_service_add_backends(self, backend_service, backends):
backend_list = [{
@ -98,16 +100,14 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
'maxRatePerEndpoint': 5
} for backend in backends]
self._patch_resource(
collection=self.api.backendServices(),
body={'backends': backend_list},
backendService=backend_service.name)
self._patch_resource(collection=self.api.backendServices(),
body={'backends': backend_list},
backendService=backend_service.name)
def backend_service_remove_all_backends(self, backend_service):
self._patch_resource(
collection=self.api.backendServices(),
body={'backends': []},
backendService=backend_service.name)
self._patch_resource(collection=self.api.backendServices(),
body={'backends': []},
backendService=backend_service.name)
def delete_backend_service(self, name):
self._delete_resource(self.api.backendServices(), backendService=name)
@ -122,18 +122,21 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
) -> GcpResource:
if dst_host_rule_match_backend_service is None:
dst_host_rule_match_backend_service = dst_default_backend_service
return self._insert_resource(self.api.urlMaps(), {
'name': name,
'defaultService': dst_default_backend_service.url,
'hostRules': [{
'hosts': src_hosts,
'pathMatcher': matcher_name,
}],
'pathMatchers': [{
'name': matcher_name,
'defaultService': dst_host_rule_match_backend_service.url,
}],
})
return self._insert_resource(
self.api.urlMaps(), {
'name':
name,
'defaultService':
dst_default_backend_service.url,
'hostRules': [{
'hosts': src_hosts,
'pathMatcher': matcher_name,
}],
'pathMatchers': [{
'name': matcher_name,
'defaultService': dst_host_rule_match_backend_service.url,
}],
})
def delete_url_map(self, name):
self._delete_resource(self.api.urlMaps(), urlMap=name)
@ -174,14 +177,17 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
target_proxy: GcpResource,
network_url: str,
) -> GcpResource:
return self._insert_resource(self.api.globalForwardingRules(), {
'name': name,
'loadBalancingScheme': 'INTERNAL_SELF_MANAGED', # Traffic Director
'portRange': src_port,
'IPAddress': '0.0.0.0',
'network': network_url,
'target': target_proxy.url,
})
return self._insert_resource(
self.api.globalForwardingRules(),
{
'name': name,
'loadBalancingScheme':
'INTERNAL_SELF_MANAGED', # Traffic Director
'portRange': src_port,
'IPAddress': '0.0.0.0',
'network': network_url,
'target': target_proxy.url,
})
def delete_forwarding_rule(self, name):
self._delete_resource(self.api.globalForwardingRules(),
@ -192,15 +198,16 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
return not neg or neg.get('size', 0) == 0
def wait_for_network_endpoint_group(self, name, zone):
@retrying.retry(retry_on_result=self._network_endpoint_group_not_ready,
stop_max_delay=60 * 1000,
wait_fixed=2 * 1000)
def _wait_for_network_endpoint_group_ready():
try:
neg = self.get_network_endpoint_group(name, zone)
logger.debug('Waiting for endpoints: NEG %s in zone %s, '
'current count %s',
neg['name'], zone, neg.get('size'))
logger.debug(
'Waiting for endpoints: NEG %s in zone %s, '
'current count %s', neg['name'], zone, neg.get('size'))
except googleapiclient.errors.HttpError as error:
# noinspection PyProtectedMember
reason = error._get_reason()
@ -211,10 +218,8 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
network_endpoint_group = _wait_for_network_endpoint_group_ready()
# @todo(sergiitk): dataclass
return self.ZonalGcpResource(
network_endpoint_group['name'],
network_endpoint_group['selfLink'],
zone)
return self.ZonalGcpResource(network_endpoint_group['name'],
network_endpoint_group['selfLink'], zone)
def get_network_endpoint_group(self, name, zone):
neg = self.api.networkEndpointGroups().get(project=self.project,
@ -232,10 +237,9 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
):
pending = set(backends)
@retrying.retry(
retry_on_result=lambda result: not result,
stop_max_delay=timeout_sec * 1000,
wait_fixed=wait_sec * 1000)
@retrying.retry(retry_on_result=lambda result: not result,
stop_max_delay=timeout_sec * 1000,
wait_fixed=wait_sec * 1000)
def _retry_backends_health():
for backend in pending:
result = self.get_backend_service_backend_health(
@ -250,9 +254,8 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
for instance in result['healthStatus']:
logger.debug(
'Backend %s in zone %s: instance %s:%s health: %s',
backend.name, backend.zone,
instance['ipAddress'], instance['port'],
instance['healthState'])
backend.name, backend.zone, instance['ipAddress'],
instance['port'], instance['healthState'])
if instance['healthState'] != 'HEALTHY':
backend_healthy = False
@ -267,8 +270,11 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
def get_backend_service_backend_health(self, backend_service, backend):
return self.api.backendServices().getHealth(
project=self.project, backendService=backend_service.name,
body={"group": backend.url}).execute()
project=self.project,
backendService=backend_service.name,
body={
"group": backend.url
}).execute()
def _get_resource(self, collection: discovery.Resource,
**kwargs) -> GcpResource:
@ -276,11 +282,8 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
logger.debug("Loaded %r", resp)
return self.GcpResource(resp['name'], resp['selfLink'])
def _insert_resource(
self,
collection: discovery.Resource,
body: Dict[str, Any]
) -> GcpResource:
def _insert_resource(self, collection: discovery.Resource,
body: Dict[str, Any]) -> GcpResource:
logger.debug("Creating %s", body)
resp = self._execute(collection.insert(project=self.project, body=body))
return self.GcpResource(body['name'], resp['targetLink'])
@ -297,14 +300,16 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
except googleapiclient.errors.HttpError as error:
# noinspection PyProtectedMember
reason = error._get_reason()
logger.info('Delete failed. Error: %s %s',
error.resp.status, reason)
logger.info('Delete failed. Error: %s %s', error.resp.status,
reason)
@staticmethod
def _operation_status_done(operation):
return 'status' in operation and operation['status'] == 'DONE'
def _execute(self, request, *,
def _execute(self,
request,
*,
test_success_fn=None,
timeout_sec=_WAIT_FOR_OPERATION_SEC):
operation = request.execute(num_retries=self._GCP_API_RETRIES)
@ -320,10 +325,9 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
logger.debug('Waiting for global operation %s, timeout %s sec',
operation['name'], timeout_sec)
response = self.wait_for_operation(
operation_request=operation_request,
test_success_fn=test_success_fn,
timeout_sec=timeout_sec)
response = self.wait_for_operation(operation_request=operation_request,
test_success_fn=test_success_fn,
timeout_sec=timeout_sec)
if 'error' in response:
logger.debug('Waiting for global operation failed, response: %r',

@ -52,22 +52,22 @@ class NetworkSecurityV1Alpha1(gcp.api.GcpStandardCloudApiResource):
self._api_locations = self.api.projects().locations()
def create_server_tls_policy(self, name, body: dict):
return self._create_resource(
self._api_locations.serverTlsPolicies(),
body, serverTlsPolicyId=name)
return self._create_resource(self._api_locations.serverTlsPolicies(),
body,
serverTlsPolicyId=name)
def get_server_tls_policy(self, name: str) -> ServerTlsPolicy:
result = self._get_resource(
collection=self._api_locations.serverTlsPolicies(),
full_name=self.resource_full_name(name, self.SERVER_TLS_POLICIES))
return self.ServerTlsPolicy(
name=name,
url=result['name'],
server_certificate=result.get('serverCertificate', {}),
mtls_policy=result.get('mtlsPolicy', {}),
create_time=result['createTime'],
update_time=result['updateTime'])
return self.ServerTlsPolicy(name=name,
url=result['name'],
server_certificate=result.get(
'serverCertificate', {}),
mtls_policy=result.get('mtlsPolicy', {}),
create_time=result['createTime'],
update_time=result['updateTime'])
def delete_server_tls_policy(self, name):
return self._delete_resource(
@ -75,9 +75,9 @@ class NetworkSecurityV1Alpha1(gcp.api.GcpStandardCloudApiResource):
full_name=self.resource_full_name(name, self.SERVER_TLS_POLICIES))
def create_client_tls_policy(self, name, body: dict):
return self._create_resource(
self._api_locations.clientTlsPolicies(),
body, clientTlsPolicyId=name)
return self._create_resource(self._api_locations.clientTlsPolicies(),
body,
clientTlsPolicyId=name)
def get_client_tls_policy(self, name: str) -> ClientTlsPolicy:
result = self._get_resource(

@ -49,7 +49,8 @@ class NetworkServicesV1Alpha1(gcp.api.GcpStandardCloudApiResource):
def create_endpoint_config_selector(self, name, body: dict):
return self._create_resource(
self._api_locations.endpointConfigSelectors(),
body, endpointConfigSelectorId=name)
body,
endpointConfigSelectorId=name)
def get_endpoint_config_selector(self, name: str) -> EndpointConfigSelector:
result = self._get_resource(

@ -35,6 +35,7 @@ ApiException = client.ApiException
def simple_resource_get(func):
def wrap_not_found_return_none(*args, **kwargs):
try:
return func(*args, **kwargs)
@ -43,6 +44,7 @@ def simple_resource_get(func):
# Ignore 404
return None
raise
return wrap_not_found_return_none
@ -51,6 +53,7 @@ def label_dict_to_selector(labels: dict) -> str:
class KubernetesApiManager:
def __init__(self, context):
self.context = context
self.client = self._cached_api_client_for_context(context)
@ -80,7 +83,8 @@ class KubernetesNamespace:
self.api = api
def apply_manifest(self, manifest):
return utils.create_from_dict(self.api.client, manifest,
return utils.create_from_dict(self.api.client,
manifest,
namespace=self.name)
@simple_resource_get
@ -91,24 +95,22 @@ class KubernetesNamespace:
def get_service_account(self, name) -> V1Service:
return self.api.core.read_namespaced_service_account(name, self.name)
def delete_service(
self,
name,
grace_period_seconds=DELETE_GRACE_PERIOD_SEC
):
def delete_service(self,
name,
grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
self.api.core.delete_namespaced_service(
name=name, namespace=self.name,
name=name,
namespace=self.name,
body=client.V1DeleteOptions(
propagation_policy='Foreground',
grace_period_seconds=grace_period_seconds))
def delete_service_account(
self,
name,
grace_period_seconds=DELETE_GRACE_PERIOD_SEC
):
def delete_service_account(self,
name,
grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
self.api.core.delete_namespaced_service_account(
name=name, namespace=self.name,
name=name,
namespace=self.name,
body=client.V1DeleteOptions(
propagation_policy='Foreground',
grace_period_seconds=grace_period_seconds))
@ -124,8 +126,8 @@ class KubernetesNamespace:
propagation_policy='Foreground',
grace_period_seconds=grace_period_seconds))
def wait_for_service_deleted(self, name: str,
timeout_sec=60, wait_sec=1):
def wait_for_service_deleted(self, name: str, timeout_sec=60, wait_sec=1):
@retrying.retry(retry_on_result=lambda r: r is not None,
stop_max_delay=timeout_sec * 1000,
wait_fixed=wait_sec * 1000)
@ -135,10 +137,14 @@ class KubernetesNamespace:
logger.info('Waiting for service %s to be deleted',
service.metadata.name)
return service
_wait_for_deleted_service_with_retry()
def wait_for_service_account_deleted(self, name: str,
timeout_sec=60, wait_sec=1):
def wait_for_service_account_deleted(self,
name: str,
timeout_sec=60,
wait_sec=1):
@retrying.retry(retry_on_result=lambda r: r is not None,
stop_max_delay=timeout_sec * 1000,
wait_fixed=wait_sec * 1000)
@ -148,10 +154,11 @@ class KubernetesNamespace:
logger.info('Waiting for service account %s to be deleted',
service_account.metadata.name)
return service_account
_wait_for_deleted_service_account_with_retry()
def wait_for_namespace_deleted(self,
timeout_sec=240, wait_sec=2):
def wait_for_namespace_deleted(self, timeout_sec=240, wait_sec=2):
@retrying.retry(retry_on_result=lambda r: r is not None,
stop_max_delay=timeout_sec * 1000,
wait_fixed=wait_sec * 1000)
@ -161,27 +168,25 @@ class KubernetesNamespace:
logger.info('Waiting for namespace %s to be deleted',
namespace.metadata.name)
return namespace
_wait_for_deleted_namespace_with_retry()
def wait_for_service_neg(self, name: str,
timeout_sec=60, wait_sec=1):
def wait_for_service_neg(self, name: str, timeout_sec=60, wait_sec=1):
@retrying.retry(retry_on_result=lambda r: not r,
stop_max_delay=timeout_sec * 1000,
wait_fixed=wait_sec * 1000)
def _wait_for_service_neg():
service = self.get_service(name)
if self.NEG_STATUS_META not in service.metadata.annotations:
logger.info('Waiting for service %s NEG',
service.metadata.name)
logger.info('Waiting for service %s NEG', service.metadata.name)
return False
return True
_wait_for_service_neg()
def get_service_neg(
self,
service_name: str,
service_port: int
) -> Tuple[str, List[str]]:
def get_service_neg(self, service_name: str,
service_port: int) -> Tuple[str, List[str]]:
service = self.get_service(service_name)
neg_info: dict = json.loads(
service.metadata.annotations[self.NEG_STATUS_META])
@ -193,13 +198,12 @@ class KubernetesNamespace:
def get_deployment(self, name) -> V1Deployment:
return self.api.apps.read_namespaced_deployment(name, self.name)
def delete_deployment(
self,
name,
grace_period_seconds=DELETE_GRACE_PERIOD_SEC
):
def delete_deployment(self,
name,
grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
self.api.apps.delete_namespaced_deployment(
name=name, namespace=self.name,
name=name,
namespace=self.name,
body=client.V1DeleteOptions(
propagation_policy='Foreground',
grace_period_seconds=grace_period_seconds))
@ -208,34 +212,43 @@ class KubernetesNamespace:
# V1LabelSelector.match_expressions not supported at the moment
return self.list_pods_with_labels(deployment.spec.selector.match_labels)
def wait_for_deployment_available_replicas(self, name, count=1,
timeout_sec=60, wait_sec=1):
def wait_for_deployment_available_replicas(self,
name,
count=1,
timeout_sec=60,
wait_sec=1):
@retrying.retry(
retry_on_result=lambda r: not self._replicas_available(r, count),
stop_max_delay=timeout_sec * 1000,
wait_fixed=wait_sec * 1000)
def _wait_for_deployment_available_replicas():
deployment = self.get_deployment(name)
logger.info('Waiting for deployment %s to have %s available '
'replicas, current count %s',
deployment.metadata.name,
count, deployment.status.available_replicas)
logger.info(
'Waiting for deployment %s to have %s available '
'replicas, current count %s', deployment.metadata.name, count,
deployment.status.available_replicas)
return deployment
_wait_for_deployment_available_replicas()
def wait_for_deployment_deleted(self, deployment_name: str,
timeout_sec=60, wait_sec=1):
def wait_for_deployment_deleted(self,
deployment_name: str,
timeout_sec=60,
wait_sec=1):
@retrying.retry(retry_on_result=lambda r: r is not None,
stop_max_delay=timeout_sec * 1000,
wait_fixed=wait_sec * 1000)
def _wait_for_deleted_deployment_with_retry():
deployment = self.get_deployment(deployment_name)
if deployment is not None:
logger.info('Waiting for deployment %s to be deleted. '
'Non-terminated replicas: %s',
deployment.metadata.name,
deployment.status.replicas)
logger.info(
'Waiting for deployment %s to be deleted. '
'Non-terminated replicas: %s', deployment.metadata.name,
deployment.status.replicas)
return deployment
_wait_for_deleted_deployment_with_retry()
def list_pods_with_labels(self, labels: dict) -> List[V1Pod]:
@ -247,15 +260,16 @@ class KubernetesNamespace:
return self.api.core.read_namespaced_pod(name, self.name)
def wait_for_pod_started(self, pod_name, timeout_sec=60, wait_sec=1):
@retrying.retry(retry_on_result=lambda r: not self._pod_started(r),
stop_max_delay=timeout_sec * 1000,
wait_fixed=wait_sec * 1000)
def _wait_for_pod_started():
pod = self.get_pod(pod_name)
logger.info('Waiting for pod %s to start, current phase: %s',
pod.metadata.name,
pod.status.phase)
pod.metadata.name, pod.status.phase)
return pod
_wait_for_pod_started()
def port_forward_pod(
@ -269,12 +283,12 @@ class KubernetesNamespace:
local_address = local_address or self.PORT_FORWARD_LOCAL_ADDRESS
local_port = local_port or remote_port
cmd = [
"kubectl", "--context", self.api.context,
"--namespace", self.name,
"kubectl", "--context", self.api.context, "--namespace", self.name,
"port-forward", "--address", local_address,
f"pod/{pod.metadata.name}", f"{local_port}:{remote_port}"
]
pf = subprocess.Popen(cmd, stdout=subprocess.PIPE,
pf = subprocess.Popen(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True)
# Wait for stdout line indicating successful start.

@ -75,13 +75,11 @@ class TrafficDirectorManager:
def network_url(self):
return f'global/networks/{self.network}'
def setup_for_grpc(
self,
service_host,
service_port,
*,
backend_protocol=BackendServiceProtocol.GRPC
):
def setup_for_grpc(self,
service_host,
service_port,
*,
backend_protocol=BackendServiceProtocol.GRPC):
self.create_health_check()
self.create_backend_service(protocol=backend_protocol)
self.create_url_map(service_host, service_port)
@ -130,9 +128,8 @@ class TrafficDirectorManager:
self.health_check = None
def create_backend_service(
self,
protocol: BackendServiceProtocol = BackendServiceProtocol.GRPC
):
self,
protocol: BackendServiceProtocol = BackendServiceProtocol.GRPC):
name = self._ns_name(self.BACKEND_SERVICE_NAME)
logger.info('Creating %s Backend Service %s', protocol.name, name)
resource = self.compute.create_backend_service_traffic_director(
@ -168,8 +165,8 @@ class TrafficDirectorManager:
def backend_service_add_backends(self):
logging.info('Adding backends to Backend Service %s: %r',
self.backend_service.name, self.backends)
self.compute.backend_service_add_backends(
self.backend_service, self.backends)
self.compute.backend_service_add_backends(self.backend_service,
self.backends)
def backend_service_remove_all_backends(self):
logging.info('Removing backends from Backend Service %s',
@ -180,8 +177,8 @@ class TrafficDirectorManager:
logger.debug(
"Waiting for Backend Service %s to report all backends healthy %r",
self.backend_service, self.backends)
self.compute.wait_for_backends_healthy_status(
self.backend_service, self.backends)
self.compute.wait_for_backends_healthy_status(self.backend_service,
self.backends)
def create_url_map(
self,
@ -191,10 +188,11 @@ class TrafficDirectorManager:
src_address = f'{src_host}:{src_port}'
name = self._ns_name(self.URL_MAP_NAME)
matcher_name = self._ns_name(self.URL_MAP_PATH_MATCHER_NAME)
logger.info('Creating URL map %s %s -> %s',
name, src_address, self.backend_service.name)
resource = self.compute.create_url_map(
name, matcher_name, [src_address], self.backend_service)
logger.info('Creating URL map %s %s -> %s', name, src_address,
self.backend_service.name)
resource = self.compute.create_url_map(name, matcher_name,
[src_address],
self.backend_service)
self.url_map = resource
return resource
@ -212,10 +210,9 @@ class TrafficDirectorManager:
def create_target_grpc_proxy(self):
# todo: different kinds
name = self._ns_name(self.TARGET_PROXY_NAME)
logger.info('Creating target GRPC proxy %s to url map %s',
name, self.url_map.name)
resource = self.compute.create_target_grpc_proxy(
name, self.url_map)
logger.info('Creating target GRPC proxy %s to url map %s', name,
self.url_map.name)
resource = self.compute.create_target_grpc_proxy(name, self.url_map)
self.target_proxy = resource
def delete_target_grpc_proxy(self, force=False):
@ -233,10 +230,9 @@ class TrafficDirectorManager:
def create_target_http_proxy(self):
# todo: different kinds
name = self._ns_name(self.TARGET_PROXY_NAME)
logger.info('Creating target HTTP proxy %s to url map %s',
name, self.url_map.name)
resource = self.compute.create_target_http_proxy(
name, self.url_map)
logger.info('Creating target HTTP proxy %s to url map %s', name,
self.url_map.name)
resource = self.compute.create_target_http_proxy(name, self.url_map)
self.target_proxy = resource
self.target_proxy_is_http = True
@ -255,10 +251,11 @@ class TrafficDirectorManager:
def create_forwarding_rule(self, src_port: int):
name = self._ns_name(self.FORWARDING_RULE_NAME)
src_port = int(src_port)
logging.info('Creating forwarding rule %s 0.0.0.0:%s -> %s in %s',
name, src_port, self.target_proxy.url, self.network)
resource = self.compute.create_forwarding_rule(
name, src_port, self.target_proxy, self.network_url)
logging.info('Creating forwarding rule %s 0.0.0.0:%s -> %s in %s', name,
src_port, self.target_proxy.url, self.network)
resource = self.compute.create_forwarding_rule(name, src_port,
self.target_proxy,
self.network_url)
self.forwarding_rule = resource
return resource
@ -289,8 +286,10 @@ class TrafficDirectorSecureManager(TrafficDirectorManager):
resource_prefix: str,
network: str = 'default',
):
super().__init__(gcp_api_manager, project,
resource_prefix=resource_prefix, network=network)
super().__init__(gcp_api_manager,
project,
resource_prefix=resource_prefix,
network=network)
# API
self.netsec = NetworkSecurityV1Alpha1(gcp_api_manager, project)
@ -301,25 +300,28 @@ class TrafficDirectorSecureManager(TrafficDirectorManager):
self.ecs: Optional[EndpointConfigSelector] = None
self.client_tls_policy: Optional[ClientTlsPolicy] = None
def setup_for_grpc(
self,
service_host,
service_port,
*,
backend_protocol=BackendServiceProtocol.HTTP2
):
super().setup_for_grpc(service_host, service_port,
def setup_for_grpc(self,
service_host,
service_port,
*,
backend_protocol=BackendServiceProtocol.HTTP2):
super().setup_for_grpc(service_host,
service_port,
backend_protocol=backend_protocol)
def setup_server_security(self, server_port, *, tls, mtls):
self.create_server_tls_policy(tls=tls, mtls=mtls)
self.create_endpoint_config_selector(server_port)
def setup_client_security(self, server_namespace, server_name,
*, tls=True, mtls=True):
def setup_client_security(self,
server_namespace,
server_name,
*,
tls=True,
mtls=True):
self.create_client_tls_policy(tls=tls, mtls=mtls)
self.backend_service_apply_client_mtls_policy(
server_namespace, server_name)
self.backend_service_apply_client_mtls_policy(server_namespace,
server_name)
def cleanup(self, *, force=False):
# Cleanup in the reverse order of creation
@ -334,12 +336,16 @@ class TrafficDirectorSecureManager(TrafficDirectorManager):
name = self._ns_name(self.SERVER_TLS_POLICY_NAME)
logger.info('Creating Server TLS Policy %s', name)
if not tls and not mtls:
logger.warning('Server TLS Policy %s neither TLS, nor mTLS '
'policy. Skipping creation', name)
logger.warning(
'Server TLS Policy %s neither TLS, nor mTLS '
'policy. Skipping creation', name)
return
grpc_endpoint = {
"grpcEndpoint": {"targetUri": self.GRPC_ENDPOINT_TARGET_URI}}
"grpcEndpoint": {
"targetUri": self.GRPC_ENDPOINT_TARGET_URI
}
}
policy = {}
if tls:
@ -381,13 +387,16 @@ class TrafficDirectorSecureManager(TrafficDirectorManager):
"type": "SIDECAR_PROXY",
"httpFilters": {},
"trafficPortSelector": port_selector,
"endpointMatcher": {"metadataLabelMatcher": label_matcher_all},
"endpointMatcher": {
"metadataLabelMatcher": label_matcher_all
},
}
if self.server_tls_policy:
config["serverTlsPolicy"] = self.server_tls_policy.name
else:
logger.warning('Creating Endpoint Config Selector %s with '
'no Server TLS policy attached', name)
logger.warning(
'Creating Endpoint Config Selector %s with '
'no Server TLS policy attached', name)
self.netsvc.create_endpoint_config_selector(name, config)
self.ecs = self.netsvc.get_endpoint_config_selector(name)
@ -408,12 +417,16 @@ class TrafficDirectorSecureManager(TrafficDirectorManager):
name = self._ns_name(self.CLIENT_TLS_POLICY_NAME)
logger.info('Creating Client TLS Policy %s', name)
if not tls and not mtls:
logger.warning('Client TLS Policy %s neither TLS, nor mTLS '
'policy. Skipping creation', name)
logger.warning(
'Client TLS Policy %s neither TLS, nor mTLS '
'policy. Skipping creation', name)
return
grpc_endpoint = {
"grpcEndpoint": {"targetUri": self.GRPC_ENDPOINT_TARGET_URI}}
"grpcEndpoint": {
"targetUri": self.GRPC_ENDPOINT_TARGET_URI
}
}
policy = {}
if tls:
@ -442,21 +455,23 @@ class TrafficDirectorSecureManager(TrafficDirectorManager):
server_name,
):
if not self.client_tls_policy:
logger.warning('Client TLS policy not created, '
'skipping attaching to Backend Service %s',
self.backend_service.name)
logger.warning(
'Client TLS policy not created, '
'skipping attaching to Backend Service %s',
self.backend_service.name)
return
server_spiffe = (f'spiffe://{self.project}.svc.id.goog/'
f'ns/{server_namespace}/sa/{server_name}')
logging.info('Adding Client TLS Policy to Backend Service %s: %s, '
'server %s',
self.backend_service.name,
self.client_tls_policy.url,
server_spiffe)
self.compute.patch_backend_service(self.backend_service, {
'securitySettings': {
'clientTlsPolicy': self.client_tls_policy.url,
'subjectAltNames': [server_spiffe]
}})
logging.info(
'Adding Client TLS Policy to Backend Service %s: %s, '
'server %s', self.backend_service.name, self.client_tls_policy.url,
server_spiffe)
self.compute.patch_backend_service(
self.backend_service, {
'securitySettings': {
'clientTlsPolicy': self.client_tls_policy.url,
'subjectAltNames': [server_spiffe]
}
})

@ -37,7 +37,8 @@ class GrpcClientHelper:
self.service_name = re.sub('Stub$', '', self.stub.__class__.__name__)
def call_unary_when_channel_ready(
self, *,
self,
*,
rpc: str,
req: Message,
wait_for_ready_sec: Optional[int] = DEFAULT_WAIT_FOR_READY_SEC,
@ -56,8 +57,7 @@ class GrpcClientHelper:
return rpc_callable(req, **call_kwargs)
def _log_debug(self, rpc, req, call_kwargs):
logger.debug('RPC %s.%s(request=%s(%r), %s)',
self.service_name, rpc,
logger.debug('RPC %s.%s(request=%s(%r), %s)', self.service_name, rpc,
req.__class__.__name__, json_format.MessageToDict(req),
', '.join({f'{k}={v}' for k, v in call_kwargs.items()}))

@ -83,10 +83,8 @@ class ChannelzServiceClient(framework.rpc.GrpcClientHelper):
f'remote={cls.sock_address_to_str(socket.remote)}')
@staticmethod
def find_server_socket_matching_client(
server_sockets: Iterator[Socket],
client_socket: Socket
) -> Socket:
def find_server_socket_matching_client(server_sockets: Iterator[Socket],
client_socket: Socket) -> Socket:
for server_socket in server_sockets:
if server_socket.remote == client_socket.local:
return server_socket
@ -103,7 +101,7 @@ class ChannelzServiceClient(framework.rpc.GrpcClientHelper):
listen_socket = self.get_socket(listen_socket_ref.socket_id)
listen_address: Address = listen_socket.local
if (self.is_sock_tcpip_address(listen_address) and
listen_address.tcpip_address.port == port):
listen_address.tcpip_address.port == port):
return server
return None
@ -136,8 +134,7 @@ class ChannelzServiceClient(framework.rpc.GrpcClientHelper):
# value by adding 1 to the highest seen result ID.
start += 1
response = self.call_unary_when_channel_ready(
rpc='GetServers',
req=GetServersRequest(start_server_id=start))
rpc='GetServers', req=GetServersRequest(start_server_id=start))
for server in response.server:
start = max(start, server.ref.server_id)
yield server
@ -170,6 +167,5 @@ class ChannelzServiceClient(framework.rpc.GrpcClientHelper):
def get_socket(self, socket_id) -> Socket:
"""Return a single Socket, otherwise raises RpcError."""
response: GetSocketResponse = self.call_unary_when_channel_ready(
rpc='GetSocket',
req=GetSocketRequest(socket_id=socket_id))
rpc='GetSocket', req=GetSocketRequest(socket_id=socket_id))
return response.socket

@ -19,7 +19,6 @@ import framework.rpc
from src.proto.grpc.testing import test_pb2_grpc
from src.proto.grpc.testing import messages_pb2
# Type aliases
LoadBalancerStatsRequest = messages_pb2.LoadBalancerStatsRequest
LoadBalancerStatsResponse = messages_pb2.LoadBalancerStatsResponse
@ -33,7 +32,8 @@ class LoadBalancerStatsServiceClient(framework.rpc.GrpcClientHelper):
super().__init__(channel, test_pb2_grpc.LoadBalancerStatsServiceStub)
def get_client_stats(
self, *,
self,
*,
num_rpcs: int,
timeout_sec: Optional[int] = STATS_PARTIAL_RESULTS_TIMEOUT_SEC,
) -> LoadBalancerStatsResponse:

@ -33,6 +33,7 @@ TEMPLATE_DIR = '../../kubernetes-manifests'
class KubernetesBaseRunner:
def __init__(self,
k8s_namespace,
namespace_template=None,
@ -50,8 +51,7 @@ class KubernetesBaseRunner:
self.namespace = self._reuse_namespace()
if not self.namespace:
self.namespace = self._create_namespace(
self.namespace_template,
namespace_name=self.k8s_namespace.name)
self.namespace_template, namespace_name=self.k8s_namespace.name)
def cleanup(self, *, force=False):
if (self.namespace and not self.reuse_namespace) or force:
@ -127,27 +127,21 @@ class KubernetesBaseRunner:
raise RunnerError('Expected V1Namespace to be created '
f'from manifest {template}')
if namespace.metadata.name != kwargs['namespace_name']:
raise RunnerError(
'Namespace created with unexpected name: '
f'{namespace.metadata.name}')
logger.info('Deployment %s created at %s',
namespace.metadata.self_link,
raise RunnerError('Namespace created with unexpected name: '
f'{namespace.metadata.name}')
logger.info('Deployment %s created at %s', namespace.metadata.self_link,
namespace.metadata.creation_timestamp)
return namespace
def _create_service_account(
self,
template,
**kwargs
) -> k8s.V1ServiceAccount:
def _create_service_account(self, template,
**kwargs) -> k8s.V1ServiceAccount:
resource = self._create_from_template(template, **kwargs)
if not isinstance(resource, k8s.V1ServiceAccount):
raise RunnerError('Expected V1ServiceAccount to be created '
f'from manifest {template}')
if resource.metadata.name != kwargs['service_account_name']:
raise RunnerError(
'V1ServiceAccount created with unexpected name: '
f'{resource.metadata.name}')
raise RunnerError('V1ServiceAccount created with unexpected name: '
f'{resource.metadata.name}')
logger.info('V1ServiceAccount %s created at %s',
resource.metadata.self_link,
resource.metadata.creation_timestamp)
@ -159,9 +153,8 @@ class KubernetesBaseRunner:
raise RunnerError('Expected V1Deployment to be created '
f'from manifest {template}')
if deployment.metadata.name != kwargs['deployment_name']:
raise RunnerError(
'Deployment created with unexpected name: '
f'{deployment.metadata.name}')
raise RunnerError('Deployment created with unexpected name: '
f'{deployment.metadata.name}')
logger.info('Deployment %s created at %s',
deployment.metadata.self_link,
deployment.metadata.creation_timestamp)
@ -173,11 +166,9 @@ class KubernetesBaseRunner:
raise RunnerError('Expected V1Service to be created '
f'from manifest {template}')
if service.metadata.name != kwargs['service_name']:
raise RunnerError(
'Service created with unexpected name: '
f'{service.metadata.name}')
logger.info('Service %s created at %s',
service.metadata.self_link,
raise RunnerError('Service created with unexpected name: '
f'{service.metadata.name}')
logger.info('Service %s created at %s', service.metadata.self_link,
service.metadata.creation_timestamp)
return service
@ -185,8 +176,8 @@ class KubernetesBaseRunner:
try:
self.k8s_namespace.delete_deployment(name)
except k8s.ApiException as e:
logger.info('Deployment %s deletion failed, error: %s %s',
name, e.status, e.reason)
logger.info('Deployment %s deletion failed, error: %s %s', name,
e.status, e.reason)
return
if wait_for_deletion:
@ -197,8 +188,8 @@ class KubernetesBaseRunner:
try:
self.k8s_namespace.delete_service(name)
except k8s.ApiException as e:
logger.info('Service %s deletion failed, error: %s %s',
name, e.status, e.reason)
logger.info('Service %s deletion failed, error: %s %s', name,
e.status, e.reason)
return
if wait_for_deletion:
@ -232,8 +223,8 @@ class KubernetesBaseRunner:
def _wait_deployment_with_available_replicas(self, name, count=1, **kwargs):
logger.info('Waiting for deployment %s to have %s available replicas',
name, count)
self.k8s_namespace.wait_for_deployment_available_replicas(name, count,
**kwargs)
self.k8s_namespace.wait_for_deployment_available_replicas(
name, count, **kwargs)
deployment = self.k8s_namespace.get_deployment(name)
logger.info('Deployment %s has %i replicas available',
deployment.metadata.name,
@ -251,5 +242,5 @@ class KubernetesBaseRunner:
self.k8s_namespace.wait_for_service_neg(name, **kwargs)
neg_name, neg_zones = self.k8s_namespace.get_service_neg(
name, service_port)
logger.info("Service %s: detected NEG=%s in zones=%s", name,
neg_name, neg_zones)
logger.info("Service %s: detected NEG=%s in zones=%s", name, neg_name,
neg_zones)

@ -32,7 +32,9 @@ LoadBalancerStatsServiceClient = grpc_testing.LoadBalancerStatsServiceClient
class XdsTestClient(framework.rpc.GrpcApp):
def __init__(self, *,
def __init__(self,
*,
ip: str,
rpc_port: int,
server_target: str,
@ -55,7 +57,8 @@ class XdsTestClient(framework.rpc.GrpcApp):
return ChannelzServiceClient(self._make_channel(self.maintenance_port))
def get_load_balancer_stats(
self, *,
self,
*,
num_rpcs: int,
timeout_sec: Optional[int] = None,
) -> grpc_testing.LoadBalancerStatsResponse:
@ -76,16 +79,14 @@ class XdsTestClient(framework.rpc.GrpcApp):
stop=tenacity.stop_after_delay(60 * 3),
reraise=True)
channel = retryer(self.get_active_server_channel)
logger.info(
'Active server channel found: channel_id: %s, %s',
channel.ref.channel_id, channel.ref.name)
logger.info('Active server channel found: channel_id: %s, %s',
channel.ref.channel_id, channel.ref.name)
logger.debug('Server channel:\n%r', channel)
def get_active_server_channel(self) -> Optional[grpc_channelz.Channel]:
for channel in self.get_server_channels():
state: ChannelConnectivityState = channel.data.state
logger.debug('Server channel: %s, state: %s',
channel.ref.name,
logger.debug('Server channel: %s, state: %s', channel.ref.name,
ChannelConnectivityState.State.Name(state.state))
if state.state is ChannelConnectivityState.READY:
return channel
@ -107,6 +108,7 @@ class XdsTestClient(framework.rpc.GrpcApp):
class KubernetesClientRunner(base_runner.KubernetesBaseRunner):
def __init__(self,
k8s_namespace,
*,
@ -142,9 +144,11 @@ class KubernetesClientRunner(base_runner.KubernetesBaseRunner):
self.service_account: Optional[k8s.V1ServiceAccount] = None
self.port_forwarder = None
def run(self, *,
def run(self,
*,
server_target,
rpc='UnaryCall', qps=25,
rpc='UnaryCall',
qps=25,
secure_mode=False,
print_response=False) -> XdsTestClient:
super().run()
@ -183,8 +187,8 @@ class KubernetesClientRunner(base_runner.KubernetesBaseRunner):
# Experimental, for local debugging.
if self.debug_use_port_forwarding:
logger.info('Enabling port forwarding from %s:%s',
pod_ip, self.stats_port)
logger.info('Enabling port forwarding from %s:%s', pod_ip,
self.stats_port)
self.port_forwarder = self.k8s_namespace.port_forward_pod(
pod, remote_port=self.stats_port)
rpc_host = self.k8s_namespace.PORT_FORWARD_LOCAL_ADDRESS

@ -27,7 +27,9 @@ ChannelzServiceClient = grpc_channelz.ChannelzServiceClient
class XdsTestServer(framework.rpc.GrpcApp):
def __init__(self, *,
def __init__(self,
*,
ip: str,
rpc_port: int,
maintenance_port: Optional[int] = None,
@ -54,13 +56,16 @@ class XdsTestServer(framework.rpc.GrpcApp):
@property
def xds_address(self) -> str:
if not self.xds_host: return ''
if not self.xds_port: return self.xds_host
if not self.xds_host:
return ''
if not self.xds_port:
return self.xds_host
return f'{self.xds_host}:{self.xds_port}'
@property
def xds_uri(self) -> str:
if not self.xds_host: return ''
if not self.xds_host:
return ''
return f'xds:///{self.xds_address}'
def get_test_server(self):
@ -74,10 +79,8 @@ class XdsTestServer(framework.rpc.GrpcApp):
server = self.get_test_server()
return self.channelz.list_server_sockets(server.ref.server_id)
def get_server_socket_matching_client(
self,
client_socket: grpc_channelz.Socket
):
def get_server_socket_matching_client(self,
client_socket: grpc_channelz.Socket):
client_local = self.channelz.sock_address_to_str(client_socket.local)
logger.debug('Looking for a server socket connected to the client %s',
client_local)
@ -95,6 +98,7 @@ class XdsTestServer(framework.rpc.GrpcApp):
class KubernetesServerRunner(base_runner.KubernetesBaseRunner):
def __init__(self,
k8s_namespace,
*,
@ -140,9 +144,12 @@ class KubernetesServerRunner(base_runner.KubernetesBaseRunner):
self.service: Optional[k8s.V1Service] = None
self.port_forwarder = None
def run(self, *,
test_port=8080, maintenance_port=None,
secure_mode=False, server_id=None,
def run(self,
*,
test_port=8080,
maintenance_port=None,
secure_mode=False,
server_id=None,
replica_count=1) -> XdsTestServer:
# todo(sergiitk): multiple replicas
if replica_count != 1:
@ -201,8 +208,9 @@ class KubernetesServerRunner(base_runner.KubernetesBaseRunner):
server_id=server_id,
secure_mode=secure_mode)
self._wait_deployment_with_available_replicas(
self.deployment_name, replica_count, timeout_sec=120)
self._wait_deployment_with_available_replicas(self.deployment_name,
replica_count,
timeout_sec=120)
# Wait for pods running
pods = self.k8s_namespace.list_deployment_pods(self.deployment)
@ -215,19 +223,18 @@ class KubernetesServerRunner(base_runner.KubernetesBaseRunner):
rpc_host = None
# Experimental, for local debugging.
if self.debug_use_port_forwarding:
logger.info('Enabling port forwarding from %s:%s',
pod_ip, maintenance_port)
logger.info('Enabling port forwarding from %s:%s', pod_ip,
maintenance_port)
self.port_forwarder = self.k8s_namespace.port_forward_pod(
pod, remote_port=maintenance_port)
rpc_host = self.k8s_namespace.PORT_FORWARD_LOCAL_ADDRESS
return XdsTestServer(
ip=pod_ip,
rpc_port=test_port,
maintenance_port=maintenance_port,
secure_mode=secure_mode,
server_id=server_id,
rpc_host=rpc_host)
return XdsTestServer(ip=pod_ip,
rpc_port=test_port,
maintenance_port=maintenance_port,
secure_mode=secure_mode,
server_id=server_id,
rpc_host=rpc_host)
def cleanup(self, *, force=False, force_namespace=False):
if self.port_forwarder:

@ -15,35 +15,38 @@ from absl import flags
import googleapiclient.discovery
# GCP
PROJECT = flags.DEFINE_string(
"project", default=None, help="GCP Project ID. Required")
PROJECT = flags.DEFINE_string("project",
default=None,
help="GCP Project ID. Required")
NAMESPACE = flags.DEFINE_string(
"namespace", default=None,
"namespace",
default=None,
help="Isolate GCP resources using given namespace / name prefix. Required")
NETWORK = flags.DEFINE_string(
"network", default="default", help="GCP Network ID")
NETWORK = flags.DEFINE_string("network",
default="default",
help="GCP Network ID")
# Test server
SERVER_NAME = flags.DEFINE_string(
"server_name", default="psm-grpc-server",
help="Server deployment and service name")
SERVER_PORT = flags.DEFINE_integer(
"server_port", default=8080,
help="Server test port")
SERVER_XDS_HOST = flags.DEFINE_string(
"server_xds_host", default='xds-test-server',
help="Test server xDS hostname")
SERVER_XDS_PORT = flags.DEFINE_integer(
"server_xds_port", default=8000, help="Test server xDS port")
SERVER_NAME = flags.DEFINE_string("server_name",
default="psm-grpc-server",
help="Server deployment and service name")
SERVER_PORT = flags.DEFINE_integer("server_port",
default=8080,
help="Server test port")
SERVER_XDS_HOST = flags.DEFINE_string("server_xds_host",
default='xds-test-server',
help="Test server xDS hostname")
SERVER_XDS_PORT = flags.DEFINE_integer("server_xds_port",
default=8000,
help="Test server xDS port")
# Test client
CLIENT_NAME = flags.DEFINE_string(
"client_name", default="psm-grpc-client",
help="Client deployment and service name")
CLIENT_PORT = flags.DEFINE_integer(
"client_port", default=8079,
help="Client test port")
CLIENT_NAME = flags.DEFINE_string("client_name",
default="psm-grpc-client",
help="Client deployment and service name")
CLIENT_PORT = flags.DEFINE_integer("client_port",
default=8079,
help="Client test port")
flags.mark_flags_as_required([
"project",

@ -14,24 +14,28 @@
from absl import flags
# GCP
KUBE_CONTEXT = flags.DEFINE_string(
"kube_context", default=None, help="Kubectl context to use")
KUBE_CONTEXT = flags.DEFINE_string("kube_context",
default=None,
help="Kubectl context to use")
GCP_SERVICE_ACCOUNT = flags.DEFINE_string(
"gcp_service_account", default=None,
"gcp_service_account",
default=None,
help="GCP Service account for GKE workloads to impersonate")
TD_BOOTSTRAP_IMAGE = flags.DEFINE_string(
"td_bootstrap_image", default=None,
"td_bootstrap_image",
default=None,
help="Traffic Director gRPC Bootstrap Docker image")
# Test app
SERVER_IMAGE = flags.DEFINE_string(
"server_image", default=None,
help="Server Docker image name")
CLIENT_IMAGE = flags.DEFINE_string(
"client_image", default=None,
help="Client Docker image name")
SERVER_IMAGE = flags.DEFINE_string("server_image",
default=None,
help="Server Docker image name")
CLIENT_IMAGE = flags.DEFINE_string("client_image",
default=None,
help="Client Docker image name")
CLIENT_PORT_FORWARDING = flags.DEFINE_bool(
"client_debug_use_port_forwarding", default=False,
"client_debug_use_port_forwarding",
default=False,
help="Development only: use kubectl port-forward to connect to test client")
flags.mark_flags_as_required([

@ -107,11 +107,9 @@ class XdsKubernetesTestCase(absltest.TestCase):
# Add backends to the Backend Service
self.td.backend_service_add_neg_backends(neg_name, neg_zones)
def assertSuccessfulRpcs(
self,
test_client: XdsTestClient,
num_rpcs: int = 100
):
def assertSuccessfulRpcs(self,
test_client: XdsTestClient,
num_rpcs: int = 100):
# Run the test
lb_stats = test_client.get_load_balancer_stats(num_rpcs=num_rpcs)
# Check the results
@ -123,17 +121,20 @@ class XdsKubernetesTestCase(absltest.TestCase):
logger.info(lb_stats.rpcs_by_peer)
for backend, rpcs_count in lb_stats.rpcs_by_peer.items():
self.assertGreater(
int(rpcs_count), 0,
int(rpcs_count),
0,
msg='Backend {backend} did not receive a single RPC')
def assertFailedRpcsAtMost(self, lb_stats, limit):
failed = int(lb_stats.num_failures)
self.assertLessEqual(
failed, limit,
failed,
limit,
msg=f'Unexpected number of RPC failures {failed} > {limit}')
class RegularXdsKubernetesTestCase(XdsKubernetesTestCase):
def setUp(self):
super().setUp()
@ -168,15 +169,13 @@ class RegularXdsKubernetesTestCase(XdsKubernetesTestCase):
reuse_namespace=self.server_namespace == self.client_namespace)
def startTestServer(self, replica_count=1, **kwargs) -> XdsTestServer:
test_server = self.server_runner.run(
replica_count=replica_count,
test_port=self.server_port,
**kwargs)
test_server = self.server_runner.run(replica_count=replica_count,
test_port=self.server_port,
**kwargs)
test_server.set_xds_address(self.server_xds_host, self.server_xds_port)
return test_server
def startTestClient(self,
test_server: XdsTestServer,
def startTestClient(self, test_server: XdsTestServer,
**kwargs) -> XdsTestClient:
test_client = self.client_runner.run(server_target=test_server.xds_uri,
**kwargs)
@ -187,6 +186,7 @@ class RegularXdsKubernetesTestCase(XdsKubernetesTestCase):
class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
class SecurityMode(enum.Enum):
MTLS = enum.auto()
TLS = enum.auto()
@ -229,43 +229,39 @@ class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
debug_use_port_forwarding=self.client_port_forwarding)
def startSecureTestServer(self, replica_count=1, **kwargs) -> XdsTestServer:
test_server = self.server_runner.run(
replica_count=replica_count,
test_port=self.server_port,
maintenance_port=8081,
secure_mode=True,
**kwargs)
test_server = self.server_runner.run(replica_count=replica_count,
test_port=self.server_port,
maintenance_port=8081,
secure_mode=True,
**kwargs)
test_server.set_xds_address(self.server_xds_host, self.server_xds_port)
return test_server
def setupSecurityPolicies(self, *,
server_tls, server_mtls,
client_tls, client_mtls):
self.td.setup_client_security(self.server_namespace, self.server_name,
tls=client_tls, mtls=client_mtls)
def setupSecurityPolicies(self, *, server_tls, server_mtls, client_tls,
client_mtls):
self.td.setup_client_security(self.server_namespace,
self.server_name,
tls=client_tls,
mtls=client_mtls)
self.td.setup_server_security(self.server_port,
tls=server_tls, mtls=server_mtls)
def startSecureTestClient(
self,
test_server: XdsTestServer,
**kwargs
) -> XdsTestClient:
test_client = self.client_runner.run(
server_target=test_server.xds_uri,
secure_mode=True,
**kwargs)
tls=server_tls,
mtls=server_mtls)
def startSecureTestClient(self, test_server: XdsTestServer,
**kwargs) -> XdsTestClient:
test_client = self.client_runner.run(server_target=test_server.xds_uri,
secure_mode=True,
**kwargs)
logger.debug('Waiting fot the client to establish healthy channel with '
'the server')
test_client.wait_for_active_server_channel()
return test_client
def assertTestAppSecurity(self,
mode: SecurityMode,
def assertTestAppSecurity(self, mode: SecurityMode,
test_client: XdsTestClient,
test_server: XdsTestServer):
client_socket, server_socket = self.getConnectedSockets(test_client,
test_server)
client_socket, server_socket = self.getConnectedSockets(
test_client, test_server)
server_security: grpc_channelz.Security = server_socket.security
client_security: grpc_channelz.Security = client_socket.security
logger.info('Server certs: %s', self.debug_sock_certs(server_security))
@ -280,72 +276,70 @@ class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
else:
raise TypeError(f'Incorrect security mode')
def assertSecurityMtls(self,
client_security: grpc_channelz.Security,
def assertSecurityMtls(self, client_security: grpc_channelz.Security,
server_security: grpc_channelz.Security):
self.assertEqual(client_security.WhichOneof('model'), 'tls',
self.assertEqual(client_security.WhichOneof('model'),
'tls',
msg='(mTLS) Client socket security model must be TLS')
self.assertEqual(server_security.WhichOneof('model'), 'tls',
self.assertEqual(server_security.WhichOneof('model'),
'tls',
msg='(mTLS) Server socket security model must be TLS')
server_tls, client_tls = server_security.tls, client_security.tls
# Confirm regular TLS: server local cert == client remote cert
self.assertNotEmpty(
self.assertNotEmpty(server_tls.local_certificate,
msg="(mTLS) Server local certificate is missing")
self.assertNotEmpty(client_tls.remote_certificate,
msg="(mTLS) Client remote certificate is missing")
self.assertEqual(
server_tls.local_certificate,
msg="(mTLS) Server local certificate is missing")
self.assertNotEmpty(
client_tls.remote_certificate,
msg="(mTLS) Client remote certificate is missing")
self.assertEqual(
server_tls.local_certificate, client_tls.remote_certificate,
msg="(mTLS) Server local certificate must match client's "
"remote certificate")
"remote certificate")
# mTLS: server remote cert == client local cert
self.assertNotEmpty(
self.assertNotEmpty(server_tls.remote_certificate,
msg="(mTLS) Server remote certificate is missing")
self.assertNotEmpty(client_tls.local_certificate,
msg="(mTLS) Client local certificate is missing")
self.assertEqual(
server_tls.remote_certificate,
msg="(mTLS) Server remote certificate is missing")
self.assertNotEmpty(
client_tls.local_certificate,
msg="(mTLS) Client local certificate is missing")
self.assertEqual(
server_tls.remote_certificate, client_tls.local_certificate,
msg="(mTLS) Server remote certificate must match client's "
"local certificate")
"local certificate")
# Success
logger.info('mTLS security mode confirmed!')
def assertSecurityTls(self,
client_security: grpc_channelz.Security,
def assertSecurityTls(self, client_security: grpc_channelz.Security,
server_security: grpc_channelz.Security):
self.assertEqual(client_security.WhichOneof('model'), 'tls',
self.assertEqual(client_security.WhichOneof('model'),
'tls',
msg='(TLS) Client socket security model must be TLS')
self.assertEqual(server_security.WhichOneof('model'), 'tls',
self.assertEqual(server_security.WhichOneof('model'),
'tls',
msg='(TLS) Server socket security model must be TLS')
server_tls, client_tls = server_security.tls, client_security.tls
# Regular TLS: server local cert == client remote cert
self.assertNotEmpty(
server_tls.local_certificate,
msg="(TLS) Server local certificate is missing")
self.assertNotEmpty(
client_tls.remote_certificate,
msg="(TLS) Client remote certificate is missing")
self.assertEqual(
server_tls.local_certificate, client_tls.remote_certificate,
msg="(TLS) Server local certificate must match client "
"remote certificate")
self.assertNotEmpty(server_tls.local_certificate,
msg="(TLS) Server local certificate is missing")
self.assertNotEmpty(client_tls.remote_certificate,
msg="(TLS) Client remote certificate is missing")
self.assertEqual(server_tls.local_certificate,
client_tls.remote_certificate,
msg="(TLS) Server local certificate must match client "
"remote certificate")
# mTLS must not be used
self.assertEmpty(
server_tls.remote_certificate,
msg="(TLS) Server remote certificate must be empty in TLS mode. "
"Is server security incorrectly configured for mTLS?")
"Is server security incorrectly configured for mTLS?")
self.assertEmpty(
client_tls.local_certificate,
msg="(TLS) Client local certificate must be empty in TLS mode. "
"Is client security incorrectly configured for mTLS?")
"Is client security incorrectly configured for mTLS?")
# Success
logger.info('TLS security mode confirmed!')
@ -373,8 +367,7 @@ class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
@staticmethod
def getConnectedSockets(
test_client: XdsTestClient,
test_server: XdsTestServer
test_client: XdsTestClient, test_server: XdsTestServer
) -> Tuple[grpc_channelz.Socket, grpc_channelz.Socket]:
client_sock = test_client.get_client_socket_with_test_server()
server_sock = test_server.get_server_socket_matching_client(client_sock)
@ -390,6 +383,7 @@ class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
@staticmethod
def debug_cert(cert):
if not cert: return 'missing'
if not cert:
return 'missing'
sha1 = hashlib.sha1(cert)
return f'sha1={sha1.hexdigest()}, len={len(cert)}'

@ -27,6 +27,7 @@ XdsTestClient = xds_k8s_testcase.XdsTestClient
class BaselineTest(xds_k8s_testcase.RegularXdsKubernetesTestCase):
def test_ping_pong(self):
self.setupTrafficDirectorGrpc()

@ -29,10 +29,13 @@ SecurityMode = xds_k8s_testcase.SecurityXdsKubernetesTestCase.SecurityMode
class SecurityTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase):
def test_mtls(self):
self.setupTrafficDirectorGrpc()
self.setupSecurityPolicies(server_tls=True, server_mtls=True,
client_tls=True, client_mtls=True)
self.setupSecurityPolicies(server_tls=True,
server_mtls=True,
client_tls=True,
client_mtls=True)
test_server: XdsTestServer = self.startSecureTestServer()
self.setupServerBackends()
@ -43,8 +46,10 @@ class SecurityTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase):
def test_tls(self):
self.setupTrafficDirectorGrpc()
self.setupSecurityPolicies(server_tls=True, server_mtls=False,
client_tls=True, client_mtls=False)
self.setupSecurityPolicies(server_tls=True,
server_mtls=False,
client_tls=True,
client_mtls=False)
test_server: XdsTestServer = self.startSecureTestServer()
self.setupServerBackends()
@ -55,15 +60,17 @@ class SecurityTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase):
def test_plaintext_fallback(self):
self.setupTrafficDirectorGrpc()
self.setupSecurityPolicies(server_tls=False, server_mtls=False,
client_tls=False, client_mtls=False)
self.setupSecurityPolicies(server_tls=False,
server_mtls=False,
client_tls=False,
client_mtls=False)
test_server: XdsTestServer = self.startSecureTestServer()
self.setupServerBackends()
test_client: XdsTestClient = self.startSecureTestClient(test_server)
self.assertTestAppSecurity(
SecurityMode.PLAINTEXT, test_client, test_server)
self.assertTestAppSecurity(SecurityMode.PLAINTEXT, test_client,
test_server)
self.assertSuccessfulRpcs(test_client)
@absltest.skip(SKIP_REASON)

Loading…
Cancel
Save