yapf autoformat

pull/24983/head
Sergii Tkachenko 4 years ago
parent 86f8792136
commit 28a6f740f5
  1. 10
      tools/run_tests/xds_test_driver/bin/run_channelz.py
  2. 56
      tools/run_tests/xds_test_driver/bin/run_td_setup.py
  3. 28
      tools/run_tests/xds_test_driver/bin/run_test_client.py
  4. 22
      tools/run_tests/xds_test_driver/bin/run_test_server.py
  5. 1
      tools/run_tests/xds_test_driver/framework/infrastructure/gcp/__init__.py
  6. 54
      tools/run_tests/xds_test_driver/framework/infrastructure/gcp/api.py
  7. 150
      tools/run_tests/xds_test_driver/framework/infrastructure/gcp/compute.py
  8. 26
      tools/run_tests/xds_test_driver/framework/infrastructure/gcp/network_security.py
  9. 3
      tools/run_tests/xds_test_driver/framework/infrastructure/gcp/network_services.py
  10. 116
      tools/run_tests/xds_test_driver/framework/infrastructure/k8s.py
  11. 149
      tools/run_tests/xds_test_driver/framework/infrastructure/traffic_director.py
  12. 6
      tools/run_tests/xds_test_driver/framework/rpc/__init__.py
  13. 14
      tools/run_tests/xds_test_driver/framework/rpc/grpc_channelz.py
  14. 4
      tools/run_tests/xds_test_driver/framework/rpc/grpc_testing.py
  15. 53
      tools/run_tests/xds_test_driver/framework/test_app/base_runner.py
  16. 26
      tools/run_tests/xds_test_driver/framework/test_app/client_app.py
  17. 51
      tools/run_tests/xds_test_driver/framework/test_app/server_app.py
  18. 49
      tools/run_tests/xds_test_driver/framework/xds_flags.py
  19. 26
      tools/run_tests/xds_test_driver/framework/xds_k8s_flags.py
  20. 146
      tools/run_tests/xds_test_driver/framework/xds_k8s_testcase.py
  21. 1
      tools/run_tests/xds_test_driver/tests/baseline_test.py
  22. 23
      tools/run_tests/xds_test_driver/tests/security_test.py

@ -26,10 +26,12 @@ from framework.test_app import client_app
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Flags # Flags
_SERVER_RPC_HOST = flags.DEFINE_string( _SERVER_RPC_HOST = flags.DEFINE_string('server_rpc_host',
'server_rpc_host', default='127.0.0.1', help='Server RPC host') default='127.0.0.1',
_CLIENT_RPC_HOST = flags.DEFINE_string( help='Server RPC host')
'client_rpc_host', default='127.0.0.1', help='Client RPC host') _CLIENT_RPC_HOST = flags.DEFINE_string('client_rpc_host',
default='127.0.0.1',
help='Client RPC host')
flags.adopt_module_key_flags(xds_flags) flags.adopt_module_key_flags(xds_flags)
flags.adopt_module_key_flags(xds_k8s_flags) flags.adopt_module_key_flags(xds_k8s_flags)

