Merge pull request #25075 from grpc/xds-k8s-mtls-error-test

xds-k8s driver: implement PSM security mtls_error test
pull/25079/head
Sergii Tkachenko 4 years ago committed by GitHub
commit d92f41fea4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 143
      tools/run_tests/xds_k8s_test_driver/bin/run_channelz.py
  2. 53
      tools/run_tests/xds_k8s_test_driver/bin/run_td_setup.py
  3. 13
      tools/run_tests/xds_k8s_test_driver/framework/helpers/__init__.py
  4. 53
      tools/run_tests/xds_k8s_test_driver/framework/helpers/retryers.py
  5. 2
      tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/api.py
  6. 10
      tools/run_tests/xds_k8s_test_driver/framework/infrastructure/gcp/compute.py
  7. 77
      tools/run_tests/xds_k8s_test_driver/framework/infrastructure/traffic_director.py
  8. 6
      tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc.py
  9. 29
      tools/run_tests/xds_k8s_test_driver/framework/rpc/grpc_channelz.py
  10. 138
      tools/run_tests/xds_k8s_test_driver/framework/test_app/client_app.py
  11. 28
      tools/run_tests/xds_k8s_test_driver/framework/test_app/server_app.py
  12. 66
      tools/run_tests/xds_k8s_test_driver/framework/xds_k8s_testcase.py
  13. 2
      tools/run_tests/xds_k8s_test_driver/tests/baseline_test.py
  14. 77
      tools/run_tests/xds_k8s_test_driver/tests/security_test.py

