diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi index dbe673d0bf6..bfbcb8d4fc8 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi @@ -125,7 +125,7 @@ cdef class _AioCall(GrpcCallWrapper): if credentials is not None: set_credentials_error = grpc_call_set_credentials(self.call, credentials.c()) if set_credentials_error != GRPC_CALL_OK: - raise Exception("Credentials couldn't have been set") + raise InternalError("Credentials couldn't have been set: {0}".format(set_credentials_error)) grpc_slice_unref(method_slice) @@ -178,7 +178,7 @@ cdef class _AioCall(GrpcCallWrapper): def cancel(self, str details): """Cancels the RPC in Core with given RPC status. - + Above abstractions must invoke this method to set Core objects into proper state. """ @@ -209,7 +209,7 @@ cdef class _AioCall(GrpcCallWrapper): def done(self): """Returns if the RPC call has finished. - + Checks if the status has been provided, either because the RPC finished or because was cancelled.. @@ -220,7 +220,7 @@ cdef class _AioCall(GrpcCallWrapper): def cancelled(self): """Returns if the RPC was cancelled. - + Returns: True if the RPC was cancelled. """ @@ -231,7 +231,7 @@ cdef class _AioCall(GrpcCallWrapper): async def status(self): """Returns the status of the RPC call. - + It returns the finshed status of the RPC. If the RPC has not finished yet this function will wait until the RPC gets finished. @@ -254,7 +254,7 @@ cdef class _AioCall(GrpcCallWrapper): async def initial_metadata(self): """Returns the initial metadata of the RPC call. - + If the initial metadata has not been received yet this function will wait until the RPC gets finished. @@ -286,7 +286,7 @@ cdef class _AioCall(GrpcCallWrapper): bytes request, tuple outbound_initial_metadata): """Performs a unary unary RPC. - + Args: request: the serialized requests in bytes. outbound_initial_metadata: optional outbound metadata. @@ -420,7 +420,7 @@ cdef class _AioCall(GrpcCallWrapper): tuple outbound_initial_metadata, object metadata_sent_observer): """Actual implementation of the complete unary-stream call. - + Needs to pay extra attention to the raise mechanism. If we want to propagate the final status exception, then we have to raise it. Othersize, it would end normally and raise `StopAsyncIteration()`. @@ -490,7 +490,7 @@ cdef class _AioCall(GrpcCallWrapper): outbound_initial_metadata, self._send_initial_metadata_flags, self._loop) - # Notify upper level that sending messages are allowed now. + # Notify upper level that sending messages are allowed now. metadata_sent_observer() # Receives initial metadata. diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi index f49681a4588..03b4990e488 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi @@ -24,3 +24,4 @@ cdef class AioChannel: object loop bytes _target AioChannelStatus _status + bint _is_secure diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi index a2882e64b7f..beadce67b4a 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi @@ -36,11 +36,13 @@ cdef class AioChannel: self._status = AIO_CHANNEL_STATUS_READY if credentials is None: + self._is_secure = False self.channel = grpc_insecure_channel_create( target, channel_args.c_args(), NULL) else: + self._is_secure = True self.channel = grpc_secure_channel_create( credentials.c(), target, @@ -122,6 +124,9 @@ cdef class AioChannel: cdef CallCredentials cython_call_credentials if python_call_credentials is not None: + if not self._is_secure: + raise UsageError("Call credentials are only valid on secure channels") + cython_call_credentials = python_call_credentials._credentials else: cython_call_credentials = None diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/grpc_aio.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/grpc_aio.pxd.pxi index 1755b702015..ebf0660174d 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/grpc_aio.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/grpc_aio.pxd.pxi @@ -23,10 +23,10 @@ cdef class _AioState: cdef grpc_completion_queue *global_completion_queue() -cdef init_grpc_aio() +cpdef init_grpc_aio() -cdef shutdown_grpc_aio() +cpdef shutdown_grpc_aio() cdef extern from "src/core/lib/iomgr/timer_manager.h": diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/grpc_aio.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/grpc_aio.pyx.pxi index 1612f5e3f25..d570f478391 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/grpc_aio.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/grpc_aio.pyx.pxi @@ -114,9 +114,9 @@ cdef _actual_aio_shutdown(): raise ValueError('Unsupported engine type [%s]' % _global_aio_state.engine) -cdef init_grpc_aio(): - """Initialis the gRPC AsyncIO module. - +cpdef init_grpc_aio(): + """Initializes the gRPC AsyncIO module. + Expected to be invoked on critical class constructors. E.g., AioChannel, AioServer. """ @@ -126,7 +126,7 @@ cdef init_grpc_aio(): _actual_aio_initialization() -cdef shutdown_grpc_aio(): +cpdef shutdown_grpc_aio(): """Shuts down the gRPC AsyncIO module. Expected to be invoked on critical class destructors. diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/iomgr.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/iomgr.pyx.pxi index ac62c41e0f2..f5b62af5287 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/iomgr.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/iomgr/iomgr.pyx.pxi @@ -212,7 +212,18 @@ cdef void asyncio_run_loop(size_t timeout_ms) with gil: pass +def _auth_plugin_callback_wrapper(object cb, + str service_url, + str method_name, + object callback): + asyncio.get_event_loop().call_soon(cb, service_url, method_name, callback) + + def install_asyncio_iomgr(): + # Auth plugins invoke user provided logic in another thread by default. We + # need to override that behavior by registering the call to the event loop. + set_async_callback_func(_auth_plugin_callback_wrapper) + asyncio_resolver_vtable.resolve = asyncio_resolve asyncio_resolver_vtable.resolve_async = asyncio_resolve_async diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi index c736b7a10c5..24d1e2a3b77 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi @@ -34,12 +34,14 @@ cdef class CallCredentials: raise NotImplementedError() -cdef int _get_metadata( - void *state, grpc_auth_metadata_context context, - grpc_credentials_plugin_metadata_cb cb, void *user_data, - grpc_metadata creds_md[GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX], - size_t *num_creds_md, grpc_status_code *status, - const char **error_details) except * with gil: +cdef int _get_metadata(void *state, + grpc_auth_metadata_context context, + grpc_credentials_plugin_metadata_cb cb, + void *user_data, + grpc_metadata creds_md[GRPC_METADATA_CREDENTIALS_PLUGIN_SYNC_MAX], + size_t *num_creds_md, + grpc_status_code *status, + const char **error_details) except * with gil: cdef size_t metadata_count cdef grpc_metadata *c_metadata def callback(metadata, grpc_status_code status, bytes error_details): diff --git a/src/python/grpcio/grpc/experimental/aio/__init__.py b/src/python/grpcio/grpc/experimental/aio/__init__.py index 31882be24fc..3613908a961 100644 --- a/src/python/grpcio/grpc/experimental/aio/__init__.py +++ b/src/python/grpcio/grpc/experimental/aio/__init__.py @@ -20,7 +20,8 @@ created. AsyncIO doesn't provide thread safety for most of its APIs. from typing import Any, Optional, Sequence, Tuple import grpc -from grpc._cython.cygrpc import (EOF, AbortError, BaseError, InternalError, +from grpc._cython.cygrpc import (init_grpc_aio, shutdown_grpc_aio, EOF, + AbortError, BaseError, InternalError, UsageError) from ._base_call import (Call, RpcContext, StreamStreamCall, StreamUnaryCall, @@ -39,6 +40,8 @@ from ._channel import insecure_channel, secure_channel ################################### __all__ ################################# __all__ = ( + 'init_grpc_aio', + 'shutdown_grpc_aio', 'AioRpcError', 'RpcContext', 'Call', diff --git a/src/python/grpcio_tests/tests_aio/interop/BUILD.bazel b/src/python/grpcio_tests/tests_aio/interop/BUILD.bazel index b5bbdc6df4e..f67ad35cca7 100644 --- a/src/python/grpcio_tests/tests_aio/interop/BUILD.bazel +++ b/src/python/grpcio_tests/tests_aio/interop/BUILD.bazel @@ -56,6 +56,7 @@ py_binary( python_version = "PY3", deps = [ "//src/python/grpcio/grpc:grpcio", + "//src/python/grpcio_tests/tests/interop:resources", "//src/python/grpcio_tests/tests/interop:server", "//src/python/grpcio_tests/tests_aio/unit:_test_server", ], @@ -70,5 +71,6 @@ py_binary( ":methods", "//src/python/grpcio/grpc:grpcio", "//src/python/grpcio_tests/tests/interop:client", + "//src/python/grpcio_tests/tests/interop:resources", ], ) diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index d79ed422596..b2b53a3ad65 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -19,9 +19,11 @@ "unit.compression_test.TestCompression", "unit.connectivity_test.TestConnectivityState", "unit.done_callback_test.TestDoneCallback", - "unit.init_test.TestInsecureChannel", - "unit.init_test.TestSecureChannel", + "unit.init_test.TestChannel", "unit.metadata_test.TestMetadata", + "unit.secure_call_test.TestStreamStreamSecureCall", + "unit.secure_call_test.TestUnaryStreamSecureCall", + "unit.secure_call_test.TestUnaryUnarySecureCall", "unit.server_interceptor_test.TestServerInterceptor", "unit.server_test.TestServer", "unit.timeout_test.TestTimeout", diff --git a/src/python/grpcio_tests/tests_aio/unit/BUILD.bazel b/src/python/grpcio_tests/tests_aio/unit/BUILD.bazel index ab475bcf97c..1847e9cff6e 100644 --- a/src/python/grpcio_tests/tests_aio/unit/BUILD.bazel +++ b/src/python/grpcio_tests/tests_aio/unit/BUILD.bazel @@ -41,6 +41,7 @@ py_library( "//src/proto/grpc/testing:py_messages_proto", "//src/proto/grpc/testing:test_py_pb2_grpc", "//src/python/grpcio/grpc:grpcio", + "//src/python/grpcio_tests/tests/unit:resources", ], ) @@ -76,6 +77,7 @@ _FLAKY_TESTS = [ "//src/proto/grpc/testing:benchmark_service_py_pb2_grpc", "//src/proto/grpc/testing:py_messages_proto", "//src/python/grpcio/grpc:grpcio", + "//src/python/grpcio_tests/tests/unit:resources", "//src/python/grpcio_tests/tests/unit/framework/common", "@six", ], diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_base.py b/src/python/grpcio_tests/tests_aio/unit/_test_base.py index ec5f2112da0..82ec7b456ad 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_base.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_base.py @@ -64,3 +64,6 @@ class AioTestBase(unittest.TestCase): return _async_to_sync_decorator(attr, self._TEST_LOOP) # For other attributes, let them pass. return attr + + +aio.init_grpc_aio() diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_server.py b/src/python/grpcio_tests/tests_aio/unit/_test_server.py index 7c8afa8ff5c..2396608e5cc 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_server.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_server.py @@ -17,6 +17,7 @@ import datetime import grpc from grpc.experimental import aio +from tests.unit import resources from src.proto.grpc.testing import empty_pb2, messages_pb2, test_pb2_grpc from tests_aio.unit import _constants @@ -129,8 +130,9 @@ async def start_test_server(port=0, if secure: if server_credentials is None: - server_credentials = grpc.local_server_credentials( - grpc.LocalConnectionType.LOCAL_TCP) + server_credentials = grpc.ssl_server_credentials([ + (resources.private_key(), resources.certificate_chain()) + ]) port = server.add_secure_port('[::]:%d' % port, server_credentials) else: port = server.add_insecure_port('[::]:%d' % port) diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index f845e078684..93b27853023 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -14,7 +14,6 @@ """Tests behavior of the Call classes.""" import asyncio -import datetime import logging import unittest @@ -24,6 +23,8 @@ from grpc.experimental import aio from src.proto.grpc.testing import messages_pb2, test_pb2_grpc from tests.unit.framework.common import test_constants from tests_aio.unit._test_base import AioTestBase +from tests.unit import resources + from tests_aio.unit._test_server import start_test_server _NUM_STREAM_RESPONSES = 5 @@ -55,7 +56,7 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): self.assertTrue(str(call) is not None) self.assertTrue(repr(call) is not None) - response = await call + await call self.assertTrue(str(call) is not None) self.assertTrue(repr(call) is not None) @@ -202,6 +203,17 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): with self.assertRaises(asyncio.CancelledError): await task + async def test_passing_credentials_fails_over_insecure_channel(self): + call_credentials = grpc.composite_call_credentials( + grpc.access_token_call_credentials("abc"), + grpc.access_token_call_credentials("def"), + ) + with self.assertRaisesRegex( + grpc._cygrpc.UsageError, + "Call credentials are only valid on secure channels"): + self._stub.UnaryCall(messages_pb2.SimpleRequest(), + credentials=call_credentials) + class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase): @@ -410,33 +422,6 @@ class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase): with self.assertRaises(asyncio.CancelledError): await task - def test_call_credentials(self): - - class DummyAuth(grpc.AuthMetadataPlugin): - - def __call__(self, context, callback): - signature = context.method_name[::-1] - callback((("test", signature),), None) - - async def coro(): - server_target, _ = await start_test_server(secure=False) # pylint: disable=unused-variable - - async with aio.insecure_channel(server_target) as channel: - hi = channel.unary_unary('/grpc.testing.TestService/UnaryCall', - request_serializer=messages_pb2. - SimpleRequest.SerializeToString, - response_deserializer=messages_pb2. - SimpleResponse.FromString) - call_credentials = grpc.metadata_call_credentials(DummyAuth()) - call = hi(messages_pb2.SimpleRequest(), - credentials=call_credentials) - response = await call - - self.assertIsInstance(response, messages_pb2.SimpleResponse) - self.assertEqual(await call.code(), grpc.StatusCode.OK) - - self.loop.run_until_complete(coro()) - async def test_time_remaining(self): request = messages_pb2.StreamingOutputCallRequest() # First message comes back immediately diff --git a/src/python/grpcio_tests/tests_aio/unit/init_test.py b/src/python/grpcio_tests/tests_aio/unit/init_test.py index 8b9a03e2dd3..9104a0368c5 100644 --- a/src/python/grpcio_tests/tests_aio/unit/init_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/init_test.py @@ -20,8 +20,14 @@ from grpc.experimental import aio from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_base import AioTestBase +from tests.unit import resources -class TestInsecureChannel(AioTestBase): +_PRIVATE_KEY = resources.private_key() +_CERTIFICATE_CHAIN = resources.certificate_chain() +_TEST_ROOT_CERTIFICATES = resources.test_root_certificates() + + +class TestChannel(AioTestBase): async def test_insecure_channel(self): server_target, _ = await start_test_server() # pylint: disable=unused-variable @@ -29,21 +35,16 @@ class TestInsecureChannel(AioTestBase): channel = aio.insecure_channel(server_target) self.assertIsInstance(channel, aio.Channel) + async def tests_secure_channel(self): + server_target, _ = await start_test_server(secure=True) # pylint: disable=unused-variable + credentials = grpc.ssl_channel_credentials( + root_certificates=_TEST_ROOT_CERTIFICATES, + private_key=_PRIVATE_KEY, + certificate_chain=_CERTIFICATE_CHAIN, + ) + secure_channel = aio.secure_channel(server_target, credentials) -class TestSecureChannel(AioTestBase): - """Test a secure channel connected to a secure server""" - - def test_secure_channel(self): - - async def coro(): - server_target, _ = await start_test_server(secure=True) # pylint: disable=unused-variable - credentials = grpc.local_channel_credentials( - grpc.LocalConnectionType.LOCAL_TCP) - secure_channel = aio.secure_channel(server_target, credentials) - - self.assertIsInstance(secure_channel, aio.Channel) - - self.loop.run_until_complete(coro()) + self.assertIsInstance(secure_channel, aio.Channel) if __name__ == '__main__': diff --git a/src/python/grpcio_tests/tests_aio/unit/secure_call_test.py b/src/python/grpcio_tests/tests_aio/unit/secure_call_test.py new file mode 100644 index 00000000000..7efaddd607e --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/secure_call_test.py @@ -0,0 +1,130 @@ +# Copyright 2020 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests the behaviour of the Call classes under a secure channel.""" + +import unittest +import logging + +import grpc +from grpc.experimental import aio +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc +from tests_aio.unit._test_base import AioTestBase +from tests_aio.unit._test_server import start_test_server +from tests.unit import resources + +_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' +_NUM_STREAM_RESPONSES = 5 +_RESPONSE_PAYLOAD_SIZE = 42 + + +class _SecureCallMixin: + """A Mixin to run the call tests over a secure channel.""" + + async def setUp(self): + server_credentials = grpc.ssl_server_credentials([ + (resources.private_key(), resources.certificate_chain()) + ]) + channel_credentials = grpc.ssl_channel_credentials( + resources.test_root_certificates()) + + self._server_address, self._server = await start_test_server( + secure=True, server_credentials=server_credentials) + channel_options = (( + 'grpc.ssl_target_name_override', + _SERVER_HOST_OVERRIDE, + ),) + self._channel = aio.secure_channel(self._server_address, + channel_credentials, channel_options) + self._stub = test_pb2_grpc.TestServiceStub(self._channel) + + async def tearDown(self): + await self._channel.close() + await self._server.stop(None) + + +class TestUnaryUnarySecureCall(_SecureCallMixin, AioTestBase): + """unary_unary Calls made over a secure channel.""" + + async def test_call_ok_over_secure_channel(self): + call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) + response = await call + self.assertIsInstance(response, messages_pb2.SimpleResponse) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + async def test_call_with_credentials(self): + call_credentials = grpc.composite_call_credentials( + grpc.access_token_call_credentials("abc"), + grpc.access_token_call_credentials("def"), + ) + call = self._stub.UnaryCall(messages_pb2.SimpleRequest(), + credentials=call_credentials) + response = await call + + self.assertIsInstance(response, messages_pb2.SimpleResponse) + + +class TestUnaryStreamSecureCall(_SecureCallMixin, AioTestBase): + """unary_stream calls over a secure channel""" + + async def test_unary_stream_async_generator_secure(self): + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.extend( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE,) + for _ in range(_NUM_STREAM_RESPONSES)) + call_credentials = grpc.composite_call_credentials( + grpc.access_token_call_credentials("abc"), + grpc.access_token_call_credentials("def"), + ) + call = self._stub.StreamingOutputCall(request, + credentials=call_credentials) + + async for response in call: + self.assertIsInstance(response, + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(len(response.payload.body), _RESPONSE_PAYLOAD_SIZE) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + +# Prepares the request that stream in a ping-pong manner. +_STREAM_OUTPUT_REQUEST_ONE_RESPONSE = messages_pb2.StreamingOutputCallRequest() +_STREAM_OUTPUT_REQUEST_ONE_RESPONSE.response_parameters.append( + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + + +class TestStreamStreamSecureCall(_SecureCallMixin, AioTestBase): + _STREAM_ITERATIONS = 2 + + async def test_async_generator_secure_channel(self): + + async def request_generator(): + for _ in range(self._STREAM_ITERATIONS): + yield _STREAM_OUTPUT_REQUEST_ONE_RESPONSE + + call_credentials = grpc.composite_call_credentials( + grpc.access_token_call_credentials("abc"), + grpc.access_token_call_credentials("def"), + ) + + call = self._stub.FullDuplexCall(request_generator(), + credentials=call_credentials) + async for response in call: + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2)