diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi index f49681a4588..03b4990e488 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi @@ -24,3 +24,4 @@ cdef class AioChannel: object loop bytes _target AioChannelStatus _status + bint _is_secure diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi index a2882e64b7f..21100444863 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi @@ -36,11 +36,13 @@ cdef class AioChannel: self._status = AIO_CHANNEL_STATUS_READY if credentials is None: + self._is_secure = False self.channel = grpc_insecure_channel_create( target, channel_args.c_args(), NULL) else: + self._is_secure = True self.channel = grpc_secure_channel_create( credentials.c(), target, @@ -122,6 +124,9 @@ cdef class AioChannel: cdef CallCredentials cython_call_credentials if python_call_credentials is not None: + if not self._is_secure: + raise RuntimeError("Call credentials are only valid on secure channels") + cython_call_credentials = python_call_credentials._credentials else: cython_call_credentials = None diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_server.py b/src/python/grpcio_tests/tests_aio/unit/_test_server.py index 7c8afa8ff5c..bcc6e3bc304 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_server.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_server.py @@ -17,6 +17,7 @@ import datetime import grpc from grpc.experimental import aio +from tests.unit import resources from src.proto.grpc.testing import empty_pb2, messages_pb2, test_pb2_grpc from tests_aio.unit import _constants @@ -37,6 +38,11 @@ async def _maybe_echo_metadata(servicer_context): invocation_metadata[_TRAILING_METADATA_KEY]) servicer_context.set_trailing_metadata((trailing_metadatum,)) +_PRIVATE_KEY = resources.private_key() +_CERTIFICATE_CHAIN = resources.certificate_chain() +_TEST_ROOT_CERTIFICATES = resources.test_root_certificates() +_SERVER_CERTS = ((_PRIVATE_KEY, _CERTIFICATE_CHAIN),) + async def _maybe_echo_status(request: messages_pb2.SimpleRequest, servicer_context): @@ -129,8 +135,11 @@ async def start_test_server(port=0, if secure: if server_credentials is None: - server_credentials = grpc.local_server_credentials( - grpc.LocalConnectionType.LOCAL_TCP) + server_credentials = grpc.ssl_server_credentials( + _SERVER_CERTS, + root_certificates=_TEST_ROOT_CERTIFICATES, + require_client_auth=True + ) port = server.add_secure_port('[::]:%d' % port, server_credentials) else: port = server.add_insecure_port('[::]:%d' % port) diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index f845e078684..68cfe3831a8 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -24,6 +24,7 @@ from grpc.experimental import aio from src.proto.grpc.testing import messages_pb2, test_pb2_grpc from tests.unit.framework.common import test_constants from tests_aio.unit._test_base import AioTestBase + from tests_aio.unit._test_server import start_test_server _NUM_STREAM_RESPONSES = 5 diff --git a/src/python/grpcio_tests/tests_aio/unit/init_test.py b/src/python/grpcio_tests/tests_aio/unit/init_test.py index 8b9a03e2dd3..c415faaf6ac 100644 --- a/src/python/grpcio_tests/tests_aio/unit/init_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/init_test.py @@ -20,6 +20,12 @@ from grpc.experimental import aio from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_base import AioTestBase +from tests.unit import resources + +_PRIVATE_KEY = resources.private_key() +_CERTIFICATE_CHAIN = resources.certificate_chain() +_TEST_ROOT_CERTIFICATES = resources.test_root_certificates() + class TestInsecureChannel(AioTestBase): @@ -37,8 +43,11 @@ class TestSecureChannel(AioTestBase): async def coro(): server_target, _ = await start_test_server(secure=True) # pylint: disable=unused-variable - credentials = grpc.local_channel_credentials( - grpc.LocalConnectionType.LOCAL_TCP) + credentials = grpc.ssl_channel_credentials( + root_certificates=_TEST_ROOT_CERTIFICATES, + private_key=_PRIVATE_KEY, + certificate_chain=_CERTIFICATE_CHAIN, + ) secure_channel = aio.secure_channel(server_target, credentials) self.assertIsInstance(secure_channel, aio.Channel)