@ -11,6 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Channelz debugging tool for xDS test client/server.
This is intended as a debugging / local development helper and not executed
as a part of interop test suites.
Typical usage examples:
# Show channel and socket info
python -m bin.run_channelz --flagfile=config/local-dev.cfg
# Evaluate setup for mtls_error test case
python -m bin.run_channelz --flagfile=config/local-dev.cfg --security=mtls_error
# More information and usage options
python -m bin.run_channelz --helpfull
"""
import hashlib
import logging
@ -21,8 +37,8 @@ from framework import xds_flags
from framework import xds_k8s_flags
from framework.infrastructure import k8s
from framework.rpc import grpc_channelz
from framework.test_app import server_app
from framework.test_app import client_app
from framework.test_app import server_app
logger = logging.getLogger(__name__)
# Flags
@ -32,11 +48,17 @@ _SERVER_RPC_HOST = flags.DEFINE_string('server_rpc_host',
_CLIENT_RPC_HOST = flags.DEFINE_string('client_rpc_host',
default='127.0.0.1',
help='Client RPC host')
_SECURITY = flags.DEFINE_enum('security',
default='positive_cases',
enum_values=['positive_cases', 'mtls_error'],
help='Test for security setup')
flags.adopt_module_key_flags(xds_flags)
flags.adopt_module_key_flags(xds_k8s_flags)
# Type aliases
_Channel = grpc_channelz.Channel
_Socket = grpc_channelz.Socket
_ChannelState = grpc_channelz.ChannelState
_XdsTestServer = server_app.XdsTestServer
_XdsTestClient = client_app.XdsTestClient
@ -59,65 +81,112 @@ def get_deployment_pod_ips(k8s_ns, deployment_name):
return [pod.status.pod_ip for pod in pods]
def negative_case_mtls(test_client, test_server):
"""Debug mTLS Error case.
Server expects client mTLS cert, but client configured only for TLS.
"""
# Client side.
client_correct_setup = True
channel: _Channel = test_client.wait_for_server_channel_state(
state=_ChannelState.TRANSIENT_FAILURE)
try:
subchannel, *subchannels = list(
test_client.channelz.list_channel_subchannels(channel))
except ValueError:
print("(mTLS-error) Client setup fail: subchannel not found. "
"Common causes: test client didn't connect to TD; "
"test client exhausted retries, and closed all subchannels.")
return
# Client must have exactly one subchannel.
logger.debug('Found subchannel, %s', subchannel)
if subchannels:
client_correct_setup = False
print(f'(mTLS-error) Unexpected subchannels {subchannels}')
subchannel_state: _ChannelState = subchannel.data.state.state
if subchannel_state is not _ChannelState.TRANSIENT_FAILURE:
client_correct_setup = False
print('(mTLS-error) Subchannel expected to be in '
'TRANSIENT_FAILURE, same as its channel')
# Client subchannel must have no sockets.
sockets = list(test_client.channelz.list_subchannels_sockets(subchannel))
if sockets:
client_correct_setup = False
print(f'(mTLS-error) Unexpected subchannel sockets {sockets}')
# Results.
if client_correct_setup:
print('(mTLS-error) Client setup pass: the channel '
'to the server has exactly one subchannel '
'in TRANSIENT_FAILURE, and no sockets')
def positive_case_all(test_client, test_server):
"""Debug positive cases: mTLS, TLS, Plaintext."""
test_client.wait_for_active_server_channel()
client_sock: _Socket = test_client.get_active_server_channel_socket()
server_sock: _Socket = test_server.get_server_socket_matching_client(
client_sock)
server_tls = server_sock.security.tls
client_tls = client_sock.security.tls
print(f'\nServer certs:\n{debug_sock_tls(server_tls)}')
print(f'\nClient certs:\n{debug_sock_tls(client_tls)}')
print()
if server_tls.local_certificate:
eq = server_tls.local_certificate == client_tls.remote_certificate
print(f'(TLS) Server local matches client remote: {eq}')
else:
print('(TLS) Not detected')
if server_tls.remote_certificate:
eq = server_tls.remote_certificate == client_tls.local_certificate
print(f'(mTLS) Server remote matches client local: {eq}')
else:
print('(mTLS) Not detected')
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
k8s_api_manager = k8s.KubernetesApiManager(xds_k8s_flags.KUBE_CONTEXT.value)
# Namespaces
namespace = xds_flags.NAMESPACE.value
server_namespace = namespace
client_namespace = namespace
# Server
server_k8s_ns = k8s.KubernetesNamespace(k8s_api_manager, server_namespace)
server_name = xds_flags.SERVER_NAME.value
server_port = xds_flags.SERVER_PORT.value
server_namespace = xds_flags.NAMESPACE.value
server_k8s_ns = k8s.KubernetesNamespace(k8s_api_manager, server_namespace)
server_pod_ip = get_deployment_pod_ips(server_k8s_ns, server_name)[0]
test_server: _XdsTestServer = _XdsTestServer(
ip=server_pod_ip,
rpc_port=server_port,
rpc_port=xds_flags.SERVER_PORT.value,
xds_host=xds_flags.SERVER_XDS_HOST.value,
xds_port=xds_flags.SERVER_XDS_PORT.value,
rpc_host=_SERVER_RPC_HOST.value)
# Client
client_k8s_ns = k8s.KubernetesNamespace(k8s_api_manager, client_namespace)
client_name = xds_flags.CLIENT_NAME.value
client_port = xds_flags.CLIENT_PORT.value
client_namespace = xds_flags.NAMESPACE.value
client_k8s_ns = k8s.KubernetesNamespace(k8s_api_manager, client_namespace)
client_pod_ip = get_deployment_pod_ips(client_k8s_ns, client_name)[0]
test_client: _XdsTestClient = _XdsTestClient(
ip=client_pod_ip,
server_target=test_server.xds_uri,
rpc_port=client_port,
rpc_port=xds_flags.CLIENT_PORT.value,
rpc_host=_CLIENT_RPC_HOST.value)
with test_client, test_server:
test_client.wait_for_active_server_channel()
client_sock: _Socket = test_client.get_client_socket_with_test_server()
server_sock: _Socket = test_server.get_server_socket_matching_client(
client_sock)
server_tls = server_sock.security.tls
client_tls = client_sock.security.tls
print(f'\nServer certs:\n{debug_sock_tls(server_tls)}')
print(f'\nClient certs:\n{debug_sock_tls(client_tls)}')
print()
if server_tls.local_certificate:
eq = server_tls.local_certificate == client_tls.remote_certificate
print(f'(TLS) Server local matches client remote: {eq}')
else:
print('(TLS) Not detected')
if server_tls.remote_certificate:
eq = server_tls.remote_certificate == client_tls.local_certificate
print(f'(mTLS) Server remote matches client local: {eq}')
else:
print('(mTLS) Not detected')
# Run checks
if _SECURITY.value in 'positive_cases':
positive_case_all(test_client, test_server)
elif _SECURITY.value == 'mtls_error':
negative_case_mtls(test_client, test_server)
test_client.close()
test_server.close()
if __name__ == '__main__':

@ -11,6 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configure Traffic Director for different GRPC Proxyless.
This is intended as a debugging / local development helper and not executed
as a part of interop test suites.
Typical usage examples:
# Regular proxyless setup
python -m bin.run_td_setup --flagfile=config/local-dev.cfg
# Additional commands: cleanup, backend management, etc.
python -m bin.run_td_setup --flagfile=config/local-dev.cfg --cmd=cleanup
# PSM security setup options: mtls, tls, etc.
python -m bin.run_td_setup --flagfile=config/local-dev.cfg --security=mtls
# More information and usage options
python -m bin.run_td_setup --helpfull
"""
import logging
from absl import app
@ -31,10 +50,11 @@ _CMD = flags.DEFINE_enum('cmd',
'backends-cleanup'
],
help='Command')
_SECURITY = flags.DEFINE_enum('security',
default=None,
enum_values=['mtls', 'tls', 'plaintext'],
help='Configure td with security')
_SECURITY = flags.DEFINE_enum(
'security',
default=None,
enum_values=['mtls', 'tls', 'plaintext', 'mtls_error'],
help='Configure TD with security')
flags.adopt_module_key_flags(xds_flags)
flags.adopt_module_key_flags(xds_k8s_flags)
@ -70,10 +90,9 @@ def main(argv):
resource_prefix=namespace,
network=network)
# noinspection PyBroadException
try:
if command == 'create' or command == 'cycle':
logger.info('Create-only mode')
if command in ('create', 'cycle'):
logger.info('Create mode')
if security_mode is None:
logger.info('No security')
td.setup_for_grpc(server_xds_host, server_xds_port)
@ -117,11 +136,26 @@ def main(argv):
tls=False,
mtls=False)
elif security_mode == 'mtls_error':
# Error case: server expects client mTLS cert,
# but client configured only for TLS
logger.info('Setting up mtls_error')
td.setup_for_grpc(server_xds_host, server_xds_port)
td.setup_server_security(server_namespace=namespace,
server_name=server_name,
server_port=server_port,
tls=True,
mtls=True)
td.setup_client_security(server_namespace=namespace,
server_name=server_name,
tls=True,
mtls=False)
logger.info('Works!')
except Exception:
except Exception: # noqa pylint: disable=broad-except
logger.exception('Got error during creation')
if command == 'cleanup' or command == 'cycle':
if command in ('cleanup', 'cycle'):
logger.info('Cleaning up')
td.cleanup(force=True)
@ -136,6 +170,7 @@ def main(argv):
td.load_backend_service()
td.backend_service_add_neg_backends(neg_name, neg_zones)
td.wait_for_backends_healthy_status()
# TODO(sergiitk): wait until client reports rpc health
elif command == 'backends-cleanup':
td.load_backend_service()

