diff --git a/src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py b/src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py index 45ed9691bcc..87d31a7b33c 100644 --- a/src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py +++ b/src/python/grpcio_tests/tests_py3_only/unit/_simple_stubs_test.py @@ -32,6 +32,7 @@ import time from typing import Callable, Optional from tests.unit import test_common +from tests.unit import resources import grpc import grpc.experimental @@ -51,6 +52,12 @@ _STREAM_UNARY = "/test/StreamUnary" _STREAM_STREAM = "/test/StreamStream" +@contextlib.contextmanager +def _env(key: str, value: str): + os.environ[key] = value + yield + del os.environ[key] + def _unary_unary_handler(request, context): return request @@ -153,115 +160,140 @@ class SimpleStubsTest(unittest.TestCase): else: self.fail(message() + " after " + str(timeout)) - def test_unary_unary_insecure(self): - with _server(None) as port: - target = f'localhost:{port}' - response = grpc.experimental.unary_unary( - _REQUEST, - target, - _UNARY_UNARY, - channel_credentials=grpc.experimental. - insecure_channel_credentials()) - self.assertEqual(_REQUEST, response) - - def test_unary_unary_secure(self): - with _server(grpc.local_server_credentials()) as port: - target = f'localhost:{port}' - response = grpc.experimental.unary_unary( - _REQUEST, - target, - _UNARY_UNARY, - channel_credentials=grpc.local_channel_credentials()) - self.assertEqual(_REQUEST, response) - - def test_channels_cached(self): - with _server(grpc.local_server_credentials()) as port: - target = f'localhost:{port}' - test_name = inspect.stack()[0][3] - args = (_REQUEST, target, _UNARY_UNARY) - kwargs = {"channel_credentials": grpc.local_channel_credentials()} - - def _invoke(seed: str): - run_kwargs = dict(kwargs) - run_kwargs["options"] = ((test_name + seed, ""),) - grpc.experimental.unary_unary(*args, **run_kwargs) - - self.assert_cached(_invoke) - - def test_channels_evicted(self): - with _server(grpc.local_server_credentials()) as port: - target = f'localhost:{port}' - response = grpc.experimental.unary_unary( - _REQUEST, - target, - _UNARY_UNARY, - channel_credentials=grpc.local_channel_credentials()) - self.assert_eventually( - lambda: grpc._simple_stubs.ChannelCache.get( - )._test_only_channel_count() == 0, - message=lambda: - f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} remain" - ) - - def test_total_channels_enforced(self): - with _server(grpc.local_server_credentials()) as port: - target = f'localhost:{port}' - for i in range(_STRESS_EPOCHS): - # Ensure we get a new channel each time. - options = (("foo", str(i)),) - # Send messages at full blast. - grpc.experimental.unary_unary( - _REQUEST, - target, - _UNARY_UNARY, - options=options, - channel_credentials=grpc.local_channel_credentials()) - self.assert_eventually( - lambda: grpc._simple_stubs.ChannelCache.get( - )._test_only_channel_count() <= _MAXIMUM_CHANNELS + 1, - message=lambda: - f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} channels remain" - ) - - def test_unary_stream(self): - with _server(grpc.local_server_credentials()) as port: - target = f'localhost:{port}' - for response in grpc.experimental.unary_stream( - _REQUEST, - target, - _UNARY_STREAM, - channel_credentials=grpc.local_channel_credentials()): - self.assertEqual(_REQUEST, response) - - def test_stream_unary(self): - - def request_iter(): - for _ in range(_CLIENT_REQUEST_COUNT): - yield _REQUEST - - with _server(grpc.local_server_credentials()) as port: - target = f'localhost:{port}' - response = grpc.experimental.stream_unary( - request_iter(), - target, - _STREAM_UNARY, - channel_credentials=grpc.local_channel_credentials()) - self.assertEqual(_REQUEST, response) - - def test_stream_stream(self): - - def request_iter(): - for _ in range(_CLIENT_REQUEST_COUNT): - yield _REQUEST - - with _server(grpc.local_server_credentials()) as port: - target = f'localhost:{port}' - for response in grpc.experimental.stream_stream( - request_iter(), - target, - _STREAM_STREAM, - channel_credentials=grpc.local_channel_credentials()): - self.assertEqual(_REQUEST, response) + # def test_unary_unary_insecure(self): + # with _server(None) as port: + # target = f'localhost:{port}' + # response = grpc.experimental.unary_unary( + # _REQUEST, + # target, + # _UNARY_UNARY, + # channel_credentials=grpc.experimental. + # insecure_channel_credentials()) + # self.assertEqual(_REQUEST, response) + + # def test_unary_unary_secure(self): + # with _server(grpc.local_server_credentials()) as port: + # target = f'localhost:{port}' + # response = grpc.experimental.unary_unary( + # _REQUEST, + # target, + # _UNARY_UNARY, + # channel_credentials=grpc.local_channel_credentials()) + # self.assertEqual(_REQUEST, response) + + # def test_channels_cached(self): + # with _server(grpc.local_server_credentials()) as port: + # target = f'localhost:{port}' + # test_name = inspect.stack()[0][3] + # args = (_REQUEST, target, _UNARY_UNARY) + # kwargs = {"channel_credentials": grpc.local_channel_credentials()} + + # def _invoke(seed: str): + # run_kwargs = dict(kwargs) + # run_kwargs["options"] = ((test_name + seed, ""),) + # grpc.experimental.unary_unary(*args, **run_kwargs) + + # self.assert_cached(_invoke) + + # def test_channels_evicted(self): + # with _server(grpc.local_server_credentials()) as port: + # target = f'localhost:{port}' + # response = grpc.experimental.unary_unary( + # _REQUEST, + # target, + # _UNARY_UNARY, + # channel_credentials=grpc.local_channel_credentials()) + # self.assert_eventually( + # lambda: grpc._simple_stubs.ChannelCache.get( + # )._test_only_channel_count() == 0, + # message=lambda: + # f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} remain" + # ) + + # def test_total_channels_enforced(self): + # with _server(grpc.local_server_credentials()) as port: + # target = f'localhost:{port}' + # for i in range(_STRESS_EPOCHS): + # # Ensure we get a new channel each time. + # options = (("foo", str(i)),) + # # Send messages at full blast. + # grpc.experimental.unary_unary( + # _REQUEST, + # target, + # _UNARY_UNARY, + # options=options, + # channel_credentials=grpc.local_channel_credentials()) + # self.assert_eventually( + # lambda: grpc._simple_stubs.ChannelCache.get( + # )._test_only_channel_count() <= _MAXIMUM_CHANNELS + 1, + # message=lambda: + # f"{grpc._simple_stubs.ChannelCache.get()._test_only_channel_count()} channels remain" + # ) + + # def test_unary_stream(self): + # with _server(grpc.local_server_credentials()) as port: + # target = f'localhost:{port}' + # for response in grpc.experimental.unary_stream( + # _REQUEST, + # target, + # _UNARY_STREAM, + # channel_credentials=grpc.local_channel_credentials()): + # self.assertEqual(_REQUEST, response) + + # def test_stream_unary(self): + + # def request_iter(): + # for _ in range(_CLIENT_REQUEST_COUNT): + # yield _REQUEST + + # with _server(grpc.local_server_credentials()) as port: + # target = f'localhost:{port}' + # response = grpc.experimental.stream_unary( + # request_iter(), + # target, + # _STREAM_UNARY, + # channel_credentials=grpc.local_channel_credentials()) + # self.assertEqual(_REQUEST, response) + + # def test_stream_stream(self): + + # def request_iter(): + # for _ in range(_CLIENT_REQUEST_COUNT): + # yield _REQUEST + + # with _server(grpc.local_server_credentials()) as port: + # target = f'localhost:{port}' + # for response in grpc.experimental.stream_stream( + # request_iter(), + # target, + # _STREAM_STREAM, + # channel_credentials=grpc.local_channel_credentials()): + # self.assertEqual(_REQUEST, response) + + def test_default_ssl(self): + _PRIVATE_KEY = resources.private_key() + _CERTIFICATE_CHAIN = resources.certificate_chain() + _SERVER_CERTS = ((_PRIVATE_KEY, _CERTIFICATE_CHAIN),) + _SERVER_HOST_OVERRIDE = 'foo.test.google.fr' + _TEST_ROOT_CERTIFICATES = resources.test_root_certificates() + _PROPERTY_OPTIONS = (( + 'grpc.ssl_target_name_override', + _SERVER_HOST_OVERRIDE, + ),) + cert_dir = os.path.join(os.path.dirname(resources.__file__), "credentials") + print(f"cert_dir: {cert_dir}") + cert_file = os.path.join(cert_dir, "ca.pem") + with _env("SSL_CERT_FILE", cert_file): + server_creds = grpc.ssl_server_credentials(_SERVER_CERTS) + with _server(server_creds) as port: + target = f'localhost:{port}' + # channel_creds = grpc.ssl_channel_credentials(root_certificates=_TEST_ROOT_CERTIFICATES) + channel_creds = grpc.ssl_channel_credentials() + response = grpc.experimental.unary_unary(_REQUEST, + target, + _UNARY_UNARY, + options=_PROPERTY_OPTIONS, + channel_credentials=channel_creds) if __name__ == "__main__":