From 77df7f5f172e9b3bbbeb5805a9e4fb76965bd0b2 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Sun, 17 Nov 2019 21:05:20 +0100 Subject: [PATCH] Client unary unary interceptor Implements the unary unary interceptor for the client-side. Interceptors can be now installed by passing them as a new parameter of the `Channel` constructor or by giving them as part of the `insecure_channel` function. Interceptors are executed within an Asyncio task for making some work before the RPC invocation, and after for accessing to the intercepted call that has been invoked. --- .../grpc/_cython/_cygrpc/aio/call.pyx.pxi | 1 + .../grpcio/grpc/experimental/BUILD.bazel | 2 + .../grpcio/grpc/experimental/aio/__init__.py | 24 +- .../grpcio/grpc/experimental/aio/_channel.py | 89 ++-- .../grpc/experimental/aio/_interceptor.py | 336 ++++++++++++ .../grpcio/grpc/experimental/aio/_utils.py | 23 + src/python/grpcio_tests/tests_aio/tests.json | 2 + .../tests_aio/unit/_test_server.py | 2 +- .../grpcio_tests/tests_aio/unit/call_test.py | 3 +- .../tests_aio/unit/channel_test.py | 1 + .../tests_aio/unit/interceptor_test.py | 504 ++++++++++++++++++ 11 files changed, 950 insertions(+), 37 deletions(-) create mode 100644 src/python/grpcio/grpc/experimental/aio/_interceptor.py create mode 100644 src/python/grpcio/grpc/experimental/aio/_utils.py create mode 100644 src/python/grpcio_tests/tests_aio/unit/interceptor_test.py diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi index c10d79cb7d3..2d013afe6cb 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi @@ -15,6 +15,7 @@ cimport cpython import grpc + _EMPTY_FLAGS = 0 _EMPTY_MASK = 0 _EMPTY_METADATA = None diff --git a/src/python/grpcio/grpc/experimental/BUILD.bazel b/src/python/grpcio/grpc/experimental/BUILD.bazel index 46e7bf29ab6..e436233210b 100644 --- a/src/python/grpcio/grpc/experimental/BUILD.bazel +++ b/src/python/grpcio/grpc/experimental/BUILD.bazel @@ -7,8 +7,10 @@ py_library( "aio/_base_call.py", "aio/_call.py", "aio/_channel.py", + "aio/_interceptor.py", "aio/_server.py", "aio/_typing.py", + "aio/_utils.py", ], deps = [ "//src/python/grpcio/grpc/_cython:cygrpc", diff --git a/src/python/grpcio/grpc/experimental/aio/__init__.py b/src/python/grpcio/grpc/experimental/aio/__init__.py index b20a3524508..2ece2ed3775 100644 --- a/src/python/grpcio/grpc/experimental/aio/__init__.py +++ b/src/python/grpcio/grpc/experimental/aio/__init__.py @@ -18,18 +18,26 @@ created. AsyncIO doesn't provide thread safety for most of its APIs. """ import abc +from typing import Any, Optional, Sequence, Text, Tuple import six import grpc from grpc._cython.cygrpc import init_grpc_aio from ._base_call import RpcContext, Call, UnaryUnaryCall, UnaryStreamCall +from ._call import AioRpcError from ._channel import Channel from ._channel import UnaryUnaryMultiCallable +from ._interceptor import ClientCallDetails, UnaryUnaryClientInterceptor +from ._interceptor import InterceptedUnaryUnaryCall from ._server import server -def insecure_channel(target, options=None, compression=None): +def insecure_channel( + target: Text, + options: Optional[Sequence[Tuple[Text, Any]]] = None, + compression: Optional[grpc.Compression] = None, + interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] = None): """Creates an insecure asynchronous Channel to a server. Args: @@ -38,16 +46,22 @@ def insecure_channel(target, options=None, compression=None): in gRPC Core runtime) to configure the channel. compression: An optional value indicating the compression method to be used over the lifetime of the channel. This is an EXPERIMENTAL option. + interceptors: An optional sequence of interceptors that will be executed for + any call executed with this channel. Returns: A Channel. """ - return Channel(target, () if options is None else options, None, - compression) + return Channel( + target, () if options is None else options, + None, + compression, + interceptors=interceptors) ################################### __all__ ################################# -__all__ = ('RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall', +__all__ = ('AioRpcError', 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall', 'init_grpc_aio', 'Channel', 'UnaryUnaryMultiCallable', - 'insecure_channel', 'server') + 'ClientCallDetails', 'UnaryUnaryClientInterceptor', + 'InterceptedUnaryUnaryCall', 'insecure_channel', 'server') diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index c8f7541bfb5..9c91101c43d 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -18,29 +18,35 @@ from typing import Any, Optional, Sequence, Text, Tuple import grpc from grpc import _common from grpc._cython import cygrpc + from . import _base_call from ._call import UnaryUnaryCall, UnaryStreamCall +from ._interceptor import UnaryUnaryClientInterceptor, InterceptedUnaryUnaryCall from ._typing import (DeserializingFunction, MetadataType, SerializingFunction) - - -def _timeout_to_deadline(loop: asyncio.AbstractEventLoop, - timeout: Optional[float]) -> Optional[float]: - if timeout is None: - return None - return loop.time() + timeout +from ._utils import _timeout_to_deadline class UnaryUnaryMultiCallable: """Factory an asynchronous unary-unary RPC stub call from client-side.""" + _channel: cygrpc.AioChannel + _method: bytes + _request_serializer: SerializingFunction + _response_deserializer: DeserializingFunction + _interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] + _loop: asyncio.AbstractEventLoop + def __init__(self, channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction) -> None: + response_deserializer: DeserializingFunction, + interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] + ) -> None: self._loop = asyncio.get_event_loop() self._channel = channel self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer + self._interceptors = interceptors def __call__(self, request: Any, @@ -74,7 +80,6 @@ class UnaryUnaryMultiCallable: raised RpcError will also be a Call for the RPC affording the RPC's metadata, status code, and details. """ - if metadata: raise NotImplementedError("TODO: metadata not implemented yet") @@ -88,16 +93,25 @@ class UnaryUnaryMultiCallable: if compression: raise NotImplementedError("TODO: compression not implemented yet") - deadline = _timeout_to_deadline(self._loop, timeout) - - return UnaryUnaryCall( - request, - deadline, - self._channel, - self._method, - self._request_serializer, - self._response_deserializer, - ) + if not self._interceptors: + return UnaryUnaryCall( + request, + _timeout_to_deadline(self._loop, timeout), + self._channel, + self._method, + self._request_serializer, + self._response_deserializer, + ) + else: + return InterceptedUnaryUnaryCall( + self._interceptors, + request, + timeout, + self._channel, + self._method, + self._request_serializer, + self._response_deserializer, + ) class UnaryStreamMultiCallable: @@ -138,13 +152,7 @@ class UnaryStreamMultiCallable: Returns: A Call object instance which is an awaitable object. - - Raises: - RpcError: Indicating that the RPC terminated with non-OK status. The - raised RpcError will also be a Call for the RPC affording the RPC's - metadata, status code, and details. """ - if metadata: raise NotImplementedError("TODO: metadata not implemented yet") @@ -175,11 +183,14 @@ class Channel: A cygrpc.AioChannel-backed implementation. """ + _channel: cygrpc.AioChannel + _unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] def __init__(self, target: Text, options: Optional[Sequence[Tuple[Text, Any]]], credentials: Optional[grpc.ChannelCredentials], - compression: Optional[grpc.Compression]): + compression: Optional[grpc.Compression], + interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]): """Constructor. Args: @@ -188,8 +199,9 @@ class Channel: credentials: A cygrpc.ChannelCredentials or None. compression: An optional value indicating the compression method to be used over the lifetime of the channel. + interceptors: An optional list of interceptors that would be used for + intercepting any RPC executed with that channel. """ - if options: raise NotImplementedError("TODO: options not implemented yet") @@ -199,6 +211,23 @@ class Channel: if compression: raise NotImplementedError("TODO: compression not implemented yet") + if interceptors is None: + self._unary_unary_interceptors = None + else: + self._unary_unary_interceptors = list( + filter( + lambda interceptor: isinstance(interceptor, UnaryUnaryClientInterceptor), + interceptors)) + + invalid_interceptors = set(interceptors) - set( + self._unary_unary_interceptors) + + if invalid_interceptors: + raise ValueError( + "Interceptor must be "+\ + "UnaryUnaryClientInterceptors, the following are invalid: {}"\ + .format(invalid_interceptors)) + self._channel = cygrpc.AioChannel(_common.encode(target)) def unary_unary( @@ -220,9 +249,9 @@ class Channel: Returns: A UnaryUnaryMultiCallable value for the named unary-unary method. """ - return UnaryUnaryMultiCallable(self._channel, _common.encode(method), - request_serializer, - response_deserializer) + return UnaryUnaryMultiCallable( + self._channel, _common.encode(method), request_serializer, + response_deserializer, self._unary_unary_interceptors) def unary_stream( self, diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py new file mode 100644 index 00000000000..205a710b249 --- /dev/null +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -0,0 +1,336 @@ +# Copyright 2019 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Interceptors implementation of gRPC Asyncio Python.""" +import asyncio +import collections +import functools +from abc import ABCMeta, abstractmethod +from typing import Callable, Optional, Iterator, Sequence, Text, Union + +import grpc +from grpc._cython import cygrpc + +from . import _base_call +from ._call import UnaryUnaryCall +from ._utils import _timeout_to_deadline +from ._typing import (RequestType, SerializingFunction, DeserializingFunction, + MetadataType, ResponseType) + +_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!' + + +class ClientCallDetails( + collections.namedtuple( + 'ClientCallDetails', + ('method', 'timeout', 'metadata', 'credentials')), + grpc.ClientCallDetails): + + method: Text + timeout: Optional[float] + metadata: Optional[MetadataType] + credentials: Optional[grpc.CallCredentials] + + +class UnaryUnaryClientInterceptor(metaclass=ABCMeta): + """Affords intercepting unary-unary invocations.""" + + @abstractmethod + async def intercept_unary_unary( + self, continuation: Callable[[ClientCallDetails, RequestType], + UnaryUnaryCall], + client_call_details: ClientCallDetails, + request: RequestType) -> Union[UnaryUnaryCall, ResponseType]: + """Intercepts a unary-unary invocation asynchronously. + Args: + continuation: A coroutine that proceeds with the invocation by + executing the next interceptor in chain or invoking the + actual RPC on the underlying Channel. It is the interceptor's + responsibility to call it if it decides to move the RPC forward. + The interceptor can use + `response_future = await continuation(client_call_details, request)` + to continue with the RPC. `continuation` returns the response of the + RPC. + client_call_details: A ClientCallDetails object describing the + outgoing RPC. + request: The request value for the RPC. + Returns: + An object with the RPC response. + Raises: + AioRpcError: Indicating that the RPC terminated with non-OK status. + asyncio.CancelledError: Indicating that the RPC was canceled. + """ + + +class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): + """Used for running a `UnaryUnaryCall` wrapped by interceptors. + + Interceptors might have some work to do before the RPC invocation with + the capacity of changing the invocation parameters, and some work to do + after the RPC invocation with the capacity for accessing to the wrapped + `UnaryUnaryCall`. + + It handles also early and later cancellations, when the RPC has not even + started and the execution is still held by the interceptors or when the + RPC has finished but again the execution is still held by the interceptors. + + Once the RPC is finally executed, all methods are finally done against the + intercepted call, being at the same time the same call returned to the + interceptors. + + For most of the methods, like `initial_metadata()` the caller does not need + to wait until the interceptors task is finished, once the RPC is done the + caller will have the freedom for accessing to the results. + + For the `__await__` method is it is proxied to the intercepted call only when + the interceptor task is finished. + """ + + _loop: asyncio.AbstractEventLoop + _channel: cygrpc.AioChannel + _cancelled_before_rpc: bool + _intercepted_call: Optional[_base_call.UnaryUnaryCall] + _intercepted_call_created: asyncio.Event + _interceptors_task: asyncio.Task + + def __init__( # pylint: disable=R0913 + self, interceptors: Sequence[UnaryUnaryClientInterceptor], + request: RequestType, timeout: Optional[float], + channel: cygrpc.AioChannel, method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction) -> None: + self._channel = channel + self._loop = asyncio.get_event_loop() + self._interceptors_task = asyncio.ensure_future( + self._invoke(interceptors, method, timeout, request, + request_serializer, response_deserializer)) + + def __del__(self): + self.cancel() + + async def _invoke( + self, interceptors: Sequence[UnaryUnaryClientInterceptor], + method: bytes, timeout: Optional[float], request: RequestType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction) -> UnaryUnaryCall: + """Run the RPC call wrapped in interceptors""" + + async def _run_interceptor( + interceptors: Iterator[UnaryUnaryClientInterceptor], + client_call_details: ClientCallDetails, + request: RequestType) -> _base_call.UnaryUnaryCall: + try: + interceptor = next(interceptors) + except StopIteration: + interceptor = None + + if interceptor: + continuation = functools.partial(_run_interceptor, interceptors) + try: + call_or_response = await interceptor.intercept_unary_unary( + continuation, client_call_details, request) + except grpc.RpcError as err: + # gRPC error is masked inside an artificial call, + # caller will see this error if and only + # if it runs an `await call` operation + return UnaryUnaryCallRpcError(err) + except asyncio.CancelledError: + # Cancellation is masked inside an artificial call, + # caller will see this error if and only + # if it runs an `await call` operation + return UnaryUnaryCancelledError() + + if isinstance(call_or_response, _base_call.UnaryUnaryCall): + return call_or_response + else: + return UnaryUnaryCallResponse(call_or_response) + + else: + return UnaryUnaryCall(request, + _timeout_to_deadline( + self._loop, + client_call_details.timeout), + self._channel, client_call_details.method, + request_serializer, response_deserializer) + + client_call_details = ClientCallDetails(method, timeout, None, None) + return await _run_interceptor( + iter(interceptors), client_call_details, request) + + def cancel(self) -> bool: + if self._interceptors_task.done(): + return False + + return self._interceptors_task.cancel() + + def cancelled(self) -> bool: + if not self._interceptors_task.done(): + return False + + call = self._interceptors_task.result() + return call.cancelled() + + def done(self) -> bool: + if not self._interceptors_task.done(): + return False + + return True + + def add_done_callback(self, unused_callback) -> None: + raise NotImplementedError() + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + async def initial_metadata(self) -> Optional[MetadataType]: + return await (await self._interceptors_task).initial_metadata() + + async def trailing_metadata(self) -> Optional[MetadataType]: + return await (await self._interceptors_task).trailing_metadata() + + async def code(self) -> grpc.StatusCode: + return await (await self._interceptors_task).code() + + async def details(self) -> str: + return await (await self._interceptors_task).details() + + async def debug_error_string(self) -> Optional[str]: + return await (await self._interceptors_task).debug_error_string() + + def __await__(self): + call = yield from self._interceptors_task.__await__() + response = yield from call.__await__() + return response + + +class UnaryUnaryCallRpcError(_base_call.UnaryUnaryCall): + """Final UnaryUnaryCall class finished with an RpcError.""" + _error: grpc.RpcError + + def __init__(self, error: grpc.RpcError) -> None: + self._error = error + + def cancel(self) -> bool: + return False + + def cancelled(self) -> bool: + return False + + def done(self) -> bool: + return True + + def add_done_callback(self, unused_callback) -> None: + raise NotImplementedError() + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + async def initial_metadata(self) -> Optional[MetadataType]: + return None + + async def trailing_metadata(self) -> Optional[MetadataType]: + return self._error.initial_metadata() + + async def code(self) -> grpc.StatusCode: + return self._error.code() + + async def details(self) -> str: + return self._error.details() + + async def debug_error_string(self) -> Optional[str]: + return self._error.debug_error_string() + + def __await__(self): + raise self._error + + +class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): + """Final UnaryUnaryCall class finished with a response.""" + _response: ResponseType + + def __init__(self, response: ResponseType) -> None: + self._response = response + + def cancel(self) -> bool: + return False + + def cancelled(self) -> bool: + return False + + def done(self) -> bool: + return True + + def add_done_callback(self, unused_callback) -> None: + raise NotImplementedError() + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + async def initial_metadata(self) -> Optional[MetadataType]: + return None + + async def trailing_metadata(self) -> Optional[MetadataType]: + return None + + async def code(self) -> grpc.StatusCode: + return grpc.StatusCode.OK + + async def details(self) -> str: + return '' + + async def debug_error_string(self) -> Optional[str]: + return None + + def __await__(self): + if False: # pylint: disable=W0125 + # This code path is never used, but a yield statement is needed + # for telling the interpreter that __await__ is a generator. + yield None + return self._response + + +class UnaryUnaryCancelledError(_base_call.UnaryUnaryCall): + """Final UnaryUnaryCall class finished with an asyncio.CancelledError.""" + + def cancel(self) -> bool: + return False + + def cancelled(self) -> bool: + return True + + def done(self) -> bool: + return True + + def add_done_callback(self, unused_callback) -> None: + raise NotImplementedError() + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + async def initial_metadata(self) -> Optional[MetadataType]: + return None + + async def trailing_metadata(self) -> Optional[MetadataType]: + return None + + async def code(self) -> grpc.StatusCode: + return grpc.StatusCode.CANCELLED + + async def details(self) -> str: + return _LOCAL_CANCELLATION_DETAILS + + async def debug_error_string(self) -> Optional[str]: + return None + + def __await__(self): + raise asyncio.CancelledError() diff --git a/src/python/grpcio/grpc/experimental/aio/_utils.py b/src/python/grpcio/grpc/experimental/aio/_utils.py new file mode 100644 index 00000000000..17fabbb5bff --- /dev/null +++ b/src/python/grpcio/grpc/experimental/aio/_utils.py @@ -0,0 +1,23 @@ +# Copyright 2019 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Internal utilities used by the gRPC Aio module.""" +import asyncio +from typing import Optional + + +def _timeout_to_deadline(loop: asyncio.AbstractEventLoop, + timeout: Optional[float]) -> Optional[float]: + if timeout is None: + return None + return loop.time() + timeout diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index 5d6634b8a8e..26545c29dc9 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -5,5 +5,7 @@ "unit.call_test.TestUnaryUnaryCall", "unit.channel_test.TestChannel", "unit.init_test.TestInsecureChannel", + "unit.interceptor_test.TestInterceptedUnaryUnaryCall", + "unit.interceptor_test.TestUnaryUnaryClientInterceptor", "unit.server_test.TestServer" ] diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_server.py b/src/python/grpcio_tests/tests_aio/unit/_test_server.py index 26b920f6c17..c3a04f29a00 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_server.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_server.py @@ -17,9 +17,9 @@ import logging import datetime from grpc.experimental import aio +from tests.unit.framework.common import test_constants from src.proto.grpc.testing import messages_pb2 from src.proto.grpc.testing import test_pb2_grpc -from tests.unit.framework.common import test_constants class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index 5ef6e0b9a43..ee28b0a966a 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -26,6 +26,7 @@ from src.proto.grpc.testing import test_pb2_grpc from tests.unit.framework.common import test_constants from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_base import AioTestBase +from src.proto.grpc.testing import messages_pb2 _NUM_STREAM_RESPONSES = 5 _RESPONSE_PAYLOAD_SIZE = 42 @@ -399,5 +400,5 @@ class TestUnaryStreamCall(AioTestBase): if __name__ == '__main__': - logging.basicConfig(level=logging.DEBUG) + logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py index 27b04e6875b..017ad0ae06c 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -25,6 +25,7 @@ from src.proto.grpc.testing import test_pb2_grpc from tests.unit.framework.common import test_constants from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_base import AioTestBase +from src.proto.grpc.testing import messages_pb2 _UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall' _STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall' diff --git a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py new file mode 100644 index 00000000000..f55c83eb48b --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py @@ -0,0 +1,504 @@ +# Copyright 2019 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import unittest + +import grpc + +from grpc.experimental import aio +from tests_aio.unit._test_server import start_test_server +from tests_aio.unit._test_base import AioTestBase +from src.proto.grpc.testing import messages_pb2 + +_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' + + +class TestUnaryUnaryClientInterceptor(AioTestBase): + + def test_invalid_interceptor(self): + + class InvalidInterceptor: + """Just an invalid Interceptor""" + + with self.assertRaises(ValueError): + aio.insecure_channel("", interceptors=[InvalidInterceptor()]) + + async def test_executed_right_order(self): + + interceptors_executed = [] + + class Interceptor(aio.UnaryUnaryClientInterceptor): + """Interceptor used for testing if the interceptor is being called""" + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + interceptors_executed.append(self) + call = await continuation(client_call_details, request) + return call + + interceptors = [Interceptor() for i in range(2)] + + server_target, _ = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel( + server_target, interceptors=interceptors) as channel: + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + response = await call + + # Check that all interceptors were executed, and were executed + # in the right order. + self.assertSequenceEqual(interceptors_executed, interceptors) + + self.assertIsInstance(response, messages_pb2.SimpleResponse) + + @unittest.expectedFailure + # TODO(https://github.com/grpc/grpc/issues/20144) Once metadata support is + # implemented in the client-side, this test must be implemented. + def test_modify_metadata(self): + raise NotImplementedError() + + @unittest.expectedFailure + # TODO(https://github.com/grpc/grpc/issues/20532) Once credentials support is + # implemented in the client-side, this test must be implemented. + def test_modify_credentials(self): + raise NotImplementedError() + + async def test_status_code_Ok(self): + + class StatusCodeOkInterceptor(aio.UnaryUnaryClientInterceptor): + """Interceptor used for observing status code Ok returned by the RPC""" + + def __init__(self): + self.status_code_Ok_observed = False + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + code = await call.code() + if code == grpc.StatusCode.OK: + self.status_code_Ok_observed = True + + return call + + interceptor = StatusCodeOkInterceptor() + server_target, server = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel( + server_target, interceptors=[interceptor]) as channel: + + # when no error StatusCode.OK must be observed + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + await multicallable(messages_pb2.SimpleRequest()) + + self.assertTrue(interceptor.status_code_Ok_observed) + + async def test_add_timeout(self): + + class TimeoutInterceptor(aio.UnaryUnaryClientInterceptor): + """Interceptor used for adding a timeout to the RPC""" + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + new_client_call_details = aio.ClientCallDetails( + method=client_call_details.method, + timeout=0.1, + metadata=client_call_details.metadata, + credentials=client_call_details.credentials) + return await continuation(new_client_call_details, request) + + interceptor = TimeoutInterceptor() + server_target, server = await start_test_server() + + async with aio.insecure_channel( + server_target, interceptors=[interceptor]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + call = multicallable(messages_pb2.SimpleRequest()) + + await server.stop(None) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + self.assertEqual(exception_context.exception.code(), + grpc.StatusCode.DEADLINE_EXCEEDED) + + self.assertTrue(call.done()) + self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await + call.code()) + + async def test_retry(self): + + class RetryInterceptor(aio.UnaryUnaryClientInterceptor): + """Simulates a Retry Interceptor which ends up by making + two RPC calls.""" + + def __init__(self): + self.calls = [] + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + + new_client_call_details = aio.ClientCallDetails( + method=client_call_details.method, + timeout=0.1, + metadata=client_call_details.metadata, + credentials=client_call_details.credentials) + + try: + call = await continuation(new_client_call_details, request) + await call + except grpc.RpcError: + pass + + self.calls.append(call) + + new_client_call_details = aio.ClientCallDetails( + method=client_call_details.method, + timeout=None, + metadata=client_call_details.metadata, + credentials=client_call_details.credentials) + + call = await continuation(new_client_call_details, request) + self.calls.append(call) + return call + + interceptor = RetryInterceptor() + server_target, server = await start_test_server() + + async with aio.insecure_channel( + server_target, interceptors=[interceptor]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + call = multicallable(messages_pb2.SimpleRequest()) + + await call + + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + # Check that two calls were made, first one finishing with + # a deadline and second one finishing ok.. + self.assertEqual(len(interceptor.calls), 2) + self.assertEqual(await interceptor.calls[0].code(), + grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertEqual(await interceptor.calls[1].code(), + grpc.StatusCode.OK) + + async def test_rpcerror_raised_when_call_is_awaited(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + """RpcErrors are only seen when the call is awaited""" + + def __init__(self): + self.deadline_seen = False + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + + try: + await call + except aio.AioRpcError as err: + if err.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + self.deadline_seen = True + raise + + # This point should never be reached + raise Exception() + + interceptor_a, interceptor_b = (Interceptor(), Interceptor()) + server_target, server = await start_test_server() + + async with aio.insecure_channel( + server_target, interceptors=[interceptor_a, + interceptor_b]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + call = multicallable(messages_pb2.SimpleRequest(), timeout=0.1) + + with self.assertRaises(grpc.RpcError) as exception_context: + await call + + # Check that the two interceptors catch the deadline exception + # only when the call was awaited + self.assertTrue(interceptor_a.deadline_seen) + self.assertTrue(interceptor_b.deadline_seen) + + # Check all of the UnaryUnaryCallRpcError attributes + self.assertTrue(call.done()) + self.assertFalse(call.cancel()) + self.assertFalse(call.cancelled()) + self.assertEqual(await call.code(), + grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertEqual(await call.details(), 'Deadline Exceeded') + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.debug_error_string(), None) + + async def test_rpcresponse(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + """Raw responses are seen as reegular calls""" + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + response = await call + return call + + class ResponseInterceptor(aio.UnaryUnaryClientInterceptor): + """Return a raw response""" + response = messages_pb2.SimpleResponse() + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + return ResponseInterceptor.response + + interceptor, interceptor_response = Interceptor(), ResponseInterceptor() + server_target, server = await start_test_server() + + async with aio.insecure_channel( + server_target, interceptors=[interceptor, + interceptor_response]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + call = multicallable(messages_pb2.SimpleRequest()) + response = await call + + # Check that the response returned is the one returned by the + # interceptor + self.assertEqual(id(response), id(ResponseInterceptor.response)) + + # Check all of the UnaryUnaryCallResponse attributes + self.assertTrue(call.done()) + self.assertFalse(call.cancel()) + self.assertFalse(call.cancelled()) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + self.assertEqual(await call.debug_error_string(), None) + + +class TestInterceptedUnaryUnaryCall(AioTestBase): + + async def test_call_ok(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + return call + + server_target, _ = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel( + server_target, interceptors=[Interceptor()]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + response = await call + + self.assertTrue(call.done()) + self.assertFalse(call.cancelled()) + self.assertEqual(type(response), messages_pb2.SimpleResponse) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.initial_metadata(), ()) + self.assertEqual(await call.trailing_metadata(), ()) + + async def test_cancel_before_rpc(self): + + interceptor_reached = asyncio.Event() + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + interceptor_reached.set() + await asyncio.sleep(0) + + # This line should never be reached + raise Exception() + + server_target, _ = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel( + server_target, interceptors=[Interceptor()]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + + self.assertFalse(call.cancelled()) + self.assertFalse(call.done()) + + await interceptor_reached.wait() + self.assertTrue(call.cancel()) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.details(), + _LOCAL_CANCEL_DETAILS_EXPECTATION) + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + + async def test_cancel_after_rpc(self): + + interceptor_reached = asyncio.Event() + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + await call + interceptor_reached.set() + await asyncio.sleep(0) + + # This line should never be reached + raise Exception() + + server_target, _ = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel( + server_target, interceptors=[Interceptor()]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + + self.assertFalse(call.cancelled()) + self.assertFalse(call.done()) + + await interceptor_reached.wait() + self.assertTrue(call.cancel()) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.details(), + _LOCAL_CANCEL_DETAILS_EXPECTATION) + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + + async def test_cancel_inside_interceptor_after_rpc_awaiting(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + call.cancel() + await call + return call + + server_target, _ = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel( + server_target, interceptors=[Interceptor()]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.details(), + _LOCAL_CANCEL_DETAILS_EXPECTATION) + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + + async def test_cancel_inside_interceptor_after_rpc_not_awaiting(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + call.cancel() + return call + + server_target, _ = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel( + server_target, interceptors=[Interceptor()]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.details(), + _LOCAL_CANCEL_DETAILS_EXPECTATION) + self.assertEqual(await call.initial_metadata(), tuple()) + self.assertEqual(await call.trailing_metadata(), None) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2)