@ -0,0 +1,13 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,53 @@
# Copyright 2020 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This contains common retrying helpers (retryers).
We use tenacity as a general-purpose retrying library.
> It [tenacity] originates from a fork of retrying which is sadly no
> longer maintained. Tenacity isnt api compatible with retrying but >
> adds significant new functionality and fixes a number of longstanding bugs.
> - https://tenacity.readthedocs.io/en/latest/index.html
"""
import datetime
from typing import Any, List, Optional
import tenacity
# Type aliases
timedelta = datetime.timedelta
Retrying = tenacity.Retrying
_retry_if_exception_type = tenacity.retry_if_exception_type
_stop_after_delay = tenacity.stop_after_delay
_wait_exponential = tenacity.wait_exponential
def _retry_on_exceptions(retry_on_exceptions: Optional[List[Any]] = None):
# Retry on all exceptions by default
if retry_on_exceptions is None:
retry_on_exceptions = (Exception,)
return _retry_if_exception_type(retry_on_exceptions)
def exponential_retryer_with_timeout(
*,
wait_min: timedelta,
wait_max: timedelta,
timeout: timedelta,
retry_on_exceptions: Optional[List[Any]] = None) -> Retrying:
return Retrying(retry=_retry_on_exceptions(retry_on_exceptions),
wait=_wait_exponential(min=wait_min.total_seconds(),
max=wait_max.total_seconds()),
stop=_stop_after_delay(timeout.total_seconds()),
reraise=True)

@ -20,7 +20,7 @@ from typing import Optional
# Workaround: `grpc` must be imported before `google.protobuf.json_format`,
# to prevent "Segmentation fault". Ref https://github.com/grpc/grpc/issues/24897
# TODO(sergiitk): Remove after #24897 is solved
import grpc # noqa # pylint: disable=unused-import
import grpc # noqa pylint: disable=unused-import
from absl import flags
from google.cloud import secretmanager_v1
from google.longrunning import operations_pb2

@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import enum
import logging
from typing import Optional, Dict, Any
from typing import Any, Dict, Optional
import dataclasses
import googleapiclient.errors
from googleapiclient import discovery
import googleapiclient.errors
# TODO(sergiitk): replace with tenacity
import retrying
@ -28,8 +28,8 @@ logger = logging.getLogger(__name__)
class ComputeV1(gcp.api.GcpProjectApiResource):
# TODO(sergiitk): move someplace better
_WAIT_FOR_BACKEND_SEC = 1200
_WAIT_FOR_OPERATION_SEC = 1200
_WAIT_FOR_BACKEND_SEC = 60 * 5
_WAIT_FOR_OPERATION_SEC = 60 * 5
@dataclasses.dataclass(frozen=True)
class GcpResource:

@ -21,10 +21,11 @@ logger = logging.getLogger(__name__)
# Type aliases
# Compute
_ComputeV1 = gcp.compute.ComputeV1
HealthCheckProtocol = _ComputeV1.HealthCheckProtocol
BackendServiceProtocol = _ComputeV1.BackendServiceProtocol
GcpResource = _ComputeV1.GcpResource
HealthCheckProtocol = _ComputeV1.HealthCheckProtocol
ZonalGcpResource = _ComputeV1.ZonalGcpResource
BackendServiceProtocol = _ComputeV1.BackendServiceProtocol
_BackendGRPC = BackendServiceProtocol.GRPC
# Network Security
_NetworkSecurityV1Alpha1 = gcp.network_security.NetworkSecurityV1Alpha1
@ -64,6 +65,8 @@ class TrafficDirectorManager:
# Managed resources
self.health_check: Optional[GcpResource] = None
self.backend_service: Optional[GcpResource] = None
# TODO(sergiitk): remove this flag once backend service resource loaded
self.backend_service_protocol: Optional[BackendServiceProtocol] = None
self.url_map: Optional[GcpResource] = None
self.target_proxy: Optional[GcpResource] = None
# TODO(sergiitk): remove this flag once target proxy resource loaded
@ -75,18 +78,23 @@ class TrafficDirectorManager:
def network_url(self):
return f'global/networks/{self.network}'
def setup_for_grpc(self,
service_host,
service_port,
*,
backend_protocol=BackendServiceProtocol.GRPC):
def setup_for_grpc(
self,
service_host,
service_port,
*,
backend_protocol: Optional[BackendServiceProtocol] = _BackendGRPC):
self.setup_backend_for_grpc(protocol=backend_protocol)
self.setup_routing_rule_map_for_grpc(service_host, service_port)
def setup_backend_for_grpc(
self, *, protocol: Optional[BackendServiceProtocol] = _BackendGRPC):
self.create_health_check()
self.create_backend_service(protocol=backend_protocol)
self.create_backend_service(protocol)
def setup_routing_rule_map_for_grpc(self, service_host, service_port):
self.create_url_map(service_host, service_port)
if backend_protocol is BackendServiceProtocol.GRPC:
self.create_target_grpc_proxy()
else:
self.create_target_http_proxy()
self.create_target_proxy()
self.create_forwarding_rule(service_port)
def cleanup(self, *, force=False):
@ -105,8 +113,8 @@ class TrafficDirectorManager:
def create_health_check(self, protocol=HealthCheckProtocol.TCP):
if self.health_check:
raise ValueError('Health check %s already created, delete it first',
self.health_check.name)
raise ValueError(f'Health check {self.health_check.name} '
'already created, delete it first')
name = self._ns_name(self.HEALTH_CHECK_NAME)
logger.info('Creating %s Health Check "%s"', protocol.name, name)
if protocol is HealthCheckProtocol.TCP:
@ -128,13 +136,16 @@ class TrafficDirectorManager:
self.health_check = None
def create_backend_service(
self,
protocol: BackendServiceProtocol = BackendServiceProtocol.GRPC):
self, protocol: Optional[BackendServiceProtocol] = _BackendGRPC):
if protocol is None:
protocol = _BackendGRPC
name = self._ns_name(self.BACKEND_SERVICE_NAME)
logger.info('Creating %s Backend Service "%s"', protocol.name, name)
resource = self.compute.create_backend_service_traffic_director(
name, health_check=self.health_check, protocol=protocol)
self.backend_service = resource
self.backend_service_protocol = protocol
def load_backend_service(self):
name = self._ns_name(self.BACKEND_SERVICE_NAME)
@ -153,15 +164,13 @@ class TrafficDirectorManager:
self.backend_service = None
def backend_service_add_neg_backends(self, name, zones):
logger.info('Waiting for Network Endpoint Groups recognize endpoints.')
logger.info('Waiting for Network Endpoint Groups to load endpoints.')
for zone in zones:
backend = self.compute.wait_for_network_endpoint_group(name, zone)
logger.info('Loaded NEG "%s" in zone %s', backend.name,
backend.zone)
self.backends.add(backend)
self.backend_service_add_backends()
self.wait_for_backends_healthy_status()
def backend_service_add_backends(self):
logging.info('Adding backends to Backend Service %s: %r',
@ -208,13 +217,22 @@ class TrafficDirectorManager:
self.compute.delete_url_map(name)
self.url_map = None
def create_target_grpc_proxy(self):
# TODO(sergiitk): merge with create_target_http_proxy()
def create_target_proxy(self):
name = self._ns_name(self.TARGET_PROXY_NAME)
logger.info('Creating target GRPC proxy "%s" to URL map %s', name,
self.url_map.name)
resource = self.compute.create_target_grpc_proxy(name, self.url_map)
self.target_proxy = resource
if self.backend_service_protocol is BackendServiceProtocol.GRPC:
target_proxy_type = 'GRPC'
create_proxy_fn = self.compute.create_target_grpc_proxy
self.target_proxy_is_http = False
elif self.backend_service_protocol is BackendServiceProtocol.HTTP2:
target_proxy_type = 'HTTP'
create_proxy_fn = self.compute.create_target_http_proxy
self.target_proxy_is_http = True
else:
raise TypeError('Unexpected backend service protocol')
logger.info('Creating target %s proxy "%s" to URL map %s', name,
target_proxy_type, self.url_map.name)
self.target_proxy = create_proxy_fn(name, self.url_map)
def delete_target_grpc_proxy(self, force=False):
if force:
@ -228,15 +246,6 @@ class TrafficDirectorManager:
self.target_proxy = None
self.target_proxy_is_http = False
def create_target_http_proxy(self):
# TODO(sergiitk): merge with create_target_grpc_proxy()
name = self._ns_name(self.TARGET_PROXY_NAME)
logger.info('Creating target HTTP proxy "%s" to url map %s', name,
self.url_map.name)
resource = self.compute.create_target_http_proxy(name, self.url_map)
self.target_proxy = resource
self.target_proxy_is_http = True
def delete_target_http_proxy(self, force=False):
if force:
name = self._ns_name(self.TARGET_PROXY_NAME)

@ -13,7 +13,7 @@
# limitations under the License.
import logging
import re
from typing import Optional, ClassVar, Dict
from typing import ClassVar, Dict, Optional
# Workaround: `grpc` must be imported before `google.protobuf.json_format`,
# to prevent "Segmentation fault". Ref https://github.com/grpc/grpc/issues/24897
@ -73,6 +73,10 @@ class GrpcApp:
class NotFound(Exception):
"""Requested resource not found"""
def __init__(self, message):
self.message = message
super().__init__(message)
def __init__(self, rpc_host):
self.rpc_host = rpc_host
# Cache gRPC channels per port

@ -17,7 +17,7 @@ https://github.com/grpc/grpc-proto/blob/master/grpc/channelz/v1/channelz.proto
"""
import ipaddress
import logging
from typing import Optional, Iterator
from typing import Iterator, Optional
import grpc
from grpc_channelz.v1 import channelz_pb2
@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
# Channel
Channel = channelz_pb2.Channel
ChannelConnectivityState = channelz_pb2.ChannelConnectivityState
ChannelState = ChannelConnectivityState.State # pylint: disable=no-member
_GetTopChannelsRequest = channelz_pb2.GetTopChannelsRequest
_GetTopChannelsResponse = channelz_pb2.GetTopChannelsResponse
# Subchannel
@ -143,8 +144,11 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
start = max(start, server.ref.server_id)
yield server
def list_server_sockets(self, server_id) -> Iterator[Socket]:
"""Iterate over all server sockets that exist in server process."""
def list_server_sockets(self, server: Server) -> Iterator[Socket]:
"""List all server sockets that exist in server process.
Iterating over the results will resolve additional pages automatically.
"""
start: int = -1
response: Optional[_GetServerSocketsResponse] = None
while start < 0 or not response.end:
@ -153,7 +157,7 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
start += 1
response = self.call_unary_with_deadline(
rpc='GetServerSockets',
req=_GetServerSocketsRequest(server_id=server_id,
req=_GetServerSocketsRequest(server_id=server.ref.server_id,
start_socket_id=start))
socket_ref: SocketRef
for socket_ref in response.socket_ref:
@ -161,6 +165,23 @@ class ChannelzServiceClient(framework.rpc.grpc.GrpcClientHelper):
# Yield actual socket
yield self.get_socket(socket_ref.socket_id)
def list_channel_sockets(self, channel: Channel) -> Iterator[Socket]:
"""List all sockets of all subchannels of a given channel."""
for subchannel in self.list_channel_subchannels(channel):
yield from self.list_subchannels_sockets(subchannel)
def list_channel_subchannels(self,
channel: Channel) -> Iterator[Subchannel]:
"""List all subchannels of a given channel."""
for subchannel_ref in channel.subchannel_ref:
yield self.get_subchannel(subchannel_ref.subchannel_id)
def list_subchannels_sockets(self,
subchannel: Subchannel) -> Iterator[Socket]:
"""List all sockets of a given subchannel."""
for socket_ref in subchannel.socket_ref:
yield self.get_socket(socket_ref.socket_id)
def get_subchannel(self, subchannel_id) -> Subchannel:
"""Return a single Subchannel, otherwise raises RpcError."""
response: _GetSubchannelResponse = self.call_unary_with_deadline(

@ -17,12 +17,12 @@ xDS Test Client.
TODO(sergiitk): separate XdsTestClient and KubernetesClientRunner to individual
modules.
"""
import datetime
import functools
import logging
from typing import Optional, Iterator
import tenacity
from typing import Iterator, Optional
from framework.helpers import retryers
from framework.infrastructure import k8s
import framework.rpc
from framework.rpc import grpc_channelz
@ -32,9 +32,13 @@ from framework.test_app import base_runner
logger = logging.getLogger(__name__)
# Type aliases
_ChannelzServiceClient = grpc_channelz.ChannelzServiceClient
_ChannelConnectivityState = grpc_channelz.ChannelConnectivityState
_timedelta = datetime.timedelta
_LoadBalancerStatsServiceClient = grpc_testing.LoadBalancerStatsServiceClient
_ChannelzServiceClient = grpc_channelz.ChannelzServiceClient
_ChannelzChannel = grpc_channelz.Channel
_ChannelzChannelState = grpc_channelz.ChannelState
_ChannelzSubchannel = grpc_channelz.Subchannel
_ChannelzSocket = grpc_channelz.Socket
class XdsTestClient(framework.rpc.grpc.GrpcApp):
@ -79,47 +83,103 @@ class XdsTestClient(framework.rpc.grpc.GrpcApp):
return self.load_balancer_stats.get_client_stats(
num_rpcs=num_rpcs, timeout_sec=timeout_sec)
def get_server_channels(self) -> Iterator[grpc_channelz.Channel]:
def get_server_channels(self) -> Iterator[_ChannelzChannel]:
return self.channelz.find_channels_for_target(self.server_target)
def wait_for_active_server_channel(self):
retryer = tenacity.Retrying(
retry=(tenacity.retry_if_result(lambda r: r is None) |
tenacity.retry_if_exception_type()),
wait=tenacity.wait_exponential(min=10, max=25),
stop=tenacity.stop_after_delay(60 * 3),
reraise=True)
logger.info(
'Waiting for client %s to establish READY gRPC channel with %s',
self.ip, self.server_target)
channel = retryer(self.get_active_server_channel)
logger.info(
'gRPC channel between client %s and %s transitioned to READY:\n%s',
self.ip, self.server_target, channel)
def get_active_server_channel(self) -> Optional[grpc_channelz.Channel]:
for channel in self.get_server_channels():
state: _ChannelConnectivityState = channel.data.state
logger.info('Server channel: %s, state: %s', channel.ref.name,
_ChannelConnectivityState.State.Name(state.state))
if state.state is _ChannelConnectivityState.READY:
return channel
raise self.NotFound('Client has no active channel with the server')
def wait_for_active_server_channel(self) -> _ChannelzChannel:
"""Wait for the channel to the server to transition to READY.
Raises:
GrpcApp.NotFound: If the channel never transitioned to READY.
"""
return self.wait_for_server_channel_state(_ChannelzChannelState.READY)
def get_active_server_channel(self) -> _ChannelzChannel:
"""Return a READY channel to the server.
Raises:
GrpcApp.NotFound: If there's no READY channel to the server.
"""
return self.find_server_channel_with_state(_ChannelzChannelState.READY)
def get_client_socket_with_test_server(self) -> grpc_channelz.Socket:
def get_active_server_channel_socket(self) -> _ChannelzSocket:
channel = self.get_active_server_channel()
logger.debug('Retrieving client->server socket: channel %s',
channel.ref.name)
# Get the first subchannel of the active server channel
subchannel_id = channel.subchannel_ref[0].subchannel_id
subchannel = self.channelz.get_subchannel(subchannel_id)
logger.debug('Retrieving client->server socket: subchannel %s',
subchannel.ref.name)
# Get the first subchannel of the active channel to the server.
logger.debug(
'Retrieving client -> server socket, '
'channel_id: %s, subchannel: %s', channel.ref.channel_id,
channel.subchannel_ref[0].name)
subchannel, *subchannels = list(
self.channelz.list_channel_subchannels(channel))
if subchannels:
logger.warning('Unexpected subchannels: %r', subchannels)
# Get the first socket of the subchannel
socket = self.channelz.get_socket(subchannel.socket_ref[0].socket_id)
logger.debug('Found client->server socket: %s', socket.ref.name)
socket, *sockets = list(
self.channelz.list_subchannels_sockets(subchannel))
if sockets:
logger.warning('Unexpected sockets: %r', subchannels)
logger.debug('Found client -> server socket: %s', socket.ref.name)
return socket
def wait_for_server_channel_state(self,
state: _ChannelzChannelState,
*,
timeout: Optional[_timedelta] = None
) -> _ChannelzChannel:
# Fine-tuned to wait for the channel to the server.
retryer = retryers.exponential_retryer_with_timeout(
wait_min=_timedelta(seconds=10),
wait_max=_timedelta(seconds=25),
timeout=_timedelta(minutes=3) if timeout is None else timeout)
logger.info('Waiting for client %s to report a %s channel to %s',
self.ip, _ChannelzChannelState.Name(state),
self.server_target)
channel = retryer(self.find_server_channel_with_state, state)
logger.info('Client %s channel to %s transitioned to state %s:\n%s',
self.ip, self.server_target,
_ChannelzChannelState.Name(state), channel)
return channel
def find_server_channel_with_state(self,
state: _ChannelzChannelState,
*,
check_subchannel=True
) -> _ChannelzChannel:
for channel in self.get_server_channels():
channel_state: _ChannelzChannelState = channel.data.state.state
logger.info('Server channel: %s, state: %s', channel.ref.name,
_ChannelzChannelState.Name(channel_state))
if channel_state is state:
if check_subchannel:
# When requested, check if the channel has at least
# one subchannel in the requested state.
try:
subchannel = self.find_subchannel_with_state(
channel, state)
logger.info('Found subchannel in state %s: %s', state,
subchannel)
except self.NotFound as e:
# Otherwise, keep searching.
logger.info(e.message)
continue
return channel
raise self.NotFound(
f'Client has no {_ChannelzChannelState.Name(state)} channel with '
'the server')
def find_subchannel_with_state(self, channel: _ChannelzChannel,
state: _ChannelzChannelState
) -> _ChannelzSubchannel:
for subchannel in self.channelz.list_channel_subchannels(channel):
if subchannel.data.state.state is state:
return subchannel
raise self.NotFound(
f'Not found a {_ChannelzChannelState.Name(state)} '
f'subchannel for channel_id {channel.ref.channel_id}')
class KubernetesClientRunner(base_runner.KubernetesBaseRunner):

@ -19,7 +19,7 @@ modules.
"""
import functools
import logging
from typing import Optional
from typing import Iterator, Optional
from framework.infrastructure import k8s
import framework.rpc
@ -78,19 +78,37 @@ class XdsTestServer(framework.rpc.grpc.GrpcApp):
return ''
return f'xds:///{self.xds_address}'
def get_test_server(self):
def get_test_server(self) -> grpc_channelz.Server:
"""Return channelz representation of a server running TestService.
Raises:
GrpcApp.NotFound: Test server not found.
"""
server = self.channelz.find_server_listening_on_port(self.rpc_port)
if not server:
raise self.NotFound(
f'Server listening on port {self.rpc_port} not found')
return server
def get_test_server_sockets(self):
def get_test_server_sockets(self) -> Iterator[grpc_channelz.Socket]:
"""List all sockets of the test server.
Raises:
GrpcApp.NotFound: Test server not found.
"""
server = self.get_test_server()
return self.channelz.list_server_sockets(server.ref.server_id)
return self.channelz.list_server_sockets(server)
def get_server_socket_matching_client(self,
client_socket: grpc_channelz.Socket):
"""Find test server socket that matches given test client socket.
Sockets are matched using TCP endpoints (ip:port), further on "address".
Server socket remote address matched with client socket local address.
Raises:
GrpcApp.NotFound: Server socket matching client socket not found.
"""
client_local = self.channelz.sock_address_to_str(client_socket.local)
logger.debug('Looking for a server socket connected to the client %s',
client_local)
@ -99,7 +117,7 @@ class XdsTestServer(framework.rpc.grpc.GrpcApp):
self.get_test_server_sockets(), client_socket)
if not server_socket:
raise self.NotFound(
f'Server socket for client {client_local} not found')
f'Server socket to client {client_local} not found')
logger.info('Found matching socket pair: server(%s) <-> client(%s)',
self.channelz.sock_addresses_pretty(server_socket),

@ -14,7 +14,7 @@
import enum
import hashlib
import logging
from typing import Tuple
from typing import Optional, Tuple
from absl import flags
from absl.testing import absltest
@ -40,16 +40,14 @@ flags.adopt_module_key_flags(xds_k8s_flags)
# Type aliases
XdsTestServer = server_app.XdsTestServer
XdsTestClient = client_app.XdsTestClient
_LoadBalancerStatsResponse = grpc_testing.LoadBalancerStatsResponse
LoadBalancerStatsResponse = grpc_testing.LoadBalancerStatsResponse
_ChannelState = grpc_channelz.ChannelState
class XdsKubernetesTestCase(absltest.TestCase):
k8s_api_manager: k8s.KubernetesApiManager
gcp_api_manager: gcp.api.GcpApiManager
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@classmethod
def setUpClass(cls):
# GCP
@ -110,26 +108,41 @@ class XdsKubernetesTestCase(absltest.TestCase):
def setupTrafficDirectorGrpc(self):
self.td.setup_for_grpc(self.server_xds_host, self.server_xds_port)
def setupServerBackends(self):
def setupServerBackends(self, *, wait_for_healthy_status=True):
# Load Backends
neg_name, neg_zones = self.server_runner.k8s_namespace.get_service_neg(
self.server_runner.service_name, self.server_port)
# Add backends to the Backend Service
self.td.backend_service_add_neg_backends(neg_name, neg_zones)
if wait_for_healthy_status:
self.td.wait_for_backends_healthy_status()
def assertSuccessfulRpcs(self,
test_client: XdsTestClient,
num_rpcs: int = 100):
# Run the test
lb_stats: _LoadBalancerStatsResponse
lb_stats = self.sendRpcs(test_client, num_rpcs)
self.assertAllBackendsReceivedRpcs(lb_stats)
self.assertFailedRpcsAtMost(lb_stats, 0)
def assertFailedRpcs(self,
test_client: XdsTestClient,
num_rpcs: Optional[int] = 100):
lb_stats = self.sendRpcs(test_client, num_rpcs)
failed = int(lb_stats.num_failures)
self.assertEqual(
failed,
num_rpcs,
msg=f'Expected all {num_rpcs} RPCs to fail, but {failed} failed')
@staticmethod
def sendRpcs(test_client: XdsTestClient,
num_rpcs: int) -> LoadBalancerStatsResponse:
lb_stats = test_client.get_load_balancer_stats(num_rpcs=num_rpcs)
logger.info(
'Received LoadBalancerStatsResponse from test client %s:\n%s',
test_client.ip, lb_stats)
# Check the results
self.assertAllBackendsReceivedRpcs(lb_stats)
self.assertFailedRpcsAtMost(lb_stats, 0)
return lb_stats
def assertAllBackendsReceivedRpcs(self, lb_stats):
# TODO(sergiitk): assert backends length
@ -261,12 +274,16 @@ class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
tls=server_tls,
mtls=server_mtls)
def startSecureTestClient(self, test_server: XdsTestServer,
def startSecureTestClient(self,
test_server: XdsTestServer,
*,
wait_for_active_server_channel=True,
**kwargs) -> XdsTestClient:
test_client = self.client_runner.run(server_target=test_server.xds_uri,
secure_mode=True,
**kwargs)
test_client.wait_for_active_server_channel()
if wait_for_active_server_channel:
test_client.wait_for_active_server_channel()
return test_client
def assertTestAppSecurity(self, mode: SecurityMode,
@ -286,7 +303,7 @@ class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
elif mode is self.SecurityMode.PLAINTEXT:
self.assertSecurityPlaintext(client_security, server_security)
else:
raise TypeError(f'Incorrect security mode')
raise TypeError('Incorrect security mode')
def assertSecurityMtls(self, client_security: grpc_channelz.Security,
server_security: grpc_channelz.Security):
@ -377,11 +394,30 @@ class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
# Success
logger.info('Plaintext security mode confirmed!')
def assertMtlsErrorSetup(self, test_client: XdsTestClient):
channel = test_client.wait_for_server_channel_state(
state=_ChannelState.TRANSIENT_FAILURE)
subchannels = list(
test_client.channelz.list_channel_subchannels(channel))
self.assertLen(subchannels,
1,
msg="Client channel must have exactly one subchannel "
"in state TRANSIENT_FAILURE.")
sockets = list(
test_client.channelz.list_subchannels_sockets(subchannels[0]))
self.assertEmpty(sockets, msg="Client subchannel must have no sockets")
# With negative tests we can't be absolutely certain expected
# failure state is not caused by something else.
logger.info(
"Client's connectivity state is consistent with a mTLS error "
"caused by not presenting mTLS certificate to the server.")
@staticmethod
def getConnectedSockets(
test_client: XdsTestClient, test_server: XdsTestServer
) -> Tuple[grpc_channelz.Socket, grpc_channelz.Socket]:
client_sock = test_client.get_client_socket_with_test_server()
client_sock = test_client.get_active_server_channel_socket()
server_sock = test_server.get_server_socket_matching_client(client_sock)
return client_sock, server_sock

@ -39,7 +39,7 @@ class BaselineTest(xds_k8s_testcase.RegularXdsKubernetesTestCase):
self.td.create_url_map(self.server_xds_host, self.server_xds_port)
with self.subTest('3_create_target_proxy'):
self.td.create_target_grpc_proxy()
self.td.create_target_proxy()
with self.subTest('4_create_forwarding_rule'):
self.td.create_forwarding_rule(self.server_xds_port)

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from absl import flags
from absl.testing import absltest
@ -31,6 +32,10 @@ _SecurityMode = xds_k8s_testcase.SecurityXdsKubernetesTestCase.SecurityMode
class SecurityTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase):
def test_mtls(self):
"""mTLS test.
Both client and server configured to use TLS and mTLS.
"""
self.setupTrafficDirectorGrpc()
self.setupSecurityPolicies(server_tls=True,
server_mtls=True,
@ -45,6 +50,10 @@ class SecurityTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase):
self.assertSuccessfulRpcs(test_client)
def test_tls(self):
"""TLS test.
Both client and server configured to use TLS and not use mTLS.
"""
self.setupTrafficDirectorGrpc()
self.setupSecurityPolicies(server_tls=True,
server_mtls=False,
@ -59,6 +68,11 @@ class SecurityTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase):
self.assertSuccessfulRpcs(test_client)
def test_plaintext_fallback(self):
"""Plain-text fallback test.
Control plane provides no security config so both client and server
fallback to plaintext based on fallback-credentials.
"""
self.setupTrafficDirectorGrpc()
self.setupSecurityPolicies(server_tls=False,
server_mtls=False,
@ -73,13 +87,70 @@ class SecurityTest(xds_k8s_testcase.SecurityXdsKubernetesTestCase):
test_server)
self.assertSuccessfulRpcs(test_client)
@absltest.skip(SKIP_REASON)
def test_mtls_error(self):
pass
"""Negative test: mTLS Error.
Server expects client mTLS cert, but client configured only for TLS.
Note: because this is a negative test we need to make sure the mTLS
failure happens after receiving the correct configuration at the
client. To ensure that we will perform the following steps in that
sequence:
- Creation of a backendService, and attaching the backend (NEG)
- Creation of the Server mTLS Policy, and attaching to the ECS
- Creation of the Client TLS Policy, and attaching to the backendService
- Creation of the urlMap, targetProxy, and forwardingRule
With this sequence we are sure that when the client receives the
endpoints of the backendService the security-config would also have
been received as confirmed by the TD team.
"""
# Create backend service
self.td.setup_backend_for_grpc()
# Start server and attach its NEGs to the backend service
test_server: _XdsTestServer = self.startSecureTestServer()
self.setupServerBackends(wait_for_healthy_status=False)
# Setup policies and attach them.
self.setupSecurityPolicies(server_tls=True,
server_mtls=True,
client_tls=True,
client_mtls=False)
# Create the routing rule map
self.td.setup_routing_rule_map_for_grpc(self.server_xds_host,
self.server_xds_port)
# Wait for backends healthy after url map is created
self.td.wait_for_backends_healthy_status()
# Start the client.
test_client: _XdsTestClient = self.startSecureTestClient(
test_server, wait_for_active_server_channel=False)
# With negative tests we can't be absolutely certain expected
# failure state is not caused by something else.
# To mitigate for this, we repeat the checks a few times in case
# the channel eventually stabilizes and RPCs pass.
# TODO(sergiitk): use tenacity retryer, move nums to constants
wait_sec = 10
checks = 3
for check in range(1, checks + 1):
self.assertMtlsErrorSetup(test_client)
self.assertFailedRpcs(test_client)
if check != checks:
logger.info(
'Check %s successful, waiting %s sec before the next check',
check, wait_sec)
time.sleep(wait_sec)
@absltest.skip(SKIP_REASON)
def test_server_authz_error(self):
pass
"""Negative test: AuthZ error.
Client does not authorize server because of mismatched SAN name.
"""
if __name__ == '__main__':

Loading…
Cancel
Save