From 5e23b2dcb7f65f24d638a0c5246592e398684419 Mon Sep 17 00:00:00 2001 From: Richard Belleville Date: Mon, 8 Feb 2021 13:49:46 -0800 Subject: [PATCH] Pull out context manager --- .../tests/unit/_xds_credentials_test.py | 78 ++++++++++--------- 1 file changed, 42 insertions(+), 36 deletions(-) diff --git a/src/python/grpcio_tests/tests/unit/_xds_credentials_test.py b/src/python/grpcio_tests/tests/unit/_xds_credentials_test.py index 55a2bee6021..d8839174398 100644 --- a/src/python/grpcio_tests/tests/unit/_xds_credentials_test.py +++ b/src/python/grpcio_tests/tests/unit/_xds_credentials_test.py @@ -17,6 +17,7 @@ import unittest import logging from concurrent import futures +import contextlib import grpc import grpc.experimental @@ -31,54 +32,59 @@ class _GenericHandler(grpc.GenericRpcHandler): lambda request, unused_context: request) +@contextlib.contextmanager +def xds_channel_server_without_xds(server_fallback_creds): + server = grpc.server(futures.ThreadPoolExecutor()) + server.add_generic_rpc_handlers((_GenericHandler(),)) + server_server_fallback_creds = grpc.ssl_server_credentials( + ((resources.private_key(), resources.certificate_chain()),)) + server_creds = grpc.xds_server_credentials(server_fallback_creds) + port = server.add_secure_port("localhost:0", server_creds) + server.start() + try: + yield "localhost:{}".format(port) + finally: + server.stop(None) + + class XdsCredentialsTest(unittest.TestCase): def test_xds_creds_fallback_ssl(self): # Since there is no xDS server, the fallback credentials will be used. # In this case, SSL credentials. - server = grpc.server(futures.ThreadPoolExecutor()) - server.add_generic_rpc_handlers((_GenericHandler(),)) server_fallback_creds = grpc.ssl_server_credentials( ((resources.private_key(), resources.certificate_chain()),)) - server_creds = grpc.xds_server_credentials(server_fallback_creds) - port = server.add_secure_port("localhost:0", server_creds) - server.start() - channel_fallback_creds = grpc.ssl_channel_credentials( - root_certificates=resources.test_root_certificates(), - private_key=resources.private_key(), - certificate_chain=resources.certificate_chain()) - channel_creds = grpc.xds_channel_credentials(channel_fallback_creds) - server_address = "localhost:{}".format(port) - override_options = (("grpc.ssl_target_name_override", - "foo.test.google.fr"),) - with grpc.secure_channel(server_address, - channel_creds, - options=override_options) as channel: - request = b"abc" - response = channel.unary_unary("/test/method")(request, - wait_for_ready=True) - self.assertEqual(response, request) - server.stop(None) + with xds_channel_server_without_xds( + server_fallback_creds) as server_address: + override_options = (("grpc.ssl_target_name_override", + "foo.test.google.fr"),) + channel_fallback_creds = grpc.ssl_channel_credentials( + root_certificates=resources.test_root_certificates(), + private_key=resources.private_key(), + certificate_chain=resources.certificate_chain()) + channel_creds = grpc.xds_channel_credentials(channel_fallback_creds) + with grpc.secure_channel(server_address, + channel_creds, + options=override_options) as channel: + request = b"abc" + response = channel.unary_unary("/test/method")( + request, wait_for_ready=True) + self.assertEqual(response, request) def test_xds_creds_fallback_insecure(self): # Since there is no xDS server, the fallback credentials will be used. # In this case, insecure. - server = grpc.server(futures.ThreadPoolExecutor()) - server.add_generic_rpc_handlers((_GenericHandler(),)) server_fallback_creds = grpc.insecure_server_credentials() - server_creds = grpc.xds_server_credentials(server_fallback_creds) - port = server.add_secure_port("localhost:0", server_creds) - server.start() - channel_fallback_creds = grpc.experimental.insecure_channel_credentials( - ) - channel_creds = grpc.xds_channel_credentials(channel_fallback_creds) - server_address = "localhost:{}".format(port) - with grpc.secure_channel(server_address, channel_creds) as channel: - request = b"abc" - response = channel.unary_unary("/test/method")(request, - wait_for_ready=True) - self.assertEqual(response, request) - server.stop(None) + with xds_channel_server_without_xds( + server_fallback_creds) as server_address: + channel_fallback_creds = grpc.experimental.insecure_channel_credentials( + ) + channel_creds = grpc.xds_channel_credentials(channel_fallback_creds) + with grpc.secure_channel(server_address, channel_creds) as channel: + request = b"abc" + response = channel.unary_unary("/test/method")( + request, wait_for_ready=True) + self.assertEqual(response, request) def test_start_xds_server(self): server = grpc.server(futures.ThreadPoolExecutor(), xds=True)