From 44e5409a4e61756b9bccd51edd3d03c837f2e41d Mon Sep 17 00:00:00 2001
From: Yash Tibrewal <yashkt@google.com>
Date: Wed, 3 Mar 2021 19:31:52 -0800
Subject: [PATCH] Option to disable local_certificate check from PSM security
 tests (#25595)

* Remove local_certificate check from PSM security tests

* Reviewer comments

* Update tools/run_tests/xds_k8s_test_driver/framework/xds_k8s_testcase.py

Co-authored-by: Sergii Tkachenko <hi@sergii.org>

* YAPF code

Co-authored-by: Sergii Tkachenko <hi@sergii.org>
---
 .../framework/xds_k8s_testcase.py             | 53 ++++++++++++-------
 1 file changed, 33 insertions(+), 20 deletions(-)

diff --git a/tools/run_tests/xds_k8s_test_driver/framework/xds_k8s_testcase.py b/tools/run_tests/xds_k8s_test_driver/framework/xds_k8s_testcase.py
index 9669239db1d..7c7f60a093b 100644
--- a/tools/run_tests/xds_k8s_test_driver/framework/xds_k8s_testcase.py
+++ b/tools/run_tests/xds_k8s_test_driver/framework/xds_k8s_testcase.py
@@ -36,6 +36,12 @@ _FORCE_CLEANUP = flags.DEFINE_bool(
     "force_cleanup",
     default=False,
     help="Force resource cleanup, even if not created by this test run")
+# TODO(yashkt): We will no longer need this flag once Core exposes local certs
+# from channelz
+_CHECK_LOCAL_CERTS = flags.DEFINE_bool(
+    "check_local_certs",
+    default=True,
+    help="Security Tests also check the value of local certs")
 flags.adopt_module_key_flags(xds_flags)
 flags.adopt_module_key_flags(xds_k8s_flags)
 
@@ -83,6 +89,7 @@ class XdsKubernetesTestCase(absltest.TestCase):
         cls.force_cleanup = _FORCE_CLEANUP.value
         cls.debug_use_port_forwarding = \
             xds_k8s_flags.DEBUG_USE_PORT_FORWARDING.value
+        cls.check_local_certs = _CHECK_LOCAL_CERTS.value
 
         # Resource managers
         cls.k8s_api_manager = k8s.KubernetesApiManager(
@@ -340,26 +347,30 @@ class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
         server_tls, client_tls = server_security.tls, client_security.tls
 
         # Confirm regular TLS: server local cert == client remote cert
-        self.assertNotEmpty(server_tls.local_certificate,
-                            msg="(mTLS) Server local certificate is missing")
         self.assertNotEmpty(client_tls.remote_certificate,
                             msg="(mTLS) Client remote certificate is missing")
-        self.assertEqual(
-            server_tls.local_certificate,
-            client_tls.remote_certificate,
-            msg="(mTLS) Server local certificate must match client's "
-            "remote certificate")
+        if cls.check_local_certs:
+            self.assertNotEmpty(
+                server_tls.local_certificate,
+                msg="(mTLS) Server local certificate is missing")
+            self.assertEqual(
+                server_tls.local_certificate,
+                client_tls.remote_certificate,
+                msg="(mTLS) Server local certificate must match client's "
+                "remote certificate")
 
         # mTLS: server remote cert == client local cert
         self.assertNotEmpty(server_tls.remote_certificate,
                             msg="(mTLS) Server remote certificate is missing")
-        self.assertNotEmpty(client_tls.local_certificate,
-                            msg="(mTLS) Client local certificate is missing")
-        self.assertEqual(
-            server_tls.remote_certificate,
-            client_tls.local_certificate,
-            msg="(mTLS) Server remote certificate must match client's "
-            "local certificate")
+        if cls.check_local_certs:
+            self.assertNotEmpty(
+                client_tls.local_certificate,
+                msg="(mTLS) Client local certificate is missing")
+            self.assertEqual(
+                server_tls.remote_certificate,
+                client_tls.local_certificate,
+                msg="(mTLS) Server remote certificate must match client's "
+                "local certificate")
 
     def assertSecurityTls(self, client_security: grpc_channelz.Security,
                           server_security: grpc_channelz.Security):
@@ -372,14 +383,16 @@ class SecurityXdsKubernetesTestCase(XdsKubernetesTestCase):
         server_tls, client_tls = server_security.tls, client_security.tls
 
         # Regular TLS: server local cert == client remote cert
-        self.assertNotEmpty(server_tls.local_certificate,
-                            msg="(TLS) Server local certificate is missing")
         self.assertNotEmpty(client_tls.remote_certificate,
                             msg="(TLS) Client remote certificate is missing")
-        self.assertEqual(server_tls.local_certificate,
-                         client_tls.remote_certificate,
-                         msg="(TLS) Server local certificate must match client "
-                         "remote certificate")
+        if cls.check_local_certs:
+            self.assertNotEmpty(server_tls.local_certificate,
+                                msg="(TLS) Server local certificate is missing")
+            self.assertEqual(
+                server_tls.local_certificate,
+                client_tls.remote_certificate,
+                msg="(TLS) Server local certificate must match client "
+                "remote certificate")
 
         # mTLS must not be used
         self.assertEmpty(