@ -22,17 +22,19 @@ from framework.infrastructure import gcp
from framework.infrastructure import k8s from framework.infrastructure import k8s
from framework.infrastructure import traffic_director from framework.infrastructure import traffic_director
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Flags # Flags
_CMD = flags.DEFINE_enum( _CMD = flags.DEFINE_enum('cmd',
'cmd', default='create', default='create',
enum_values=['cycle', 'create', 'cleanup', enum_values=[
'backends-add', 'backends-cleanup'], 'cycle', 'create', 'cleanup', 'backends-add',
help='Command') 'backends-cleanup'
_SECURITY = flags.DEFINE_enum( ],
'security', default=None, enum_values=['mtls', 'tls', 'plaintext'], help='Command')
help='Configure td with security') _SECURITY = flags.DEFINE_enum('security',
default=None,
enum_values=['mtls', 'tls', 'plaintext'],
help='Configure td with security')
flags.adopt_module_key_flags(xds_flags) flags.adopt_module_key_flags(xds_flags)
flags.adopt_module_key_flags(xds_k8s_flags) flags.adopt_module_key_flags(xds_k8s_flags)
@ -57,11 +59,10 @@ def main(argv):
gcp_api_manager = gcp.api.GcpApiManager() gcp_api_manager = gcp.api.GcpApiManager()
if security_mode is None: if security_mode is None:
td = traffic_director.TrafficDirectorManager( td = traffic_director.TrafficDirectorManager(gcp_api_manager,
gcp_api_manager, project=project,
project=project, resource_prefix=namespace,
resource_prefix=namespace, network=network)
network=network)
else: else:
td = traffic_director.TrafficDirectorSecureManager( td = traffic_director.TrafficDirectorSecureManager(
gcp_api_manager, gcp_api_manager,
@ -80,26 +81,29 @@ def main(argv):
elif security_mode == 'mtls': elif security_mode == 'mtls':
logger.info('Setting up mtls') logger.info('Setting up mtls')
td.setup_for_grpc(server_xds_host, server_xds_port) td.setup_for_grpc(server_xds_host, server_xds_port)
td.setup_server_security(server_port, td.setup_server_security(server_port, tls=True, mtls=True)
tls=True, mtls=True) td.setup_client_security(namespace,
td.setup_client_security(namespace, server_name, server_name,
tls=True, mtls=True) tls=True,
mtls=True)
elif security_mode == 'tls': elif security_mode == 'tls':
logger.info('Setting up tls') logger.info('Setting up tls')
td.setup_for_grpc(server_xds_host, server_xds_port) td.setup_for_grpc(server_xds_host, server_xds_port)
td.setup_server_security(server_port, td.setup_server_security(server_port, tls=True, mtls=False)
tls=True, mtls=False) td.setup_client_security(namespace,
td.setup_client_security(namespace, server_name, server_name,
tls=True, mtls=False) tls=True,
mtls=False)
elif security_mode == 'plaintext': elif security_mode == 'plaintext':
logger.info('Setting up plaintext') logger.info('Setting up plaintext')
td.setup_for_grpc(server_xds_host, server_xds_port) td.setup_for_grpc(server_xds_host, server_xds_port)
td.setup_server_security(server_port, td.setup_server_security(server_port, tls=False, mtls=False)
tls=False, mtls=False) td.setup_client_security(namespace,
td.setup_client_security(namespace, server_name, server_name,
tls=False, mtls=False) tls=False,
mtls=False)
logger.info('Works!') logger.info('Works!')
except Exception: except Exception:

@ -23,21 +23,23 @@ from framework.test_app import client_app
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Flags # Flags
_CMD = flags.DEFINE_enum( _CMD = flags.DEFINE_enum('cmd',
'cmd', default='run', enum_values=['run', 'cleanup'], default='run',
help='Command') enum_values=['run', 'cleanup'],
_SECURE = flags.DEFINE_bool( help='Command')
"secure", default=False, _SECURE = flags.DEFINE_bool("secure",
help="Run client in the secure mode") default=False,
help="Run client in the secure mode")
_QPS = flags.DEFINE_integer('qps', default=25, help='Queries per second') _QPS = flags.DEFINE_integer('qps', default=25, help='Queries per second')
_PRINT_RESPONSE = flags.DEFINE_bool( _PRINT_RESPONSE = flags.DEFINE_bool("print_response",
"print_response", default=False, default=False,
help="Client prints responses") help="Client prints responses")
_REUSE_NAMESPACE = flags.DEFINE_bool( _REUSE_NAMESPACE = flags.DEFINE_bool("reuse_namespace",
"reuse_namespace", default=True, default=True,
help="Use existing namespace if exists") help="Use existing namespace if exists")
_CLEANUP_NAMESPACE = flags.DEFINE_bool( _CLEANUP_NAMESPACE = flags.DEFINE_bool(
"cleanup_namespace", default=False, "cleanup_namespace",
default=False,
help="Delete namespace during resource cleanup") help="Delete namespace during resource cleanup")
flags.adopt_module_key_flags(xds_flags) flags.adopt_module_key_flags(xds_flags)
flags.adopt_module_key_flags(xds_k8s_flags) flags.adopt_module_key_flags(xds_k8s_flags)

@ -35,17 +35,19 @@ from framework.test_app import server_app
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Flags # Flags
_CMD = flags.DEFINE_enum( _CMD = flags.DEFINE_enum('cmd',
'cmd', default='run', enum_values=['run', 'cleanup'], default='run',
help='Command') enum_values=['run', 'cleanup'],
_SECURE = flags.DEFINE_bool( help='Command')
"secure", default=False, _SECURE = flags.DEFINE_bool("secure",
help="Run server in the secure mode") default=False,
_REUSE_NAMESPACE = flags.DEFINE_bool( help="Run server in the secure mode")
"reuse_namespace", default=True, _REUSE_NAMESPACE = flags.DEFINE_bool("reuse_namespace",
help="Use existing namespace if exists") default=True,
help="Use existing namespace if exists")
_CLEANUP_NAMESPACE = flags.DEFINE_bool( _CLEANUP_NAMESPACE = flags.DEFINE_bool(
"cleanup_namespace", default=False, "cleanup_namespace",
default=False,
help="Delete namespace during resource cleanup") help="Delete namespace during resource cleanup")
flags.adopt_module_key_flags(xds_flags) flags.adopt_module_key_flags(xds_flags)
flags.adopt_module_key_flags(xds_k8s_flags) flags.adopt_module_key_flags(xds_k8s_flags)

@ -15,4 +15,3 @@ from framework.infrastructure.gcp import api
from framework.infrastructure.gcp import compute from framework.infrastructure.gcp import compute
from framework.infrastructure.gcp import network_security from framework.infrastructure.gcp import network_security
from framework.infrastructure.gcp import network_services from framework.infrastructure.gcp import network_services

@ -28,14 +28,15 @@ import googleapiclient.errors
import tenacity import tenacity
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
V1_DISCOVERY_URI = flags.DEFINE_string( V1_DISCOVERY_URI = flags.DEFINE_string("v1_discovery_uri",
"v1_discovery_uri", default=discovery.V1_DISCOVERY_URI, default=discovery.V1_DISCOVERY_URI,
help="Override v1 Discovery URI") help="Override v1 Discovery URI")
V2_DISCOVERY_URI = flags.DEFINE_string( V2_DISCOVERY_URI = flags.DEFINE_string("v2_discovery_uri",
"v2_discovery_uri", default=discovery.V2_DISCOVERY_URI, default=discovery.V2_DISCOVERY_URI,
help="Override v2 Discovery URI") help="Override v2 Discovery URI")
COMPUTE_V1_DISCOVERY_FILE = flags.DEFINE_string( COMPUTE_V1_DISCOVERY_FILE = flags.DEFINE_string(
"compute_v1_discovery_file", default=None, "compute_v1_discovery_file",
default=None,
help="Load compute v1 from discovery file") help="Load compute v1 from discovery file")
# Type aliases # Type aliases
@ -43,7 +44,9 @@ Operation = operations_pb2.Operation
class GcpApiManager: class GcpApiManager:
def __init__(self, *,
def __init__(self,
*,
v1_discovery_uri=None, v1_discovery_uri=None,
v2_discovery_uri=None, v2_discovery_uri=None,
compute_v1_discovery_file=None, compute_v1_discovery_file=None,
@ -73,8 +76,9 @@ class GcpApiManager:
def networksecurity(self, version): def networksecurity(self, version):
api_name = 'networksecurity' api_name = 'networksecurity'
if version == 'v1alpha1': if version == 'v1alpha1':
return self._build_from_discovery_v2( return self._build_from_discovery_v2(api_name,
api_name, version, api_key=self.private_api_key) version,
api_key=self.private_api_key)
raise NotImplementedError(f'Network Security {version} not supported') raise NotImplementedError(f'Network Security {version} not supported')
@ -82,22 +86,26 @@ class GcpApiManager:
def networkservices(self, version): def networkservices(self, version):
api_name = 'networkservices' api_name = 'networkservices'
if version == 'v1alpha1': if version == 'v1alpha1':
return self._build_from_discovery_v2( return self._build_from_discovery_v2(api_name,
api_name, version, api_key=self.private_api_key) version,
api_key=self.private_api_key)
raise NotImplementedError(f'Network Services {version} not supported') raise NotImplementedError(f'Network Services {version} not supported')
def _build_from_discovery_v1(self, api_name, version): def _build_from_discovery_v1(self, api_name, version):
api = discovery.build( api = discovery.build(api_name,
api_name, version, cache_discovery=False, version,
discoveryServiceUrl=self.v1_discovery_uri) cache_discovery=False,
discoveryServiceUrl=self.v1_discovery_uri)
self._exit_stack.enter_context(api) self._exit_stack.enter_context(api)
return api return api
def _build_from_discovery_v2(self, api_name, version, *, api_key=None): def _build_from_discovery_v2(self, api_name, version, *, api_key=None):
key_arg = f'&key={api_key}' if api_key else '' key_arg = f'&key={api_key}' if api_key else ''
api = discovery.build( api = discovery.build(
api_name, version, cache_discovery=False, api_name,
version,
cache_discovery=False,
discoveryServiceUrl=f'{self.v2_discovery_uri}{key_arg}') discoveryServiceUrl=f'{self.v2_discovery_uri}{key_arg}')
self._exit_stack.enter_context(api) self._exit_stack.enter_context(api)
return api return api
@ -121,6 +129,7 @@ class OperationError(Error):
https://cloud.google.com/apis/design/design_patterns#long_running_operations https://cloud.google.com/apis/design/design_patterns#long_running_operations
https://github.com/googleapis/googleapis/blob/master/google/longrunning/operations.proto https://github.com/googleapis/googleapis/blob/master/google/longrunning/operations.proto
""" """
def __init__(self, api_name, operation_response, message=None): def __init__(self, api_name, operation_response, message=None):
self.api_name = api_name self.api_name = api_name
operation = json_format.ParseDict(operation_response, Operation()) operation = json_format.ParseDict(operation_response, Operation())
@ -175,7 +184,8 @@ class GcpStandardCloudApiResource(GcpProjectApiResource):
**kwargs): **kwargs):
logger.debug("Creating %s", body) logger.debug("Creating %s", body)
create_req = collection.create(parent=self.parent(), create_req = collection.create(parent=self.parent(),
body=body, **kwargs) body=body,
**kwargs)
self._execute(create_req) self._execute(create_req)
@staticmethod @staticmethod
@ -191,15 +201,17 @@ class GcpStandardCloudApiResource(GcpProjectApiResource):
except googleapiclient.errors.HttpError as error: except googleapiclient.errors.HttpError as error:
# noinspection PyProtectedMember # noinspection PyProtectedMember
reason = error._get_reason() reason = error._get_reason()
logger.info('Delete failed. Error: %s %s', logger.info('Delete failed. Error: %s %s', error.resp.status,
error.resp.status, reason) reason)
def _execute(self, request, def _execute(self,
request,
timeout_sec=GcpProjectApiResource._WAIT_FOR_OPERATION_SEC): timeout_sec=GcpProjectApiResource._WAIT_FOR_OPERATION_SEC):
operation = request.execute(num_retries=self._GCP_API_RETRIES) operation = request.execute(num_retries=self._GCP_API_RETRIES)
self._wait(operation, timeout_sec) self._wait(operation, timeout_sec)
def _wait(self, operation, def _wait(self,
operation,
timeout_sec=GcpProjectApiResource._WAIT_FOR_OPERATION_SEC): timeout_sec=GcpProjectApiResource._WAIT_FOR_OPERATION_SEC):
op_name = operation['name'] op_name = operation['name']
logger.debug('Waiting for %s operation, timeout %s sec: %s', logger.debug('Waiting for %s operation, timeout %s sec: %s',

@ -50,7 +50,8 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
HTTP2 = enum.auto() HTTP2 = enum.auto()
GRPC = enum.auto() GRPC = enum.auto()
def create_health_check_tcp(self, name, def create_health_check_tcp(self,
name,
use_serving_port=False) -> GcpResource: use_serving_port=False) -> GcpResource:
health_check_settings = {} health_check_settings = {}
if use_serving_port: if use_serving_port:
@ -66,30 +67,31 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
self._delete_resource(self.api.healthChecks(), healthCheck=name) self._delete_resource(self.api.healthChecks(), healthCheck=name)
def create_backend_service_traffic_director( def create_backend_service_traffic_director(
self, self,
name: str, name: str,
health_check: GcpResource, health_check: GcpResource,
protocol: Optional[BackendServiceProtocol] = None protocol: Optional[BackendServiceProtocol] = None) -> GcpResource:
) -> GcpResource:
if not isinstance(protocol, self.BackendServiceProtocol): if not isinstance(protocol, self.BackendServiceProtocol):
raise TypeError(f'Unexpected Backend Service protocol: {protocol}') raise TypeError(f'Unexpected Backend Service protocol: {protocol}')
return self._insert_resource(self.api.backendServices(), { return self._insert_resource(
'name': name, self.api.backendServices(),
'loadBalancingScheme': 'INTERNAL_SELF_MANAGED', # Traffic Director {
'healthChecks': [health_check.url], 'name': name,
'protocol': protocol.name, 'loadBalancingScheme':
}) 'INTERNAL_SELF_MANAGED', # Traffic Director
'healthChecks': [health_check.url],
'protocol': protocol.name,
})
def get_backend_service_traffic_director(self, name: str) -> GcpResource: def get_backend_service_traffic_director(self, name: str) -> GcpResource:
return self._get_resource(self.api.backendServices(), return self._get_resource(self.api.backendServices(),
backendService=name) backendService=name)
def patch_backend_service(self, backend_service, body, **kwargs): def patch_backend_service(self, backend_service, body, **kwargs):
self._patch_resource( self._patch_resource(collection=self.api.backendServices(),
collection=self.api.backendServices(), backendService=backend_service.name,
backendService=backend_service.name, body=body,
body=body, **kwargs)
**kwargs)
def backend_service_add_backends(self, backend_service, backends): def backend_service_add_backends(self, backend_service, backends):
backend_list = [{ backend_list = [{
@ -98,16 +100,14 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
'maxRatePerEndpoint': 5 'maxRatePerEndpoint': 5
} for backend in backends] } for backend in backends]
self._patch_resource( self._patch_resource(collection=self.api.backendServices(),
collection=self.api.backendServices(), body={'backends': backend_list},
body={'backends': backend_list}, backendService=backend_service.name)
backendService=backend_service.name)
def backend_service_remove_all_backends(self, backend_service): def backend_service_remove_all_backends(self, backend_service):
self._patch_resource( self._patch_resource(collection=self.api.backendServices(),
collection=self.api.backendServices(), body={'backends': []},
body={'backends': []}, backendService=backend_service.name)
backendService=backend_service.name)
def delete_backend_service(self, name): def delete_backend_service(self, name):
self._delete_resource(self.api.backendServices(), backendService=name) self._delete_resource(self.api.backendServices(), backendService=name)
@ -122,18 +122,21 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
) -> GcpResource: ) -> GcpResource:
if dst_host_rule_match_backend_service is None: if dst_host_rule_match_backend_service is None:
dst_host_rule_match_backend_service = dst_default_backend_service dst_host_rule_match_backend_service = dst_default_backend_service
return self._insert_resource(self.api.urlMaps(), { return self._insert_resource(
'name': name, self.api.urlMaps(), {
'defaultService': dst_default_backend_service.url, 'name':
'hostRules': [{ name,
'hosts': src_hosts, 'defaultService':
'pathMatcher': matcher_name, dst_default_backend_service.url,
}], 'hostRules': [{
'pathMatchers': [{ 'hosts': src_hosts,
'name': matcher_name, 'pathMatcher': matcher_name,
'defaultService': dst_host_rule_match_backend_service.url, }],
}], 'pathMatchers': [{
}) 'name': matcher_name,
'defaultService': dst_host_rule_match_backend_service.url,
}],
})
def delete_url_map(self, name): def delete_url_map(self, name):
self._delete_resource(self.api.urlMaps(), urlMap=name) self._delete_resource(self.api.urlMaps(), urlMap=name)
@ -174,14 +177,17 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
target_proxy: GcpResource, target_proxy: GcpResource,
network_url: str, network_url: str,
) -> GcpResource: ) -> GcpResource:
return self._insert_resource(self.api.globalForwardingRules(), { return self._insert_resource(
'name': name, self.api.globalForwardingRules(),
'loadBalancingScheme': 'INTERNAL_SELF_MANAGED', # Traffic Director {
'portRange': src_port, 'name': name,
'IPAddress': '0.0.0.0', 'loadBalancingScheme':
'network': network_url, 'INTERNAL_SELF_MANAGED', # Traffic Director
'target': target_proxy.url, 'portRange': src_port,
}) 'IPAddress': '0.0.0.0',
'network': network_url,
'target': target_proxy.url,
})
def delete_forwarding_rule(self, name): def delete_forwarding_rule(self, name):
self._delete_resource(self.api.globalForwardingRules(), self._delete_resource(self.api.globalForwardingRules(),
@ -192,15 +198,16 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
return not neg or neg.get('size', 0) == 0 return not neg or neg.get('size', 0) == 0
def wait_for_network_endpoint_group(self, name, zone): def wait_for_network_endpoint_group(self, name, zone):
@retrying.retry(retry_on_result=self._network_endpoint_group_not_ready, @retrying.retry(retry_on_result=self._network_endpoint_group_not_ready,
stop_max_delay=60 * 1000, stop_max_delay=60 * 1000,
wait_fixed=2 * 1000) wait_fixed=2 * 1000)
def _wait_for_network_endpoint_group_ready(): def _wait_for_network_endpoint_group_ready():
try: try:
neg = self.get_network_endpoint_group(name, zone) neg = self.get_network_endpoint_group(name, zone)
logger.debug('Waiting for endpoints: NEG %s in zone %s, ' logger.debug(
'current count %s', 'Waiting for endpoints: NEG %s in zone %s, '
neg['name'], zone, neg.get('size')) 'current count %s', neg['name'], zone, neg.get('size'))
except googleapiclient.errors.HttpError as error: except googleapiclient.errors.HttpError as error:
# noinspection PyProtectedMember # noinspection PyProtectedMember
reason = error._get_reason() reason = error._get_reason()
@ -211,10 +218,8 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
network_endpoint_group = _wait_for_network_endpoint_group_ready() network_endpoint_group = _wait_for_network_endpoint_group_ready()
# @todo(sergiitk): dataclass # @todo(sergiitk): dataclass
return self.ZonalGcpResource( return self.ZonalGcpResource(network_endpoint_group['name'],
network_endpoint_group['name'], network_endpoint_group['selfLink'], zone)
network_endpoint_group['selfLink'],
zone)
def get_network_endpoint_group(self, name, zone): def get_network_endpoint_group(self, name, zone):
neg = self.api.networkEndpointGroups().get(project=self.project, neg = self.api.networkEndpointGroups().get(project=self.project,
@ -232,10 +237,9 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
): ):
pending = set(backends) pending = set(backends)
@retrying.retry( @retrying.retry(retry_on_result=lambda result: not result,
retry_on_result=lambda result: not result, stop_max_delay=timeout_sec * 1000,
stop_max_delay=timeout_sec * 1000, wait_fixed=wait_sec * 1000)
wait_fixed=wait_sec * 1000)
def _retry_backends_health(): def _retry_backends_health():
for backend in pending: for backend in pending:
result = self.get_backend_service_backend_health( result = self.get_backend_service_backend_health(
@ -250,9 +254,8 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
for instance in result['healthStatus']: for instance in result['healthStatus']:
logger.debug( logger.debug(
'Backend %s in zone %s: instance %s:%s health: %s', 'Backend %s in zone %s: instance %s:%s health: %s',
backend.name, backend.zone, backend.name, backend.zone, instance['ipAddress'],
instance['ipAddress'], instance['port'], instance['port'], instance['healthState'])
instance['healthState'])
if instance['healthState'] != 'HEALTHY': if instance['healthState'] != 'HEALTHY':
backend_healthy = False backend_healthy = False
@ -267,8 +270,11 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
def get_backend_service_backend_health(self, backend_service, backend): def get_backend_service_backend_health(self, backend_service, backend):
return self.api.backendServices().getHealth( return self.api.backendServices().getHealth(
project=self.project, backendService=backend_service.name, project=self.project,
body={"group": backend.url}).execute() backendService=backend_service.name,
body={
"group": backend.url
}).execute()
def _get_resource(self, collection: discovery.Resource, def _get_resource(self, collection: discovery.Resource,
**kwargs) -> GcpResource: **kwargs) -> GcpResource:
@ -276,11 +282,8 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
logger.debug("Loaded %r", resp) logger.debug("Loaded %r", resp)
return self.GcpResource(resp['name'], resp['selfLink']) return self.GcpResource(resp['name'], resp['selfLink'])
def _insert_resource( def _insert_resource(self, collection: discovery.Resource,
self, body: Dict[str, Any]) -> GcpResource:
collection: discovery.Resource,
body: Dict[str, Any]
) -> GcpResource:
logger.debug("Creating %s", body) logger.debug("Creating %s", body)
resp = self._execute(collection.insert(project=self.project, body=body)) resp = self._execute(collection.insert(project=self.project, body=body))
return self.GcpResource(body['name'], resp['targetLink']) return self.GcpResource(body['name'], resp['targetLink'])
@ -297,14 +300,16 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
except googleapiclient.errors.HttpError as error: except googleapiclient.errors.HttpError as error:
# noinspection PyProtectedMember # noinspection PyProtectedMember
reason = error._get_reason() reason = error._get_reason()
logger.info('Delete failed. Error: %s %s', logger.info('Delete failed. Error: %s %s', error.resp.status,
error.resp.status, reason) reason)
@staticmethod @staticmethod
def _operation_status_done(operation): def _operation_status_done(operation):
return 'status' in operation and operation['status'] == 'DONE' return 'status' in operation and operation['status'] == 'DONE'
def _execute(self, request, *, def _execute(self,
request,
*,
test_success_fn=None, test_success_fn=None,
timeout_sec=_WAIT_FOR_OPERATION_SEC): timeout_sec=_WAIT_FOR_OPERATION_SEC):
operation = request.execute(num_retries=self._GCP_API_RETRIES) operation = request.execute(num_retries=self._GCP_API_RETRIES)
@ -320,10 +325,9 @@ class ComputeV1(gcp.api.GcpProjectApiResource):
logger.debug('Waiting for global operation %s, timeout %s sec', logger.debug('Waiting for global operation %s, timeout %s sec',
operation['name'], timeout_sec) operation['name'], timeout_sec)
response = self.wait_for_operation( response = self.wait_for_operation(operation_request=operation_request,
operation_request=operation_request, test_success_fn=test_success_fn,
test_success_fn=test_success_fn, timeout_sec=timeout_sec)
timeout_sec=timeout_sec)
if 'error' in response: if 'error' in response:
logger.debug('Waiting for global operation failed, response: %r', logger.debug('Waiting for global operation failed, response: %r',

@ -52,22 +52,22 @@ class NetworkSecurityV1Alpha1(gcp.api.GcpStandardCloudApiResource):
self._api_locations = self.api.projects().locations() self._api_locations = self.api.projects().locations()
def create_server_tls_policy(self, name, body: dict): def create_server_tls_policy(self, name, body: dict):
return self._create_resource( return self._create_resource(self._api_locations.serverTlsPolicies(),
self._api_locations.serverTlsPolicies(), body,
body, serverTlsPolicyId=name) serverTlsPolicyId=name)
def get_server_tls_policy(self, name: str) -> ServerTlsPolicy: def get_server_tls_policy(self, name: str) -> ServerTlsPolicy:
result = self._get_resource( result = self._get_resource(
collection=self._api_locations.serverTlsPolicies(), collection=self._api_locations.serverTlsPolicies(),
full_name=self.resource_full_name(name, self.SERVER_TLS_POLICIES)) full_name=self.resource_full_name(name, self.SERVER_TLS_POLICIES))
return self.ServerTlsPolicy( return self.ServerTlsPolicy(name=name,
name=name, url=result['name'],
url=result['name'], server_certificate=result.get(
server_certificate=result.get('serverCertificate', {}), 'serverCertificate', {}),
mtls_policy=result.get('mtlsPolicy', {}), mtls_policy=result.get('mtlsPolicy', {}),
create_time=result['createTime'], create_time=result['createTime'],
update_time=result['updateTime']) update_time=result['updateTime'])
def delete_server_tls_policy(self, name): def delete_server_tls_policy(self, name):
return self._delete_resource( return self._delete_resource(
@ -75,9 +75,9 @@ class NetworkSecurityV1Alpha1(gcp.api.GcpStandardCloudApiResource):
full_name=self.resource_full_name(name, self.SERVER_TLS_POLICIES)) full_name=self.resource_full_name(name, self.SERVER_TLS_POLICIES))
def create_client_tls_policy(self, name, body: dict): def create_client_tls_policy(self, name, body: dict):
return self._create_resource( return self._create_resource(self._api_locations.clientTlsPolicies(),
self._api_locations.clientTlsPolicies(), body,
body, clientTlsPolicyId=name) clientTlsPolicyId=name)
def get_client_tls_policy(self, name: str) -> ClientTlsPolicy: def get_client_tls_policy(self, name: str) -> ClientTlsPolicy:
result = self._get_resource( result = self._get_resource(

@ -49,7 +49,8 @@ class NetworkServicesV1Alpha1(gcp.api.GcpStandardCloudApiResource):
def create_endpoint_config_selector(self, name, body: dict): def create_endpoint_config_selector(self, name, body: dict):
return self._create_resource( return self._create_resource(
self._api_locations.endpointConfigSelectors(), self._api_locations.endpointConfigSelectors(),
body, endpointConfigSelectorId=name) body,
endpointConfigSelectorId=name)
def get_endpoint_config_selector(self, name: str) -> EndpointConfigSelector: def get_endpoint_config_selector(self, name: str) -> EndpointConfigSelector:
result = self._get_resource( result = self._get_resource(

@ -35,6 +35,7 @@ ApiException = client.ApiException
def simple_resource_get(func): def simple_resource_get(func):
def wrap_not_found_return_none(*args, **kwargs): def wrap_not_found_return_none(*args, **kwargs):
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
@ -43,6 +44,7 @@ def simple_resource_get(func):
# Ignore 404 # Ignore 404
return None return None
raise raise
return wrap_not_found_return_none return wrap_not_found_return_none
@ -51,6 +53,7 @@ def label_dict_to_selector(labels: dict) -> str:
class KubernetesApiManager: class KubernetesApiManager:
def __init__(self, context): def __init__(self, context):
self.context = context self.context = context
self.client = self._cached_api_client_for_context(context) self.client = self._cached_api_client_for_context(context)
@ -80,7 +83,8 @@ class KubernetesNamespace:
self.api = api self.api = api
def apply_manifest(self, manifest): def apply_manifest(self, manifest):
return utils.create_from_dict(self.api.client, manifest, return utils.create_from_dict(self.api.client,
manifest,
namespace=self.name) namespace=self.name)
@simple_resource_get @simple_resource_get
@ -91,24 +95,22 @@ class KubernetesNamespace:
def get_service_account(self, name) -> V1Service: def get_service_account(self, name) -> V1Service:
return self.api.core.read_namespaced_service_account(name, self.name) return self.api.core.read_namespaced_service_account(name, self.name)
def delete_service( def delete_service(self,
self, name,
name, grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
grace_period_seconds=DELETE_GRACE_PERIOD_SEC
):
self.api.core.delete_namespaced_service( self.api.core.delete_namespaced_service(
name=name, namespace=self.name, name=name,
namespace=self.name,
body=client.V1DeleteOptions( body=client.V1DeleteOptions(
propagation_policy='Foreground', propagation_policy='Foreground',
grace_period_seconds=grace_period_seconds)) grace_period_seconds=grace_period_seconds))
def delete_service_account( def delete_service_account(self,
self, name,
name, grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
grace_period_seconds=DELETE_GRACE_PERIOD_SEC
):
self.api.core.delete_namespaced_service_account( self.api.core.delete_namespaced_service_account(
name=name, namespace=self.name, name=name,
namespace=self.name,
body=client.V1DeleteOptions( body=client.V1DeleteOptions(
propagation_policy='Foreground', propagation_policy='Foreground',
grace_period_seconds=grace_period_seconds)) grace_period_seconds=grace_period_seconds))
@ -124,8 +126,8 @@ class KubernetesNamespace:
propagation_policy='Foreground', propagation_policy='Foreground',
grace_period_seconds=grace_period_seconds)) grace_period_seconds=grace_period_seconds))
def wait_for_service_deleted(self, name: str, def wait_for_service_deleted(self, name: str, timeout_sec=60, wait_sec=1):
timeout_sec=60, wait_sec=1):
@retrying.retry(retry_on_result=lambda r: r is not None, @retrying.retry(retry_on_result=lambda r: r is not None,
stop_max_delay=timeout_sec * 1000, stop_max_delay=timeout_sec * 1000,
wait_fixed=wait_sec * 1000) wait_fixed=wait_sec * 1000)
@ -135,10 +137,14 @@ class KubernetesNamespace:
logger.info('Waiting for service %s to be deleted', logger.info('Waiting for service %s to be deleted',
service.metadata.name) service.metadata.name)
return service return service
_wait_for_deleted_service_with_retry() _wait_for_deleted_service_with_retry()
def wait_for_service_account_deleted(self, name: str, def wait_for_service_account_deleted(self,
timeout_sec=60, wait_sec=1): name: str,
timeout_sec=60,
wait_sec=1):
@retrying.retry(retry_on_result=lambda r: r is not None, @retrying.retry(retry_on_result=lambda r: r is not None,
stop_max_delay=timeout_sec * 1000, stop_max_delay=timeout_sec * 1000,
wait_fixed=wait_sec * 1000) wait_fixed=wait_sec * 1000)
@ -148,10 +154,11 @@ class KubernetesNamespace:
logger.info('Waiting for service account %s to be deleted', logger.info('Waiting for service account %s to be deleted',
service_account.metadata.name) service_account.metadata.name)
return service_account return service_account
_wait_for_deleted_service_account_with_retry() _wait_for_deleted_service_account_with_retry()
def wait_for_namespace_deleted(self, def wait_for_namespace_deleted(self, timeout_sec=240, wait_sec=2):
timeout_sec=240, wait_sec=2):
@retrying.retry(retry_on_result=lambda r: r is not None, @retrying.retry(retry_on_result=lambda r: r is not None,
stop_max_delay=timeout_sec * 1000, stop_max_delay=timeout_sec * 1000,
wait_fixed=wait_sec * 1000) wait_fixed=wait_sec * 1000)
@ -161,27 +168,25 @@ class KubernetesNamespace:
logger.info('Waiting for namespace %s to be deleted', logger.info('Waiting for namespace %s to be deleted',
namespace.metadata.name) namespace.metadata.name)
return namespace return namespace
_wait_for_deleted_namespace_with_retry() _wait_for_deleted_namespace_with_retry()
def wait_for_service_neg(self, name: str, def wait_for_service_neg(self, name: str, timeout_sec=60, wait_sec=1):
timeout_sec=60, wait_sec=1):
@retrying.retry(retry_on_result=lambda r: not r, @retrying.retry(retry_on_result=lambda r: not r,
stop_max_delay=timeout_sec * 1000, stop_max_delay=timeout_sec * 1000,
wait_fixed=wait_sec * 1000) wait_fixed=wait_sec * 1000)
def _wait_for_service_neg(): def _wait_for_service_neg():
service = self.get_service(name) service = self.get_service(name)
if self.NEG_STATUS_META not in service.metadata.annotations: if self.NEG_STATUS_META not in service.metadata.annotations:
logger.info('Waiting for service %s NEG', logger.info('Waiting for service %s NEG', service.metadata.name)
service.metadata.name)
return False return False
return True return True
_wait_for_service_neg() _wait_for_service_neg()
def get_service_neg( def get_service_neg(self, service_name: str,
self, service_port: int) -> Tuple[str, List[str]]:
service_name: str,
service_port: int
) -> Tuple[str, List[str]]:
service = self.get_service(service_name) service = self.get_service(service_name)
neg_info: dict = json.loads( neg_info: dict = json.loads(
service.metadata.annotations[self.NEG_STATUS_META]) service.metadata.annotations[self.NEG_STATUS_META])
@ -193,13 +198,12 @@ class KubernetesNamespace:
def get_deployment(self, name) -> V1Deployment: def get_deployment(self, name) -> V1Deployment:
return self.api.apps.read_namespaced_deployment(name, self.name) return self.api.apps.read_namespaced_deployment(name, self.name)
def delete_deployment( def delete_deployment(self,
self, name,
name, grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
grace_period_seconds=DELETE_GRACE_PERIOD_SEC
):
self.api.apps.delete_namespaced_deployment( self.api.apps.delete_namespaced_deployment(
name=name, namespace=self.name, name=name,
namespace=self.name,
body=client.V1DeleteOptions( body=client.V1DeleteOptions(
propagation_policy='Foreground', propagation_policy='Foreground',
grace_period_seconds=grace_period_seconds)) grace_period_seconds=grace_period_seconds))
@ -208,34 +212,43 @@ class KubernetesNamespace:
# V1LabelSelector.match_expressions not supported at the moment # V1LabelSelector.match_expressions not supported at the moment
return self.list_pods_with_labels(deployment.spec.selector.match_labels) return self.list_pods_with_labels(deployment.spec.selector.match_labels)
def wait_for_deployment_available_replicas(self, name, count=1, def wait_for_deployment_available_replicas(self,
timeout_sec=60, wait_sec=1): name,
count=1,
timeout_sec=60,
wait_sec=1):
@retrying.retry( @retrying.retry(
retry_on_result=lambda r: not self._replicas_available(r, count), retry_on_result=lambda r: not self._replicas_available(r, count),
stop_max_delay=timeout_sec * 1000, stop_max_delay=timeout_sec * 1000,
wait_fixed=wait_sec * 1000) wait_fixed=wait_sec * 1000)
def _wait_for_deployment_available_replicas(): def _wait_for_deployment_available_replicas():
deployment = self.get_deployment(name) deployment = self.get_deployment(name)
logger.info('Waiting for deployment %s to have %s available ' logger.info(
'replicas, current count %s', 'Waiting for deployment %s to have %s available '
deployment.metadata.name, 'replicas, current count %s', deployment.metadata.name, count,
count, deployment.status.available_replicas) deployment.status.available_replicas)
return deployment return deployment
_wait_for_deployment_available_replicas() _wait_for_deployment_available_replicas()
def wait_for_deployment_deleted(self, deployment_name: str, def wait_for_deployment_deleted(self,
timeout_sec=60, wait_sec=1): deployment_name: str,
timeout_sec=60,
wait_sec=1):
@retrying.retry(retry_on_result=lambda r: r is not None, @retrying.retry(retry_on_result=lambda r: r is not None,
stop_max_delay=timeout_sec * 1000, stop_max_delay=timeout_sec * 1000,
wait_fixed=wait_sec * 1000) wait_fixed=wait_sec * 1000)
def _wait_for_deleted_deployment_with_retry(): def _wait_for_deleted_deployment_with_retry():
deployment = self.get_deployment(deployment_name) deployment = self.get_deployment(deployment_name)
if deployment is not None: if deployment is not None:
logger.info('Waiting for deployment %s to be deleted. ' logger.info(
'Non-terminated replicas: %s', 'Waiting for deployment %s to be deleted. '
deployment.metadata.name, 'Non-terminated replicas: %s', deployment.metadata.name,
deployment.status.replicas) deployment.status.replicas)
return deployment return deployment
_wait_for_deleted_deployment_with_retry() _wait_for_deleted_deployment_with_retry()
def list_pods_with_labels(self, labels: dict) -> List[V1Pod]: def list_pods_with_labels(self, labels: dict) -> List[V1Pod]:
@ -247,15 +260,16 @@ class KubernetesNamespace:
return self.api.core.read_namespaced_pod(name, self.name) return self.api.core.read_namespaced_pod(name, self.name)
def wait_for_pod_started(self, pod_name, timeout_sec=60, wait_sec=1): def wait_for_pod_started(self, pod_name, timeout_sec=60, wait_sec=1):
@retrying.retry(retry_on_result=lambda r: not self._pod_started(r), @retrying.retry(retry_on_result=lambda r: not self._pod_started(r),
stop_max_delay=timeout_sec * 1000, stop_max_delay=timeout_sec * 1000,
wait_fixed=wait_sec * 1000) wait_fixed=wait_sec * 1000)
def _wait_for_pod_started(): def _wait_for_pod_started():
pod = self.get_pod(pod_name) pod = self.get_pod(pod_name)
logger.info('Waiting for pod %s to start, current phase: %s', logger.info('Waiting for pod %s to start, current phase: %s',
pod.metadata.name, pod.metadata.name, pod.status.phase)
pod.status.phase)
return pod return pod
_wait_for_pod_started() _wait_for_pod_started()
def port_forward_pod( def port_forward_pod(
@ -269,12 +283,12 @@ class KubernetesNamespace:
local_address = local_address or self.PORT_FORWARD_LOCAL_ADDRESS local_address = local_address or self.PORT_FORWARD_LOCAL_ADDRESS
local_port = local_port or remote_port local_port = local_port or remote_port
cmd = [ cmd = [
"kubectl", "--context", self.api.context, "kubectl", "--context", self.api.context, "--namespace", self.name,
"--namespace", self.name,
"port-forward", "--address", local_address, "port-forward", "--address", local_address,
f"pod/{pod.metadata.name}", f"{local_port}:{remote_port}" f"pod/{pod.metadata.name}", f"{local_port}:{remote_port}"
] ]
pf = subprocess.Popen(cmd, stdout=subprocess.PIPE, pf = subprocess.Popen(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, stderr=subprocess.STDOUT,
universal_newlines=True) universal_newlines=True)
# Wait for stdout line indicating successful start. # Wait for stdout line indicating successful start.

@ -75,13 +75,11 @@ class TrafficDirectorManager:
def network_url(self): def network_url(self):
return f'global/networks/{self.network}' return f'global/networks/{self.network}'
def setup_for_grpc( def setup_for_grpc(self,
self, service_host,
service_host, service_port,
service_port, *,
*, backend_protocol=BackendServiceProtocol.GRPC):
backend_protocol=BackendServiceProtocol.GRPC
):
self.create_health_check() self.create_health_check()
self.create_backend_service(protocol=backend_protocol) self.create_backend_service(protocol=backend_protocol)
self.create_url_map(service_host, service_port) self.create_url_map(service_host, service_port)
@ -130,9 +128,8 @@ class TrafficDirectorManager:
self.health_check = None self.health_check = None
def create_backend_service( def create_backend_service(
self, self,
protocol: BackendServiceProtocol = BackendServiceProtocol.GRPC protocol: BackendServiceProtocol = BackendServiceProtocol.GRPC):
):
name = self._ns_name(self.BACKEND_SERVICE_NAME) name = self._ns_name(self.BACKEND_SERVICE_NAME)
logger.info('Creating %s Backend Service %s', protocol.name, name) logger.info('Creating %s Backend Service %s', protocol.name, name)
resource = self.compute.create_backend_service_traffic_director( resource = self.compute.create_backend_service_traffic_director(
@ -168,8 +165,8 @@ class TrafficDirectorManager:
def backend_service_add_backends(self): def backend_service_add_backends(self):
logging.info('Adding backends to Backend Service %s: %r', logging.info('Adding backends to Backend Service %s: %r',
self.backend_service.name, self.backends) self.backend_service.name, self.backends)
self.compute.backend_service_add_backends( self.compute.backend_service_add_backends(self.backend_service,
self.backend_service, self.backends) self.backends)
def backend_service_remove_all_backends(self): def backend_service_remove_all_backends(self):
logging.info('Removing backends from Backend Service %s', logging.info('Removing backends from Backend Service %s',
@ -180,8 +177,8 @@ class TrafficDirectorManager:
logger.debug( logger.debug(
"Waiting for Backend Service %s to report all backends healthy %r", "Waiting for Backend Service %s to report all backends healthy %r",
self.backend_service, self.backends) self.backend_service, self.backends)
self.compute.wait_for_backends_healthy_status( self.compute.wait_for_backends_healthy_status(self.backend_service,
self.backend_service, self.backends) self.backends)
def create_url_map( def create_url_map(
self, self,
@ -191,10 +188,11 @@ class TrafficDirectorManager:
src_address = f'{src_host}:{src_port}' src_address = f'{src_host}:{src_port}'
name = self._ns_name(self.URL_MAP_NAME) name = self._ns_name(self.URL_MAP_NAME)
matcher_name = self._ns_name(self.URL_MAP_PATH_MATCHER_NAME) matcher_name = self._ns_name(self.URL_MAP_PATH_MATCHER_NAME)
logger.info('Creating URL map %s %s -> %s', logger.info('Creating URL map %s %s -> %s', name, src_address,
name, src_address, self.backend_service.name) self.backend_service.name)
resource = self.compute.create_url_map( resource = self.compute.create_url_map(name, matcher_name,
name, matcher_name, [src_address], self.backend_service) [src_address],
self.backend_service)
self.url_map = resource self.url_map = resource
return resource return resource
@ -212,10 +210,9 @@ class TrafficDirectorManager:
def create_target_grpc_proxy(self): def create_target_grpc_proxy(self):
# todo: different kinds # todo: different kinds
name = self._ns_name(self.TARGET_PROXY_NAME) name = self._ns_name(self.TARGET_PROXY_NAME)
logger.info('Creating target GRPC proxy %s to url map %s', logger.info('Creating target GRPC proxy %s to url map %s', name,
name, self.url_map.name) self.url_map.name)
resource = self.compute.create_target_grpc_proxy( resource = self.compute.create_target_grpc_proxy(name, self.url_map)
name, self.url_map)
self.target_proxy = resource self.target_proxy = resource
def delete_target_grpc_proxy(self, force=False): def delete_target_grpc_proxy(self, force=False):
@ -233,10 +230,9 @@ class TrafficDirectorManager:
def create_target_http_proxy(self): def create_target_http_proxy(self):
# todo: different kinds # todo: different kinds
name = self._ns_name(self.TARGET_PROXY_NAME) name = self._ns_name(self.TARGET_PROXY_NAME)
logger.info('Creating target HTTP proxy %s to url map %s', logger.info('Creating target HTTP proxy %s to url map %s', name,
name, self.url_map.name) self.url_map.name)
resource = self.compute.create_target_http_proxy( resource = self.compute.create_target_http_proxy(name, self.url_map)
name, self.url_map)
self.target_proxy = resource self.target_proxy = resource
self.target_proxy_is_http = True self.target_proxy_is_http = True
@ -255,10 +251,11 @@ class TrafficDirectorManager:
def create_forwarding_rule(self, src_port: int): def create_forwarding_rule(self, src_port: int):
name = self._ns_name(self.FORWARDING_RULE_NAME) name = self._ns_name(self.FORWARDING_RULE_NAME)
src_port = int(src_port) src_port = int(src_port)
logging.info('Creating forwarding rule %s 0.0.0.0:%s -> %s in %s', logging.info('Creating forwarding rule %s 0.0.0.0:%s -> %s in %s', name,
name, src_port, self.target_proxy.url, self.network) src_port, self.target_proxy.url, self.network)
resource = self.compute.create_forwarding_rule( resource = self.compute.create_forwarding_rule(name, src_port,
name, src_port, self.target_proxy, self.network_url) self.target_proxy,
self.network_url)
self.forwarding_rule = resource self.forwarding_rule = resource
return resource return resource
@ -289,8 +286,10 @@ class TrafficDirectorSecureManager(TrafficDirectorManager):
resource_prefix: str, resource_prefix: str,
network: str = 'default', network: str = 'default',
): ):
super().__init__(gcp_api_manager, project, super().__init__(gcp_api_manager,
resource_prefix=resource_prefix, network=network) project,
resource_prefix=resource_prefix,
network=network)
# API # API
self.netsec = NetworkSecurityV1Alpha1(gcp_api_manager, project) self.netsec = NetworkSecurityV1Alpha1(gcp_api_manager, project)
@ -301,25 +300,28 @@ class TrafficDirectorSecureManager(TrafficDirectorManager):
self.ecs: Optional[EndpointConfigSelector] = None self.ecs: Optional[EndpointConfigSelector] = None
self.client_tls_policy: Optional[ClientTlsPolicy] = None self.client_tls_policy: Optional[ClientTlsPolicy] = None
def setup_for_grpc( def setup_for_grpc(self,
self, service_host,
service_host, service_port,
service_port, *,
*, backend_protocol=BackendServiceProtocol.HTTP2):
backend_protocol=BackendServiceProtocol.HTTP2 super().setup_for_grpc(service_host,
): service_port,
super().setup_for_grpc(service_host, service_port,
backend_protocol=backend_protocol) backend_protocol=backend_protocol)
def setup_server_security(self, server_port, *, tls, mtls): def setup_server_security(self, server_port, *, tls, mtls):
self.create_server_tls_policy(tls=tls, mtls=mtls) self.create_server_tls_policy(tls=tls, mtls=mtls)
self.create_endpoint_config_selector(server_port) self.create_endpoint_config_selector(server_port)
def setup_client_security(self, server_namespace, server_name, def setup_client_security(self,
*, tls=True, mtls=True): server_namespace,
server_name,
*,
tls=True,
mtls=True):
self.create_client_tls_policy(tls=tls, mtls=mtls) self.create_client_tls_policy(tls=tls, mtls=mtls)
self.backend_service_apply_client_mtls_policy( self.backend_service_apply_client_mtls_policy(server_namespace,
server_namespace, server_name) server_name)
def cleanup(self, *, force=False): def cleanup(self, *, force=False):
# Cleanup in the reverse order of creation # Cleanup in the reverse order of creation
@ -334,12 +336,16 @@ class TrafficDirectorSecureManager(TrafficDirectorManager):
name = self._ns_name(self.SERVER_TLS_POLICY_NAME) name = self._ns_name(self.SERVER_TLS_POLICY_NAME)
logger.info('Creating Server TLS Policy %s', name) logger.info('Creating Server TLS Policy %s', name)
if not tls and not mtls: if not tls and not mtls:
logger.warning('Server TLS Policy %s neither TLS, nor mTLS ' logger.warning(
'policy. Skipping creation', name) 'Server TLS Policy %s neither TLS, nor mTLS '
'policy. Skipping creation', name)
return return
grpc_endpoint = { grpc_endpoint = {
"grpcEndpoint": {"targetUri": self.GRPC_ENDPOINT_TARGET_URI}} "grpcEndpoint": {
"targetUri": self.GRPC_ENDPOINT_TARGET_URI
}
}
policy = {} policy = {}
if tls: if tls:
@ -381,13 +387,16 @@ class TrafficDirectorSecureManager(TrafficDirectorManager):
"type": "SIDECAR_PROXY", "type": "SIDECAR_PROXY",
"httpFilters": {}, "httpFilters": {},
"trafficPortSelector": port_selector, "trafficPortSelector": port_selector,
"endpointMatcher": {"metadataLabelMatcher": label_matcher_all}, "endpointMatcher": {
"metadataLabelMatcher": label_matcher_all
},
} }
if self.server_tls_policy: if self.server_tls_policy:
config["serverTlsPolicy"] = self.server_tls_policy.name config["serverTlsPolicy"] = self.server_tls_policy.name
else: else:
logger.warning('Creating Endpoint Config Selector %s with ' logger.warning(
'no Server TLS policy attached', name) 'Creating Endpoint Config Selector %s with '
'no Server TLS policy attached', name)
self.netsvc.create_endpoint_config_selector(name, config) self.netsvc.create_endpoint_config_selector(name, config)
self.ecs = self.netsvc.get_endpoint_config_selector(name) self.ecs = self.netsvc.get_endpoint_config_selector(name)
@ -408,12 +417,16 @@ class TrafficDirectorSecureManager(TrafficDirectorManager):
name = self._ns_name(self.CLIENT_TLS_POLICY_NAME) name = self._ns_name(self.CLIENT_TLS_POLICY_NAME)
logger.info('Creating Client TLS Policy %s', name) logger.info('Creating Client TLS Policy %s', name)
if not tls and not mtls: if not tls and not mtls:
logger.warning('Client TLS Policy %s neither TLS, nor mTLS ' logger.warning(
'policy. Skipping creation', name) 'Client TLS Policy %s neither TLS, nor mTLS '
'policy. Skipping creation', name)
return return
grpc_endpoint = { grpc_endpoint = {
"grpcEndpoint": {"targetUri": self.GRPC_ENDPOINT_TARGET_URI}} "grpcEndpoint": {
"targetUri": self.GRPC_ENDPOINT_TARGET_URI
}
}
policy = {} policy = {}
if tls: if tls:
@ -442,21 +455,23 @@ class TrafficDirectorSecureManager(TrafficDirectorManager):
server_name, server_name,
): ):
if not self.client_tls_policy: if not self.client_tls_policy:
logger.warning('Client TLS policy not created, ' logger.warning(
'skipping attaching to Backend Service %s', 'Client TLS policy not created, '
self.backend_service.name) 'skipping attaching to Backend Service %s',
self.backend_service.name)
return return
server_spiffe = (f'spiffe://{self.project}.svc.id.goog/' server_spiffe = (f'spiffe://{self.project}.svc.id.goog/'
f'ns/{server_namespace}/sa/{server_name}') f'ns/{server_namespace}/sa/{server_name}')
logging.info('Adding Client TLS Policy to Backend Service %s: %s, ' logging.info(
'server %s', 'Adding Client TLS Policy to Backend Service %s: %s, '
self.backend_service.name, 'server %s', self.backend_service.name, self.client_tls_policy.url,
self.client_tls_policy.url, server_spiffe)
server_spiffe)
self.compute.patch_backend_service(
self.compute.patch_backend_service(self.backend_service, { self.backend_service, {
'securitySettings': { 'securitySettings': {
'clientTlsPolicy': self.client_tls_policy.url, 'clientTlsPolicy': self.client_tls_policy.url,
'subjectAltNames': [server_spiffe] 'subjectAltNames': [server_spiffe]
}}) }
})

@ -37,7 +37,8 @@ class GrpcClientHelper:
self.service_name = re.sub('Stub$', '', self.stub.__class__.__name__) self.service_name = re.sub('Stub$', '', self.stub.__class__.__name__)
def call_unary_when_channel_ready( def call_unary_when_channel_ready(
self, *, self,
*,
rpc: str, rpc: str,
req: Message, req: Message,
wait_for_ready_sec: Optional[int] = DEFAULT_WAIT_FOR_READY_SEC, wait_for_ready_sec: Optional[int] = DEFAULT_WAIT_FOR_READY_SEC,
@ -56,8 +57,7 @@ class GrpcClientHelper:
return rpc_callable(req, **call_kwargs) return rpc_callable(req, **call_kwargs)
def _log_debug(self, rpc, req, call_kwargs): def _log_debug(self, rpc, req, call_kwargs):
logger.debug('RPC %s.%s(request=%s(%r), %s)', logger.debug('RPC %s.%s(request=%s(%r), %s)', self.service_name, rpc,
self.service_name, rpc,
req.__class__.__name__, json_format.MessageToDict(req), req.__class__.__name__, json_format.MessageToDict(req),
', '.join({f'{k}={v}' for k, v in call_kwargs.items()})) ', '.join({f'{k}={v}' for k, v in call_kwargs.items()}))

@ -83,10 +83,8 @@ class ChannelzServiceClient(framework.rpc.GrpcClientHelper):
f'remote={cls.sock_address_to_str(socket.remote)}') f'remote={cls.sock_address_to_str(socket.remote)}')
@staticmethod @staticmethod
def find_server_socket_matching_client( def find_server_socket_matching_client(server_sockets: Iterator[Socket],
server_sockets: Iterator[Socket], client_socket: Socket) -> Socket:
client_socket: Socket
) -> Socket:
for server_socket in server_sockets: for server_socket in server_sockets:
if server_socket.remote == client_socket.local: if server_socket.remote == client_socket.local:
return server_socket return server_socket
@ -103,7 +101,7 @@ class ChannelzServiceClient(framework.rpc.GrpcClientHelper):
listen_socket = self.get_socket(listen_socket_ref.socket_id) listen_socket = self.get_socket(listen_socket_ref.socket_id)
listen_address: Address = listen_socket.local listen_address: Address = listen_socket.local
if (self.is_sock_tcpip_address(listen_address) and if (self.is_sock_tcpip_address(listen_address) and
listen_address.tcpip_address.port == port): listen_address.tcpip_address.port == port):
return server return server
return None return None
@ -136,8 +134,7 @@ class ChannelzServiceClient(framework.rpc.GrpcClientHelper):
# value by adding 1 to the highest seen result ID. # value by adding 1 to the highest seen result ID.
start += 1 start += 1
response = self.call_unary_when_channel_ready( response = self.call_unary_when_channel_ready(
rpc='GetServers', rpc='GetServers', req=GetServersRequest(start_server_id=start))
req=GetServersRequest(start_server_id=start))
for server in response.server: for server in response.server:
start = max(start, server.ref.server_id) start = max(start, server.ref.server_id)
yield server yield server
@ -170,6 +167,5 @@ class ChannelzServiceClient(framework.rpc.GrpcClientHelper):
def get_socket(self, socket_id) -> Socket: def get_socket(self, socket_id) -> Socket:
"""Return a single Socket, otherwise raises RpcError.""" """Return a single Socket, otherwise raises RpcError."""
response: GetSocketResponse = self.call_unary_when_channel_ready( response: GetSocketResponse = self.call_unary_when_channel_ready(
rpc='GetSocket', rpc='GetSocket', req=GetSocketRequest(socket_id=socket_id))
req=GetSocketRequest(socket_id=socket_id))
return response.socket return response.socket

@ -19,7 +19,6 @@ import framework.rpc
from src.proto.grpc.testing import test_pb2_grpc from src.proto.grpc.testing import test_pb2_grpc
from src.proto.grpc.testing import messages_pb2 from src.proto.grpc.testing import messages_pb2
# Type aliases # Type aliases
LoadBalancerStatsRequest = messages_pb2.LoadBalancerStatsRequest LoadBalancerStatsRequest = messages_pb2.LoadBalancerStatsRequest
LoadBalancerStatsResponse = messages_pb2.LoadBalancerStatsResponse LoadBalancerStatsResponse = messages_pb2.LoadBalancerStatsResponse
@ -33,7 +32,8 @@ class LoadBalancerStatsServiceClient(framework.rpc.GrpcClientHelper):
super().__init__(channel, test_pb2_grpc.LoadBalancerStatsServiceStub) super().__init__(channel, test_pb2_grpc.LoadBalancerStatsServiceStub)
def get_client_stats( def get_client_stats(
self, *, self,
*,
num_rpcs: int, num_rpcs: int,
timeout_sec: Optional[int] = STATS_PARTIAL_RESULTS_TIMEOUT_SEC, timeout_sec: Optional[int] = STATS_PARTIAL_RESULTS_TIMEOUT_SEC,
) -> LoadBalancerStatsResponse: ) -> LoadBalancerStatsResponse:

@ -33,6 +33,7 @@ TEMPLATE_DIR = '../../kubernetes-manifests'
class KubernetesBaseRunner: class KubernetesBaseRunner:
def __init__(self, def __init__(self,
k8s_namespace, k8s_namespace,
namespace_template=None, namespace_template=None,
@ -50,8 +51,7 @@ class KubernetesBaseRunner:
self.namespace = self._reuse_namespace() self.namespace = self._reuse_namespace()
if not self.namespace: if not self.namespace:
self.namespace = self._create_namespace( self.namespace = self._create_namespace(
self.namespace_template, self.namespace_template, namespace_name=self.k8s_namespace.name)
namespace_name=self.k8s_namespace.name)
def cleanup(self, *, force=False): def cleanup(self, *, force=False):
if (self.namespace and not self.reuse_namespace) or force: if (self.namespace and not self.reuse_namespace) or force:
@ -127,27 +127,21 @@ class KubernetesBaseRunner:
raise RunnerError('Expected V1Namespace to be created ' raise RunnerError('Expected V1Namespace to be created '
f'from manifest {template}') f'from manifest {template}')
if namespace.metadata.name != kwargs['namespace_name']: if namespace.metadata.name != kwargs['namespace_name']:
raise RunnerError( raise RunnerError('Namespace created with unexpected name: '
'Namespace created with unexpected name: ' f'{namespace.metadata.name}')
f'{namespace.metadata.name}') logger.info('Deployment %s created at %s', namespace.metadata.self_link,
logger.info('Deployment %s created at %s',
namespace.metadata.self_link,
namespace.metadata.creation_timestamp) namespace.metadata.creation_timestamp)
return namespace return namespace
def _create_service_account( def _create_service_account(self, template,
self, **kwargs) -> k8s.V1ServiceAccount:
template,
**kwargs
) -> k8s.V1ServiceAccount:
resource = self._create_from_template(template, **kwargs) resource = self._create_from_template(template, **kwargs)
if not isinstance(resource, k8s.V1ServiceAccount): if not isinstance(resource, k8s.V1ServiceAccount):
raise RunnerError('Expected V1ServiceAccount to be created ' raise RunnerError('Expected V1ServiceAccount to be created '
f'from manifest {template}') f'from manifest {template}')
if resource.metadata.name != kwargs['service_account_name']: if resource.metadata.name != kwargs['service_account_name']:
raise RunnerError( raise RunnerError('V1ServiceAccount created with unexpected name: '
'V1ServiceAccount created with unexpected name: ' f'{resource.metadata.name}')
f'{resource.metadata.name}')
logger.info('V1ServiceAccount %s created at %s', logger.info('V1ServiceAccount %s created at %s',
resource.metadata.self_link, resource.metadata.self_link,
resource.metadata.creation_timestamp) resource.metadata.creation_timestamp)
@ -159,9 +153,8 @@ class KubernetesBaseRunner:
raise RunnerError('Expected V1Deployment to be created ' raise RunnerError('Expected V1Deployment to be created '
f'from manifest {template}') f'from manifest {template}')
if deployment.metadata.name != kwargs['deployment_name']: if deployment.metadata.name != kwargs['deployment_name']:
raise RunnerError( raise RunnerError('Deployment created with unexpected name: '
'Deployment created with unexpected name: ' f'{deployment.metadata.name}')
f'{deployment.metadata.name}')
logger.info('Deployment %s created at %s', logger.info('Deployment %s created at %s',
deployment.metadata.self_link, deployment.metadata.self_link,
deployment.metadata.creation_timestamp) deployment.metadata.creation_timestamp)
@ -173,11 +166,9 @@ class KubernetesBaseRunner:
raise RunnerError('Expected V1Service to be created ' raise RunnerError('Expected V1Service to be created '
f'from manifest {template}') f'from manifest {template}')
if service.metadata.name != kwargs['service_name']: if service.metadata.name != kwargs['service_name']:
raise RunnerError( raise RunnerError('Service created with unexpected name: '
'Service created with unexpected name: ' f'{service.metadata.name}')
f'{service.metadata.name}') logger.info('Service %s created at %s', service.metadata.self_link,
logger.info('Service %s created at %s',
service.metadata.self_link,
service.metadata.creation_timestamp) service.metadata.creation_timestamp)
return service return service
@ -185,8 +176,8 @@ class KubernetesBaseRunner:
try: try:
self.k8s_namespace.delete_deployment(name) self.k8s_namespace.delete_deployment(name)
except k8s.ApiException as e: except k8s.ApiException as e:
logger.info('Deployment %s deletion failed, error: %s %s', logger.info('Deployment %s deletion failed, error: %s %s', name,
name, e.status, e.reason) e.status, e.reason)
return return
if wait_for_deletion: if wait_for_deletion:
@ -197,8 +188,8 @@ class KubernetesBaseRunner:
try: try:
self.k8s_namespace.delete_service(name) self.k8s_namespace.delete_service(name)
except k8s.ApiException as e: except k8s.ApiException as e:
logger.info('Service %s deletion failed, error: %s %s', logger.info('Service %s deletion failed, error: %s %s', name,
name, e.status, e.reason) e.status, e.reason)
return return
if wait_for_deletion: if wait_for_deletion:
@ -232,8 +223,8 @@ class KubernetesBaseRunner:
def _wait_deployment_with_available_replicas(self, name, count=1, **kwargs): def _wait_deployment_with_available_replicas(self, name, count=1, **kwargs):
logger.info('Waiting for deployment %s to have %s available replicas', logger.info('Waiting for deployment %s to have %s available replicas',
name, count) name, count)
self.k8s_namespace.wait_for_deployment_available_replicas(name, count, self.k8s_namespace.wait_for_deployment_available_replicas(
**kwargs) name, count, **kwargs)
deployment = self.k8s_namespace.get_deployment(name) deployment = self.k8s_namespace.get_deployment(name)
logger.info('Deployment %s has %i replicas available', logger.info('Deployment %s has %i replicas available',
deployment.metadata.name, deployment.metadata.name,
@ -251,5 +242,5 @@ class KubernetesBaseRunner:
self.k8s_namespace.wait_for_service_neg(name, **kwargs) self.k8s_namespace.wait_for_service_neg(name, **kwargs)
neg_name, neg_zones = self.k8s_namespace.get_service_neg( neg_name, neg_zones = self.k8s_namespace.get_service_neg(
name, service_port) name, service_port)
logger.info("Service %s: detected NEG=%s in zones=%s", name, logger.info("Service %s: detected NEG=%s in zones=%s", name, neg_name,
neg_name, neg_zones) neg_zones)

@ -32,7 +32,9 @@ LoadBalancerStatsServiceClient = grpc_testing.LoadBalancerStatsServiceClient
class XdsTestClient(framework.rpc.GrpcApp): class XdsTestClient(framework.rpc.GrpcApp):
def __init__(self, *,
def __init__(self,
*,
ip: str, ip: str,
rpc_port: int, rpc_port: int,
server_target: str, server_target: str,
@ -55,7 +57,8 @@ class XdsTestClient(framework.rpc.GrpcApp):
return ChannelzServiceClient(self._make_channel(self.maintenance_port)) return ChannelzServiceClient(self._make_channel(self.maintenance_port))
def get_load_balancer_stats( def get_load_balancer_stats(
self, *, self,
*,
num_rpcs: int, num_rpcs: int,
timeout_sec: Optional[int] = None, timeout_sec: Optional[int] = None,
) -> grpc_testing.LoadBalancerStatsResponse: ) -> grpc_testing.LoadBalancerStatsResponse:
@ -76,16 +79,14 @@ class XdsTestClient(framework.rpc.GrpcApp):
stop=tenacity.stop_after_delay(60 * 3), stop=tenacity.stop_after_delay(60 * 3),
reraise=True) reraise=True)
channel = retryer(self.get_active_server_channel) channel = retryer(self.get_active_server_channel)
logger.info( logger.info('Active server channel found: channel_id: %s, %s',
'Active server channel found: channel_id: %s, %s', channel.ref.channel_id, channel.ref.name)
channel.ref.channel_id, channel.ref.name)
logger.debug('Server channel:\n%r', channel) logger.debug('Server channel:\n%r', channel)
def get_active_server_channel(self) -> Optional[grpc_channelz.Channel]: def get_active_server_channel(self) -> Optional[grpc_channelz.Channel]:
for channel in self.get_server_channels(): for channel in self.get_server_channels():
state: ChannelConnectivityState = channel.data.state state: ChannelConnectivityState = channel.data.state
logger.debug('Server channel: %s, state: %s', logger.debug('Server channel: %s, state: %s', channel.ref.name,
channel.ref.name,
ChannelConnectivityState.State.Name(state.state)) ChannelConnectivityState.State.Name(state.state))
if state.state is ChannelConnectivityState.READY: if state.state is ChannelConnectivityState.READY:
return channel return channel
@ -107,6 +108,7 @@ class XdsTestClient(framework.rpc.GrpcApp):
class KubernetesClientRunner(base_runner.KubernetesBaseRunner): class KubernetesClientRunner(base_runner.KubernetesBaseRunner):
def __init__(self, def __init__(self,
k8s_namespace, k8s_namespace,
*, *,
@ -142,9 +144,11 @@ class KubernetesClientRunner(base_runner.KubernetesBaseRunner):
self.service_account: Optional[k8s.V1ServiceAccount] = None self.service_account: Optional[k8s.V1ServiceAccount] = None
self.port_forwarder = None self.port_forwarder = None
def run(self, *, def run(self,
*,
server_target, server_target,
rpc='UnaryCall', qps=25, rpc='UnaryCall',
qps=25,
secure_mode=False, secure_mode=False,
print_response=False) -> XdsTestClient: print_response=False) -> XdsTestClient:
super().run() super().run()
@ -183,8 +187,8 @@ class KubernetesClientRunner(base_runner.KubernetesBaseRunner):
# Experimental, for local debugging. # Experimental, for local debugging.
if self.debug_use_port_forwarding: if self.debug_use_port_forwarding:
logger.info('Enabling port forwarding from %s:%s', logger.info('Enabling port forwarding from %s:%s', pod_ip,
pod_ip, self.stats_port) self.stats_port)
self.port_forwarder = self.k8s_namespace.port_forward_pod( self.port_forwarder = self.k8s_namespace.port_forward_pod(
pod, remote_port=self.stats_port) pod, remote_port=self.stats_port)
rpc_host = self.k8s_namespace.PORT_FORWARD_LOCAL_ADDRESS rpc_host = self.k8s_namespace.PORT_FORWARD_LOCAL_ADDRESS

@ -27,7 +27,9 @@ ChannelzServiceClient = grpc_channelz.ChannelzServiceClient
class XdsTestServer(framework.rpc.GrpcApp): class XdsTestServer(framework.rpc.GrpcApp):
def __init__(self, *,
def __init__(self,
*,
ip: str, ip: str,
rpc_port: int, rpc_port: int,
maintenance_port: Optional[int] = None, maintenance_port: Optional[int] = None,
@ -54,13 +56,16 @@ class XdsTestServer(framework.rpc.GrpcApp):
@property @property
def xds_address(self) -> str: def xds_address(self) -> str:
if not self.xds_host: return '' if not self.xds_host:
if not self.xds_port: return self.xds_host return ''
if not self.xds_port:
return self.xds_host
return f'{self.xds_host}:{self.xds_port}' return f'{self.xds_host}:{self.xds_port}'
@property @property
def xds_uri(self) -> str: def xds_uri(self) -> str:
if not self.xds_host: return '' if not self.xds_host:
return ''
return f'xds:///{self.xds_address}' return f'xds:///{self.xds_address}'
def get_test_server(self): def get_test_server(self):
@ -74,10 +79,8 @@ class XdsTestServer(framework.rpc.GrpcApp):
server = self.get_test_server() server = self.get_test_server()
return self.channelz.list_server_sockets(server.ref.server_id) return self.channelz.list_server_sockets(server.ref.server_id)
def get_server_socket_matching_client( def get_server_socket_matching_client(self,
self, client_socket: grpc_channelz.Socket):
client_socket: grpc_channelz.Socket
):
client_local = self.channelz.sock_address_to_str(client_socket.local) client_local = self.channelz.sock_address_to_str(client_socket.local)
logger.debug('Looking for a server socket connected to the client %s', logger.debug('Looking for a server socket connected to the client %s',
client_local) client_local)
@ -95,6 +98,7 @@ class XdsTestServer(framework.rpc.GrpcApp):
class KubernetesServerRunner(base_runner.KubernetesBaseRunner): class KubernetesServerRunner(base_runner.KubernetesBaseRunner):
def __init__(self, def __init__(self,
k8s_namespace, k8s_namespace,
*, *,
@ -140,9 +144,12 @@ class KubernetesServerRunner(base_runner.KubernetesBaseRunner):
self.service: Optional[k8s.V1Service] = None self.service: Optional[k8s.V1Service] = None
self.port_forwarder = None self.port_forwarder = None
def run(self, *, def run(self,
test_port=8080, maintenance_port=None, *,
secure_mode=False, server_id=None, test_port=8080,
maintenance_port=None,
secure_mode=False,
server_id=None,
replica_count=1) -> XdsTestServer: replica_count=1) -> XdsTestServer:
# todo(sergiitk): multiple replicas # todo(sergiitk): multiple replicas
if replica_count != 1: if replica_count != 1:
@ -201,8 +208,9 @@ class KubernetesServerRunner(base_runner.KubernetesBaseRunner):
server_id=server_id, server_id=server_id,
secure_mode=secure_mode) secure_mode=secure_mode)
self._wait_deployment_with_available_replicas( self._wait_deployment_with_available_replicas(self.deployment_name,
self.deployment_name, replica_count, timeout_sec=120) replica_count,
timeout_sec=120)
# Wait for pods running # Wait for pods running
pods = self.k8s_namespace.list_deployment_pods(self.deployment) pods = self.k8s_namespace.list_deployment_pods(self.deployment)
@ -215,19 +223,18 @@ class KubernetesServerRunner(base_runner.KubernetesBaseRunner):
rpc_host = None rpc_host = None
# Experimental, for local debugging. # Experimental, for local debugging.
if self.debug_use_port_forwarding: if self.debug_use_port_forwarding:
logger.info('Enabling port forwarding from %s:%s', logger.info('Enabling port forwarding from %s:%s', pod_ip,
pod_ip, maintenance_port) maintenance_port)
self.port_forwarder = self.k8s_namespace.port_forward_pod( self.port_forwarder = self.k8s_namespace.port_forward_pod(
pod, remote_port=maintenance_port) pod, remote_port=maintenance_port)
rpc_host = self.k8s_namespace.PORT_FORWARD_LOCAL_ADDRESS rpc_host = self.k8s_namespace.PORT_FORWARD_LOCAL_ADDRESS
return XdsTestServer( return XdsTestServer(ip=pod_ip,
ip=pod_ip, rpc_port=test_port,
rpc_port=test_port, maintenance_port=maintenance_port,
maintenance_port=maintenance_port, secure_mode=secure_mode,
secure_mode=secure_mode, server_id=server_id,
server_id=server_id, rpc_host=rpc_host)
rpc_host=rpc_host)
def cleanup(self, *, force=False, force_namespace=False): def cleanup(self, *, force=False, force_namespace=False):
if self.port_forwarder: if self.port_forwarder:

@ -15,35 +15,38 @@ from absl import flags
import googleapiclient.discovery import googleapiclient.discovery
# GCP # GCP
PROJECT = flags.DEFINE_string( PROJECT = flags.DEFINE_string("project",
"project", default=None, help="GCP Project ID. Required") default=None,
help="GCP Project ID. Required")
NAMESPACE = flags.DEFINE_string( NAMESPACE = flags.DEFINE_string(
"namespace", default=None, "namespace",
default=None,
help="Isolate GCP resources using given namespace / name prefix. Required") help="Isolate GCP resources using given namespace / name prefix. Required")
NETWORK = flags.DEFINE_string( NETWORK = flags.DEFINE_string("network",
"network", default="default", help="GCP Network ID") default="default",
help="GCP Network ID")
# Test server # Test server
SERVER_NAME = flags.DEFINE_string( SERVER_NAME = flags.DEFINE_string("server_name",
"server_name", default="psm-grpc-server", default="psm-grpc-server",
help="Server deployment and service name") help="Server deployment and service name")
SERVER_PORT = flags.DEFINE_integer( SERVER_PORT = flags.DEFINE_integer("server_port",
"server_port", default=8080, default=8080,
help="Server test port") help="Server test port")
SERVER_XDS_HOST = flags.DEFINE_string( SERVER_XDS_HOST = flags.DEFINE_string("server_xds_host",
"server_xds_host", default='xds-test-server', default='xds-test-server',
help="Test server xDS hostname") help="Test server xDS hostname")
SERVER_XDS_PORT = flags.DEFINE_integer( SERVER_XDS_PORT = flags.DEFINE_integer("server_xds_port",
"server_xds_port", default=8000, help="Test server xDS port") default=8000,
help="Test server xDS port")
# Test client # Test client
CLIENT_NAME = flags.DEFINE_string( CLIENT_NAME = flags.DEFINE_string("client_name",
"client_name", default="psm-grpc-client", default="psm-grpc-client",
help="Client deployment and service name") help="Client deployment and service name")
CLIENT_PORT = flags.DEFINE_integer( CLIENT_PORT = flags.DEFINE_integer("client_port",
"client_port", default=8079, default=8079,
help="Client test port") help="Client test port")
flags.mark_flags_as_required([ flags.mark_flags_as_required([
"project", "project",

@ -14,24 +14,28 @@
from absl import flags from absl import flags
# GCP # GCP
KUBE_CONTEXT = flags.DEFINE_string( KUBE_CONTEXT = flags.DEFINE_string("kube_context",
"kube_context", default=None, help="Kubectl context to use") default=None,
help="Kubectl context to use")
GCP_SERVICE_ACCOUNT = flags.DEFINE_string( GCP_SERVICE_ACCOUNT = flags.DEFINE_string(
"gcp_service_account", default=None, "gcp_service_account",
default=None,
help="GCP Service account for GKE workloads to impersonate") help="GCP Service account for GKE workloads to impersonate")
TD_BOOTSTRAP_IMAGE = flags.DEFINE_string( TD_BOOTSTRAP_IMAGE = flags.DEFINE_string(
"td_bootstrap_image", default=None, "td_bootstrap_image",
default=None,
help="Traffic Director gRPC Bootstrap Docker image") help="Traffic Director gRPC Bootstrap Docker image")
# Test app # Test app
SERVER_IMAGE = flags.DEFINE_string( SERVER_IMAGE = flags.DEFINE_string("server_image",
"server_image", default=None, default=None,
help="Server Docker image name") help="Server Docker image name")
CLIENT_IMAGE = flags.DEFINE_string( CLIENT_IMAGE = flags.DEFINE_string("client_image",
"client_image", default=None, default=None,
help="Client Docker image name") help="Client Docker image name")
CLIENT_PORT_FORWARDING = flags.DEFINE_bool( CLIENT_PORT_FORWARDING = flags.DEFINE_bool(
"client_debug_use_port_forwarding", default=False, "client_debug_use_port_forwarding",
default=False,
help="Development only: use kubectl port-forward to connect to test client") help="Development only: use kubectl port-forward to connect to test client")
flags.mark_flags_as_required([ flags.mark_flags_as_required([

@ -107,11 +107,9 @@ class XdsKubernetesTestCase(absltest.TestCase):
# Add backends to the Backend Service # Add backends to the Backend Service
self.td.backend_service_add_neg_backends(neg_name, neg_zones) self.td.backend_service_add_neg_backends(neg_name, neg_zones)
def assertSuccessfulRpcs( def assertSuccessfulRpcs(self,
self, test_client: XdsTestClient,
test_client: XdsTestClient, num_rpcs: int = 100):
num_rpcs: int = 100
):
# Run the test # Run the test
lb_stats = test_client.get_load_balancer_stats(num_rpcs=num_rpcs) lb_stats = test_client.get_load_balancer_stats(num_rpcs=num_rpcs)
# Check the results # Check the results
@ -123,17 +121,20 @@ class XdsKubernetesTestCase(absltest.TestCase):
logger.info(lb_stats.rpcs_by_peer) logger.info(lb_stats.rpcs_by_peer)
for backend, rpcs_count in lb_stats.rpcs_by_peer.items(): for backend, rpcs_count in lb_stats.rpcs_by_peer.items():
self.assertGreater( self.assertGreater(
int(rpcs_count), 0, int(rpcs_count),
0,
msg='Backend {backend} did not receive a single RPC') msg='Backend {backend} did not receive a single RPC')
def assertFailedRpcsAtMost(self, lb_stats, limit): def assertFailedRpcsAtMost(self, lb_stats, limit):
failed = int(lb_stats.num_failures) failed = int(lb_stats.num_failures)
self.assertLessEqual( self.assertLessEqual(
failed, limit, failed,
limit,
msg=f'Unexpected number of RPC failures {failed} > {limit}') msg=f'Unexpected number of RPC failures {failed} > {limit}')
class RegularXdsKubernetesTestCase(XdsKubernetesTestCase): class RegularXdsKubernetesTestCase(XdsKubernetesTestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
@ -168,15 +169,13 @@ class RegularXdsKubernetesTestCase(XdsKubernetesTestCase):
reuse_namespace=self.server_namespace == self.client_namespace) reuse_namespace=self.server_namespace == self.client_namespace)
def startTestServer(self, replica_count=1, **kwargs) -> XdsTestServer: def startTestServer(self, replica_count=1, **kwargs) -> XdsTestServer:
test_server = self.server_runner.run( test_server = self.server_runner.run(replica_count=replica_count,
replica_count=replica_count, test_port=self.server_port,
test_port=self.server_port, **kwargs)
**kwargs)
test_server.set_xds_address(self.server_xds_host, self.server_xds_port) test_server.set_xds_address(self.server_xds_host, self.server_xds_port)
return test_server return test_server
def startTestClient(self, def startTestClient(self, test_server: XdsTestServer,
test_server: XdsTestServer,
**kwargs) -> XdsTestClient: **kwargs) -> XdsTestClient:
test_client = self.client_runner.run(server_target=test_server.xds_uri, test_client = self.client_runner.run(server_target=test_server.xds_uri,
**kwargs) **kwargs)
@ -187,6 +186,7 @@ class RegularXdsKubernetesTestCase(XdsKubernetesTestCase):
class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase): class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
class SecurityMode(enum.Enum): class SecurityMode(enum.Enum):
MTLS = enum.auto() MTLS = enum.auto()
TLS = enum.auto() TLS = enum.auto()
@ -229,43 +229,39 @@ class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
debug_use_port_forwarding=self.client_port_forwarding) debug_use_port_forwarding=self.client_port_forwarding)
def startSecureTestServer(self, replica_count=1, **kwargs) -> XdsTestServer: def startSecureTestServer(self, replica_count=1, **kwargs) -> XdsTestServer:
test_server = self.server_runner.run( test_server = self.server_runner.run(replica_count=replica_count,
replica_count=replica_count, test_port=self.server_port,
test_port=self.server_port, maintenance_port=8081,
maintenance_port=8081, secure_mode=True,
secure_mode=True, **kwargs)
**kwargs)
test_server.set_xds_address(self.server_xds_host, self.server_xds_port) test_server.set_xds_address(self.server_xds_host, self.server_xds_port)
return test_server return test_server
def setupSecurityPolicies(self, *, def setupSecurityPolicies(self, *, server_tls, server_mtls, client_tls,
server_tls, server_mtls, client_mtls):
client_tls, client_mtls): self.td.setup_client_security(self.server_namespace,
self.td.setup_client_security(self.server_namespace, self.server_name, self.server_name,
tls=client_tls, mtls=client_mtls) tls=client_tls,
mtls=client_mtls)
self.td.setup_server_security(self.server_port, self.td.setup_server_security(self.server_port,
tls=server_tls, mtls=server_mtls) tls=server_tls,
mtls=server_mtls)
def startSecureTestClient(
self, def startSecureTestClient(self, test_server: XdsTestServer,
test_server: XdsTestServer, **kwargs) -> XdsTestClient:
**kwargs test_client = self.client_runner.run(server_target=test_server.xds_uri,
) -> XdsTestClient: secure_mode=True,
test_client = self.client_runner.run( **kwargs)
server_target=test_server.xds_uri,
secure_mode=True,
**kwargs)
logger.debug('Waiting fot the client to establish healthy channel with ' logger.debug('Waiting fot the client to establish healthy channel with '
'the server') 'the server')
test_client.wait_for_active_server_channel() test_client.wait_for_active_server_channel()
return test_client return test_client
def assertTestAppSecurity(self, def assertTestAppSecurity(self, mode: SecurityMode,
mode: SecurityMode,
test_client: XdsTestClient, test_client: XdsTestClient,
test_server: XdsTestServer): test_server: XdsTestServer):
client_socket, server_socket = self.getConnectedSockets(test_client, client_socket, server_socket = self.getConnectedSockets(
test_server) test_client, test_server)
server_security: grpc_channelz.Security = server_socket.security server_security: grpc_channelz.Security = server_socket.security
client_security: grpc_channelz.Security = client_socket.security client_security: grpc_channelz.Security = client_socket.security
logger.info('Server certs: %s', self.debug_sock_certs(server_security)) logger.info('Server certs: %s', self.debug_sock_certs(server_security))
@ -280,72 +276,70 @@ class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
else: else:
raise TypeError(f'Incorrect security mode') raise TypeError(f'Incorrect security mode')
def assertSecurityMtls(self, def assertSecurityMtls(self, client_security: grpc_channelz.Security,
client_security: grpc_channelz.Security,
server_security: grpc_channelz.Security): server_security: grpc_channelz.Security):
self.assertEqual(client_security.WhichOneof('model'), 'tls', self.assertEqual(client_security.WhichOneof('model'),
'tls',
msg='(mTLS) Client socket security model must be TLS') msg='(mTLS) Client socket security model must be TLS')
self.assertEqual(server_security.WhichOneof('model'), 'tls', self.assertEqual(server_security.WhichOneof('model'),
'tls',
msg='(mTLS) Server socket security model must be TLS') msg='(mTLS) Server socket security model must be TLS')
server_tls, client_tls = server_security.tls, client_security.tls server_tls, client_tls = server_security.tls, client_security.tls
# Confirm regular TLS: server local cert == client remote cert # Confirm regular TLS: server local cert == client remote cert
self.assertNotEmpty( self.assertNotEmpty(server_tls.local_certificate,
msg="(mTLS) Server local certificate is missing")
self.assertNotEmpty(client_tls.remote_certificate,
msg="(mTLS) Client remote certificate is missing")
self.assertEqual(
server_tls.local_certificate, server_tls.local_certificate,
msg="(mTLS) Server local certificate is missing")
self.assertNotEmpty(
client_tls.remote_certificate, client_tls.remote_certificate,
msg="(mTLS) Client remote certificate is missing")
self.assertEqual(
server_tls.local_certificate, client_tls.remote_certificate,
msg="(mTLS) Server local certificate must match client's " msg="(mTLS) Server local certificate must match client's "
"remote certificate") "remote certificate")
# mTLS: server remote cert == client local cert # mTLS: server remote cert == client local cert
self.assertNotEmpty( self.assertNotEmpty(server_tls.remote_certificate,
msg="(mTLS) Server remote certificate is missing")
self.assertNotEmpty(client_tls.local_certificate,
msg="(mTLS) Client local certificate is missing")
self.assertEqual(
server_tls.remote_certificate, server_tls.remote_certificate,
msg="(mTLS) Server remote certificate is missing")
self.assertNotEmpty(
client_tls.local_certificate, client_tls.local_certificate,
msg="(mTLS) Client local certificate is missing")
self.assertEqual(
server_tls.remote_certificate, client_tls.local_certificate,
msg="(mTLS) Server remote certificate must match client's " msg="(mTLS) Server remote certificate must match client's "
"local certificate") "local certificate")
# Success # Success
logger.info('mTLS security mode confirmed!') logger.info('mTLS security mode confirmed!')
def assertSecurityTls(self, def assertSecurityTls(self, client_security: grpc_channelz.Security,
client_security: grpc_channelz.Security,
server_security: grpc_channelz.Security): server_security: grpc_channelz.Security):
self.assertEqual(client_security.WhichOneof('model'), 'tls', self.assertEqual(client_security.WhichOneof('model'),
'tls',
msg='(TLS) Client socket security model must be TLS') msg='(TLS) Client socket security model must be TLS')
self.assertEqual(server_security.WhichOneof('model'), 'tls', self.assertEqual(server_security.WhichOneof('model'),
'tls',
msg='(TLS) Server socket security model must be TLS') msg='(TLS) Server socket security model must be TLS')
server_tls, client_tls = server_security.tls, client_security.tls server_tls, client_tls = server_security.tls, client_security.tls
# Regular TLS: server local cert == client remote cert # Regular TLS: server local cert == client remote cert
self.assertNotEmpty( self.assertNotEmpty(server_tls.local_certificate,
server_tls.local_certificate, msg="(TLS) Server local certificate is missing")
msg="(TLS) Server local certificate is missing") self.assertNotEmpty(client_tls.remote_certificate,
self.assertNotEmpty( msg="(TLS) Client remote certificate is missing")
client_tls.remote_certificate, self.assertEqual(server_tls.local_certificate,
msg="(TLS) Client remote certificate is missing") client_tls.remote_certificate,
self.assertEqual( msg="(TLS) Server local certificate must match client "
server_tls.local_certificate, client_tls.remote_certificate, "remote certificate")
msg="(TLS) Server local certificate must match client "
"remote certificate")
# mTLS must not be used # mTLS must not be used
self.assertEmpty( self.assertEmpty(
server_tls.remote_certificate, server_tls.remote_certificate,
msg="(TLS) Server remote certificate must be empty in TLS mode. " msg="(TLS) Server remote certificate must be empty in TLS mode. "
"Is server security incorrectly configured for mTLS?") "Is server security incorrectly configured for mTLS?")
self.assertEmpty( self.assertEmpty(
client_tls.local_certificate, client_tls.local_certificate,
msg="(TLS) Client local certificate must be empty in TLS mode. " msg="(TLS) Client local certificate must be empty in TLS mode. "
"Is client security incorrectly configured for mTLS?") "Is client security incorrectly configured for mTLS?")
# Success # Success
logger.info('TLS security mode confirmed!') logger.info('TLS security mode confirmed!')
@ -373,8 +367,7 @@ class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
@staticmethod @staticmethod
def getConnectedSockets( def getConnectedSockets(
test_client: XdsTestClient, test_client: XdsTestClient, test_server: XdsTestServer
test_server: XdsTestServer
) -> Tuple[grpc_channelz.Socket, grpc_channelz.Socket]: ) -> Tuple[grpc_channelz.Socket, grpc_channelz.Socket]:
client_sock = test_client.get_client_socket_with_test_server() client_sock = test_client.get_client_socket_with_test_server()
server_sock = test_server.get_server_socket_matching_client(client_sock) server_sock = test_server.get_server_socket_matching_client(client_sock)
@ -390,6 +383,7 @@ class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
@staticmethod @staticmethod
def debug_cert(cert): def debug_cert(cert):
if not cert: return 'missing' if not cert:
return 'missing'
sha1 = hashlib.sha1(cert) sha1 = hashlib.sha1(cert)
return f'sha1={sha1.hexdigest()}, len={len(cert)}' return f'sha1={sha1.hexdigest()}, len={len(cert)}'

@ -27,6 +27,7 @@ XdsTestClient = xds_k8s_testcase.XdsTestClient
class BaselineTest(xds_k8s_testcase.RegularXdsKubernetesTestCase): class BaselineTest(xds_k8s_testcase.RegularXdsKubernetesTestCase):
def test_ping_pong(self): def test_ping_pong(self):
self.setupTrafficDirectorGrpc() self.setupTrafficDirectorGrpc()

@ -29,10 +29,13 @@ SecurityMode = xds_k8s_testcase.SecurityXdsKubernetesTestCase.SecurityMode
class SecurityTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase): class SecurityTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase):
def test_mtls(self): def test_mtls(self):
self.setupTrafficDirectorGrpc() self.setupTrafficDirectorGrpc()
self.setupSecurityPolicies(server_tls=True, server_mtls=True, self.setupSecurityPolicies(server_tls=True,
client_tls=True, client_mtls=True) server_mtls=True,
client_tls=True,
client_mtls=True)
test_server: XdsTestServer = self.startSecureTestServer() test_server: XdsTestServer = self.startSecureTestServer()
self.setupServerBackends() self.setupServerBackends()
@ -43,8 +46,10 @@ class SecurityTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase):
def test_tls(self): def test_tls(self):
self.setupTrafficDirectorGrpc() self.setupTrafficDirectorGrpc()
self.setupSecurityPolicies(server_tls=True, server_mtls=False, self.setupSecurityPolicies(server_tls=True,
client_tls=True, client_mtls=False) server_mtls=False,
client_tls=True,
client_mtls=False)
test_server: XdsTestServer = self.startSecureTestServer() test_server: XdsTestServer = self.startSecureTestServer()
self.setupServerBackends() self.setupServerBackends()
@ -55,15 +60,17 @@ class SecurityTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase):
def test_plaintext_fallback(self): def test_plaintext_fallback(self):
self.setupTrafficDirectorGrpc() self.setupTrafficDirectorGrpc()
self.setupSecurityPolicies(server_tls=False, server_mtls=False, self.setupSecurityPolicies(server_tls=False,
client_tls=False, client_mtls=False) server_mtls=False,
client_tls=False,
client_mtls=False)
test_server: XdsTestServer = self.startSecureTestServer() test_server: XdsTestServer = self.startSecureTestServer()
self.setupServerBackends() self.setupServerBackends()
test_client: XdsTestClient = self.startSecureTestClient(test_server) test_client: XdsTestClient = self.startSecureTestClient(test_server)
self.assertTestAppSecurity( self.assertTestAppSecurity(SecurityMode.PLAINTEXT, test_client,
SecurityMode.PLAINTEXT, test_client, test_server) test_server)
self.assertSuccessfulRpcs(test_client) self.assertSuccessfulRpcs(test_client)
@absltest.skip(SKIP_REASON) @absltest.skip(SKIP_REASON)

Loading…
Cancel
Save