From 77df7f5f172e9b3bbbeb5805a9e4fb76965bd0b2 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Sun, 17 Nov 2019 21:05:20 +0100 Subject: [PATCH 01/12] Client unary unary interceptor Implements the unary unary interceptor for the client-side. Interceptors can be now installed by passing them as a new parameter of the `Channel` constructor or by giving them as part of the `insecure_channel` function. Interceptors are executed within an Asyncio task for making some work before the RPC invocation, and after for accessing to the intercepted call that has been invoked. --- .../grpc/_cython/_cygrpc/aio/call.pyx.pxi | 1 + .../grpcio/grpc/experimental/BUILD.bazel | 2 + .../grpcio/grpc/experimental/aio/__init__.py | 24 +- .../grpcio/grpc/experimental/aio/_channel.py | 89 ++-- .../grpc/experimental/aio/_interceptor.py | 336 ++++++++++++ .../grpcio/grpc/experimental/aio/_utils.py | 23 + src/python/grpcio_tests/tests_aio/tests.json | 2 + .../tests_aio/unit/_test_server.py | 2 +- .../grpcio_tests/tests_aio/unit/call_test.py | 3 +- .../tests_aio/unit/channel_test.py | 1 + .../tests_aio/unit/interceptor_test.py | 504 ++++++++++++++++++ 11 files changed, 950 insertions(+), 37 deletions(-) create mode 100644 src/python/grpcio/grpc/experimental/aio/_interceptor.py create mode 100644 src/python/grpcio/grpc/experimental/aio/_utils.py create mode 100644 src/python/grpcio_tests/tests_aio/unit/interceptor_test.py diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi index c10d79cb7d3..2d013afe6cb 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi @@ -15,6 +15,7 @@ cimport cpython import grpc + _EMPTY_FLAGS = 0 _EMPTY_MASK = 0 _EMPTY_METADATA = None diff --git a/src/python/grpcio/grpc/experimental/BUILD.bazel b/src/python/grpcio/grpc/experimental/BUILD.bazel index 46e7bf29ab6..e436233210b 100644 --- a/src/python/grpcio/grpc/experimental/BUILD.bazel +++ b/src/python/grpcio/grpc/experimental/BUILD.bazel @@ -7,8 +7,10 @@ py_library( "aio/_base_call.py", "aio/_call.py", "aio/_channel.py", + "aio/_interceptor.py", "aio/_server.py", "aio/_typing.py", + "aio/_utils.py", ], deps = [ "//src/python/grpcio/grpc/_cython:cygrpc", diff --git a/src/python/grpcio/grpc/experimental/aio/__init__.py b/src/python/grpcio/grpc/experimental/aio/__init__.py index b20a3524508..2ece2ed3775 100644 --- a/src/python/grpcio/grpc/experimental/aio/__init__.py +++ b/src/python/grpcio/grpc/experimental/aio/__init__.py @@ -18,18 +18,26 @@ created. AsyncIO doesn't provide thread safety for most of its APIs. """ import abc +from typing import Any, Optional, Sequence, Text, Tuple import six import grpc from grpc._cython.cygrpc import init_grpc_aio from ._base_call import RpcContext, Call, UnaryUnaryCall, UnaryStreamCall +from ._call import AioRpcError from ._channel import Channel from ._channel import UnaryUnaryMultiCallable +from ._interceptor import ClientCallDetails, UnaryUnaryClientInterceptor +from ._interceptor import InterceptedUnaryUnaryCall from ._server import server -def insecure_channel(target, options=None, compression=None): +def insecure_channel( + target: Text, + options: Optional[Sequence[Tuple[Text, Any]]] = None, + compression: Optional[grpc.Compression] = None, + interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] = None): """Creates an insecure asynchronous Channel to a server. Args: @@ -38,16 +46,22 @@ def insecure_channel(target, options=None, compression=None): in gRPC Core runtime) to configure the channel. compression: An optional value indicating the compression method to be used over the lifetime of the channel. This is an EXPERIMENTAL option. + interceptors: An optional sequence of interceptors that will be executed for + any call executed with this channel. Returns: A Channel. """ - return Channel(target, () if options is None else options, None, - compression) + return Channel( + target, () if options is None else options, + None, + compression, + interceptors=interceptors) ################################### __all__ ################################# -__all__ = ('RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall', +__all__ = ('AioRpcError', 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall', 'init_grpc_aio', 'Channel', 'UnaryUnaryMultiCallable', - 'insecure_channel', 'server') + 'ClientCallDetails', 'UnaryUnaryClientInterceptor', + 'InterceptedUnaryUnaryCall', 'insecure_channel', 'server') diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index c8f7541bfb5..9c91101c43d 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -18,29 +18,35 @@ from typing import Any, Optional, Sequence, Text, Tuple import grpc from grpc import _common from grpc._cython import cygrpc + from . import _base_call from ._call import UnaryUnaryCall, UnaryStreamCall +from ._interceptor import UnaryUnaryClientInterceptor, InterceptedUnaryUnaryCall from ._typing import (DeserializingFunction, MetadataType, SerializingFunction) - - -def _timeout_to_deadline(loop: asyncio.AbstractEventLoop, - timeout: Optional[float]) -> Optional[float]: - if timeout is None: - return None - return loop.time() + timeout +from ._utils import _timeout_to_deadline class UnaryUnaryMultiCallable: """Factory an asynchronous unary-unary RPC stub call from client-side.""" + _channel: cygrpc.AioChannel + _method: bytes + _request_serializer: SerializingFunction + _response_deserializer: DeserializingFunction + _interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] + _loop: asyncio.AbstractEventLoop + def __init__(self, channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction) -> None: + response_deserializer: DeserializingFunction, + interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] + ) -> None: self._loop = asyncio.get_event_loop() self._channel = channel self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer + self._interceptors = interceptors def __call__(self, request: Any, @@ -74,7 +80,6 @@ class UnaryUnaryMultiCallable: raised RpcError will also be a Call for the RPC affording the RPC's metadata, status code, and details. """ - if metadata: raise NotImplementedError("TODO: metadata not implemented yet") @@ -88,16 +93,25 @@ class UnaryUnaryMultiCallable: if compression: raise NotImplementedError("TODO: compression not implemented yet") - deadline = _timeout_to_deadline(self._loop, timeout) - - return UnaryUnaryCall( - request, - deadline, - self._channel, - self._method, - self._request_serializer, - self._response_deserializer, - ) + if not self._interceptors: + return UnaryUnaryCall( + request, + _timeout_to_deadline(self._loop, timeout), + self._channel, + self._method, + self._request_serializer, + self._response_deserializer, + ) + else: + return InterceptedUnaryUnaryCall( + self._interceptors, + request, + timeout, + self._channel, + self._method, + self._request_serializer, + self._response_deserializer, + ) class UnaryStreamMultiCallable: @@ -138,13 +152,7 @@ class UnaryStreamMultiCallable: Returns: A Call object instance which is an awaitable object. - - Raises: - RpcError: Indicating that the RPC terminated with non-OK status. The - raised RpcError will also be a Call for the RPC affording the RPC's - metadata, status code, and details. """ - if metadata: raise NotImplementedError("TODO: metadata not implemented yet") @@ -175,11 +183,14 @@ class Channel: A cygrpc.AioChannel-backed implementation. """ + _channel: cygrpc.AioChannel + _unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] def __init__(self, target: Text, options: Optional[Sequence[Tuple[Text, Any]]], credentials: Optional[grpc.ChannelCredentials], - compression: Optional[grpc.Compression]): + compression: Optional[grpc.Compression], + interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]): """Constructor. Args: @@ -188,8 +199,9 @@ class Channel: credentials: A cygrpc.ChannelCredentials or None. compression: An optional value indicating the compression method to be used over the lifetime of the channel. + interceptors: An optional list of interceptors that would be used for + intercepting any RPC executed with that channel. """ - if options: raise NotImplementedError("TODO: options not implemented yet") @@ -199,6 +211,23 @@ class Channel: if compression: raise NotImplementedError("TODO: compression not implemented yet") + if interceptors is None: + self._unary_unary_interceptors = None + else: + self._unary_unary_interceptors = list( + filter( + lambda interceptor: isinstance(interceptor, UnaryUnaryClientInterceptor), + interceptors)) + + invalid_interceptors = set(interceptors) - set( + self._unary_unary_interceptors) + + if invalid_interceptors: + raise ValueError( + "Interceptor must be "+\ + "UnaryUnaryClientInterceptors, the following are invalid: {}"\ + .format(invalid_interceptors)) + self._channel = cygrpc.AioChannel(_common.encode(target)) def unary_unary( @@ -220,9 +249,9 @@ class Channel: Returns: A UnaryUnaryMultiCallable value for the named unary-unary method. """ - return UnaryUnaryMultiCallable(self._channel, _common.encode(method), - request_serializer, - response_deserializer) + return UnaryUnaryMultiCallable( + self._channel, _common.encode(method), request_serializer, + response_deserializer, self._unary_unary_interceptors) def unary_stream( self, diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py new file mode 100644 index 00000000000..205a710b249 --- /dev/null +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -0,0 +1,336 @@ +# Copyright 2019 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Interceptors implementation of gRPC Asyncio Python.""" +import asyncio +import collections +import functools +from abc import ABCMeta, abstractmethod +from typing import Callable, Optional, Iterator, Sequence, Text, Union + +import grpc +from grpc._cython import cygrpc + +from . import _base_call +from ._call import UnaryUnaryCall +from ._utils import _timeout_to_deadline +from ._typing import (RequestType, SerializingFunction, DeserializingFunction, + MetadataType, ResponseType) + +_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!' + + +class ClientCallDetails( + collections.namedtuple( + 'ClientCallDetails', + ('method', 'timeout', 'metadata', 'credentials')), + grpc.ClientCallDetails): + + method: Text + timeout: Optional[float] + metadata: Optional[MetadataType] + credentials: Optional[grpc.CallCredentials] + + +class UnaryUnaryClientInterceptor(metaclass=ABCMeta): + """Affords intercepting unary-unary invocations.""" + + @abstractmethod + async def intercept_unary_unary( + self, continuation: Callable[[ClientCallDetails, RequestType], + UnaryUnaryCall], + client_call_details: ClientCallDetails, + request: RequestType) -> Union[UnaryUnaryCall, ResponseType]: + """Intercepts a unary-unary invocation asynchronously. + Args: + continuation: A coroutine that proceeds with the invocation by + executing the next interceptor in chain or invoking the + actual RPC on the underlying Channel. It is the interceptor's + responsibility to call it if it decides to move the RPC forward. + The interceptor can use + `response_future = await continuation(client_call_details, request)` + to continue with the RPC. `continuation` returns the response of the + RPC. + client_call_details: A ClientCallDetails object describing the + outgoing RPC. + request: The request value for the RPC. + Returns: + An object with the RPC response. + Raises: + AioRpcError: Indicating that the RPC terminated with non-OK status. + asyncio.CancelledError: Indicating that the RPC was canceled. + """ + + +class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): + """Used for running a `UnaryUnaryCall` wrapped by interceptors. + + Interceptors might have some work to do before the RPC invocation with + the capacity of changing the invocation parameters, and some work to do + after the RPC invocation with the capacity for accessing to the wrapped + `UnaryUnaryCall`. + + It handles also early and later cancellations, when the RPC has not even + started and the execution is still held by the interceptors or when the + RPC has finished but again the execution is still held by the interceptors. + + Once the RPC is finally executed, all methods are finally done against the + intercepted call, being at the same time the same call returned to the + interceptors. + + For most of the methods, like `initial_metadata()` the caller does not need + to wait until the interceptors task is finished, once the RPC is done the + caller will have the freedom for accessing to the results. + + For the `__await__` method is it is proxied to the intercepted call only when + the interceptor task is finished. + """ + + _loop: asyncio.AbstractEventLoop + _channel: cygrpc.AioChannel + _cancelled_before_rpc: bool + _intercepted_call: Optional[_base_call.UnaryUnaryCall] + _intercepted_call_created: asyncio.Event + _interceptors_task: asyncio.Task + + def __init__( # pylint: disable=R0913 + self, interceptors: Sequence[UnaryUnaryClientInterceptor], + request: RequestType, timeout: Optional[float], + channel: cygrpc.AioChannel, method: bytes, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction) -> None: + self._channel = channel + self._loop = asyncio.get_event_loop() + self._interceptors_task = asyncio.ensure_future( + self._invoke(interceptors, method, timeout, request, + request_serializer, response_deserializer)) + + def __del__(self): + self.cancel() + + async def _invoke( + self, interceptors: Sequence[UnaryUnaryClientInterceptor], + method: bytes, timeout: Optional[float], request: RequestType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction) -> UnaryUnaryCall: + """Run the RPC call wrapped in interceptors""" + + async def _run_interceptor( + interceptors: Iterator[UnaryUnaryClientInterceptor], + client_call_details: ClientCallDetails, + request: RequestType) -> _base_call.UnaryUnaryCall: + try: + interceptor = next(interceptors) + except StopIteration: + interceptor = None + + if interceptor: + continuation = functools.partial(_run_interceptor, interceptors) + try: + call_or_response = await interceptor.intercept_unary_unary( + continuation, client_call_details, request) + except grpc.RpcError as err: + # gRPC error is masked inside an artificial call, + # caller will see this error if and only + # if it runs an `await call` operation + return UnaryUnaryCallRpcError(err) + except asyncio.CancelledError: + # Cancellation is masked inside an artificial call, + # caller will see this error if and only + # if it runs an `await call` operation + return UnaryUnaryCancelledError() + + if isinstance(call_or_response, _base_call.UnaryUnaryCall): + return call_or_response + else: + return UnaryUnaryCallResponse(call_or_response) + + else: + return UnaryUnaryCall(request, + _timeout_to_deadline( + self._loop, + client_call_details.timeout), + self._channel, client_call_details.method, + request_serializer, response_deserializer) + + client_call_details = ClientCallDetails(method, timeout, None, None) + return await _run_interceptor( + iter(interceptors), client_call_details, request) + + def cancel(self) -> bool: + if self._interceptors_task.done(): + return False + + return self._interceptors_task.cancel() + + def cancelled(self) -> bool: + if not self._interceptors_task.done(): + return False + + call = self._interceptors_task.result() + return call.cancelled() + + def done(self) -> bool: + if not self._interceptors_task.done(): + return False + + return True + + def add_done_callback(self, unused_callback) -> None: + raise NotImplementedError() + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + async def initial_metadata(self) -> Optional[MetadataType]: + return await (await self._interceptors_task).initial_metadata() + + async def trailing_metadata(self) -> Optional[MetadataType]: + return await (await self._interceptors_task).trailing_metadata() + + async def code(self) -> grpc.StatusCode: + return await (await self._interceptors_task).code() + + async def details(self) -> str: + return await (await self._interceptors_task).details() + + async def debug_error_string(self) -> Optional[str]: + return await (await self._interceptors_task).debug_error_string() + + def __await__(self): + call = yield from self._interceptors_task.__await__() + response = yield from call.__await__() + return response + + +class UnaryUnaryCallRpcError(_base_call.UnaryUnaryCall): + """Final UnaryUnaryCall class finished with an RpcError.""" + _error: grpc.RpcError + + def __init__(self, error: grpc.RpcError) -> None: + self._error = error + + def cancel(self) -> bool: + return False + + def cancelled(self) -> bool: + return False + + def done(self) -> bool: + return True + + def add_done_callback(self, unused_callback) -> None: + raise NotImplementedError() + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + async def initial_metadata(self) -> Optional[MetadataType]: + return None + + async def trailing_metadata(self) -> Optional[MetadataType]: + return self._error.initial_metadata() + + async def code(self) -> grpc.StatusCode: + return self._error.code() + + async def details(self) -> str: + return self._error.details() + + async def debug_error_string(self) -> Optional[str]: + return self._error.debug_error_string() + + def __await__(self): + raise self._error + + +class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): + """Final UnaryUnaryCall class finished with a response.""" + _response: ResponseType + + def __init__(self, response: ResponseType) -> None: + self._response = response + + def cancel(self) -> bool: + return False + + def cancelled(self) -> bool: + return False + + def done(self) -> bool: + return True + + def add_done_callback(self, unused_callback) -> None: + raise NotImplementedError() + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + async def initial_metadata(self) -> Optional[MetadataType]: + return None + + async def trailing_metadata(self) -> Optional[MetadataType]: + return None + + async def code(self) -> grpc.StatusCode: + return grpc.StatusCode.OK + + async def details(self) -> str: + return '' + + async def debug_error_string(self) -> Optional[str]: + return None + + def __await__(self): + if False: # pylint: disable=W0125 + # This code path is never used, but a yield statement is needed + # for telling the interpreter that __await__ is a generator. + yield None + return self._response + + +class UnaryUnaryCancelledError(_base_call.UnaryUnaryCall): + """Final UnaryUnaryCall class finished with an asyncio.CancelledError.""" + + def cancel(self) -> bool: + return False + + def cancelled(self) -> bool: + return True + + def done(self) -> bool: + return True + + def add_done_callback(self, unused_callback) -> None: + raise NotImplementedError() + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + async def initial_metadata(self) -> Optional[MetadataType]: + return None + + async def trailing_metadata(self) -> Optional[MetadataType]: + return None + + async def code(self) -> grpc.StatusCode: + return grpc.StatusCode.CANCELLED + + async def details(self) -> str: + return _LOCAL_CANCELLATION_DETAILS + + async def debug_error_string(self) -> Optional[str]: + return None + + def __await__(self): + raise asyncio.CancelledError() diff --git a/src/python/grpcio/grpc/experimental/aio/_utils.py b/src/python/grpcio/grpc/experimental/aio/_utils.py new file mode 100644 index 00000000000..17fabbb5bff --- /dev/null +++ b/src/python/grpcio/grpc/experimental/aio/_utils.py @@ -0,0 +1,23 @@ +# Copyright 2019 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Internal utilities used by the gRPC Aio module.""" +import asyncio +from typing import Optional + + +def _timeout_to_deadline(loop: asyncio.AbstractEventLoop, + timeout: Optional[float]) -> Optional[float]: + if timeout is None: + return None + return loop.time() + timeout diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index 5d6634b8a8e..26545c29dc9 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -5,5 +5,7 @@ "unit.call_test.TestUnaryUnaryCall", "unit.channel_test.TestChannel", "unit.init_test.TestInsecureChannel", + "unit.interceptor_test.TestInterceptedUnaryUnaryCall", + "unit.interceptor_test.TestUnaryUnaryClientInterceptor", "unit.server_test.TestServer" ] diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_server.py b/src/python/grpcio_tests/tests_aio/unit/_test_server.py index 26b920f6c17..c3a04f29a00 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_server.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_server.py @@ -17,9 +17,9 @@ import logging import datetime from grpc.experimental import aio +from tests.unit.framework.common import test_constants from src.proto.grpc.testing import messages_pb2 from src.proto.grpc.testing import test_pb2_grpc -from tests.unit.framework.common import test_constants class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index 5ef6e0b9a43..ee28b0a966a 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -26,6 +26,7 @@ from src.proto.grpc.testing import test_pb2_grpc from tests.unit.framework.common import test_constants from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_base import AioTestBase +from src.proto.grpc.testing import messages_pb2 _NUM_STREAM_RESPONSES = 5 _RESPONSE_PAYLOAD_SIZE = 42 @@ -399,5 +400,5 @@ class TestUnaryStreamCall(AioTestBase): if __name__ == '__main__': - logging.basicConfig(level=logging.DEBUG) + logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py index 27b04e6875b..017ad0ae06c 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -25,6 +25,7 @@ from src.proto.grpc.testing import test_pb2_grpc from tests.unit.framework.common import test_constants from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_base import AioTestBase +from src.proto.grpc.testing import messages_pb2 _UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall' _STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall' diff --git a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py new file mode 100644 index 00000000000..f55c83eb48b --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py @@ -0,0 +1,504 @@ +# Copyright 2019 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import unittest + +import grpc + +from grpc.experimental import aio +from tests_aio.unit._test_server import start_test_server +from tests_aio.unit._test_base import AioTestBase +from src.proto.grpc.testing import messages_pb2 + +_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' + + +class TestUnaryUnaryClientInterceptor(AioTestBase): + + def test_invalid_interceptor(self): + + class InvalidInterceptor: + """Just an invalid Interceptor""" + + with self.assertRaises(ValueError): + aio.insecure_channel("", interceptors=[InvalidInterceptor()]) + + async def test_executed_right_order(self): + + interceptors_executed = [] + + class Interceptor(aio.UnaryUnaryClientInterceptor): + """Interceptor used for testing if the interceptor is being called""" + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + interceptors_executed.append(self) + call = await continuation(client_call_details, request) + return call + + interceptors = [Interceptor() for i in range(2)] + + server_target, _ = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel( + server_target, interceptors=interceptors) as channel: + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + response = await call + + # Check that all interceptors were executed, and were executed + # in the right order. + self.assertSequenceEqual(interceptors_executed, interceptors) + + self.assertIsInstance(response, messages_pb2.SimpleResponse) + + @unittest.expectedFailure + # TODO(https://github.com/grpc/grpc/issues/20144) Once metadata support is + # implemented in the client-side, this test must be implemented. + def test_modify_metadata(self): + raise NotImplementedError() + + @unittest.expectedFailure + # TODO(https://github.com/grpc/grpc/issues/20532) Once credentials support is + # implemented in the client-side, this test must be implemented. + def test_modify_credentials(self): + raise NotImplementedError() + + async def test_status_code_Ok(self): + + class StatusCodeOkInterceptor(aio.UnaryUnaryClientInterceptor): + """Interceptor used for observing status code Ok returned by the RPC""" + + def __init__(self): + self.status_code_Ok_observed = False + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + code = await call.code() + if code == grpc.StatusCode.OK: + self.status_code_Ok_observed = True + + return call + + interceptor = StatusCodeOkInterceptor() + server_target, server = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel( + server_target, interceptors=[interceptor]) as channel: + + # when no error StatusCode.OK must be observed + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + await multicallable(messages_pb2.SimpleRequest()) + + self.assertTrue(interceptor.status_code_Ok_observed) + + async def test_add_timeout(self): + + class TimeoutInterceptor(aio.UnaryUnaryClientInterceptor): + """Interceptor used for adding a timeout to the RPC""" + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + new_client_call_details = aio.ClientCallDetails( + method=client_call_details.method, + timeout=0.1, + metadata=client_call_details.metadata, + credentials=client_call_details.credentials) + return await continuation(new_client_call_details, request) + + interceptor = TimeoutInterceptor() + server_target, server = await start_test_server() + + async with aio.insecure_channel( + server_target, interceptors=[interceptor]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + call = multicallable(messages_pb2.SimpleRequest()) + + await server.stop(None) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + self.assertEqual(exception_context.exception.code(), + grpc.StatusCode.DEADLINE_EXCEEDED) + + self.assertTrue(call.done()) + self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await + call.code()) + + async def test_retry(self): + + class RetryInterceptor(aio.UnaryUnaryClientInterceptor): + """Simulates a Retry Interceptor which ends up by making + two RPC calls.""" + + def __init__(self): + self.calls = [] + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + + new_client_call_details = aio.ClientCallDetails( + method=client_call_details.method, + timeout=0.1, + metadata=client_call_details.metadata, + credentials=client_call_details.credentials) + + try: + call = await continuation(new_client_call_details, request) + await call + except grpc.RpcError: + pass + + self.calls.append(call) + + new_client_call_details = aio.ClientCallDetails( + method=client_call_details.method, + timeout=None, + metadata=client_call_details.metadata, + credentials=client_call_details.credentials) + + call = await continuation(new_client_call_details, request) + self.calls.append(call) + return call + + interceptor = RetryInterceptor() + server_target, server = await start_test_server() + + async with aio.insecure_channel( + server_target, interceptors=[interceptor]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + call = multicallable(messages_pb2.SimpleRequest()) + + await call + + self.assertEqual(grpc.StatusCode.OK, await call.code()) + + # Check that two calls were made, first one finishing with + # a deadline and second one finishing ok.. + self.assertEqual(len(interceptor.calls), 2) + self.assertEqual(await interceptor.calls[0].code(), + grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertEqual(await interceptor.calls[1].code(), + grpc.StatusCode.OK) + + async def test_rpcerror_raised_when_call_is_awaited(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + """RpcErrors are only seen when the call is awaited""" + + def __init__(self): + self.deadline_seen = False + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + + try: + await call + except aio.AioRpcError as err: + if err.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + self.deadline_seen = True + raise + + # This point should never be reached + raise Exception() + + interceptor_a, interceptor_b = (Interceptor(), Interceptor()) + server_target, server = await start_test_server() + + async with aio.insecure_channel( + server_target, interceptors=[interceptor_a, + interceptor_b]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + call = multicallable(messages_pb2.SimpleRequest(), timeout=0.1) + + with self.assertRaises(grpc.RpcError) as exception_context: + await call + + # Check that the two interceptors catch the deadline exception + # only when the call was awaited + self.assertTrue(interceptor_a.deadline_seen) + self.assertTrue(interceptor_b.deadline_seen) + + # Check all of the UnaryUnaryCallRpcError attributes + self.assertTrue(call.done()) + self.assertFalse(call.cancel()) + self.assertFalse(call.cancelled()) + self.assertEqual(await call.code(), + grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertEqual(await call.details(), 'Deadline Exceeded') + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.debug_error_string(), None) + + async def test_rpcresponse(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + """Raw responses are seen as reegular calls""" + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + response = await call + return call + + class ResponseInterceptor(aio.UnaryUnaryClientInterceptor): + """Return a raw response""" + response = messages_pb2.SimpleResponse() + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + return ResponseInterceptor.response + + interceptor, interceptor_response = Interceptor(), ResponseInterceptor() + server_target, server = await start_test_server() + + async with aio.insecure_channel( + server_target, interceptors=[interceptor, + interceptor_response]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + call = multicallable(messages_pb2.SimpleRequest()) + response = await call + + # Check that the response returned is the one returned by the + # interceptor + self.assertEqual(id(response), id(ResponseInterceptor.response)) + + # Check all of the UnaryUnaryCallResponse attributes + self.assertTrue(call.done()) + self.assertFalse(call.cancel()) + self.assertFalse(call.cancelled()) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + self.assertEqual(await call.debug_error_string(), None) + + +class TestInterceptedUnaryUnaryCall(AioTestBase): + + async def test_call_ok(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + return call + + server_target, _ = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel( + server_target, interceptors=[Interceptor()]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + response = await call + + self.assertTrue(call.done()) + self.assertFalse(call.cancelled()) + self.assertEqual(type(response), messages_pb2.SimpleResponse) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.initial_metadata(), ()) + self.assertEqual(await call.trailing_metadata(), ()) + + async def test_cancel_before_rpc(self): + + interceptor_reached = asyncio.Event() + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + interceptor_reached.set() + await asyncio.sleep(0) + + # This line should never be reached + raise Exception() + + server_target, _ = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel( + server_target, interceptors=[Interceptor()]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + + self.assertFalse(call.cancelled()) + self.assertFalse(call.done()) + + await interceptor_reached.wait() + self.assertTrue(call.cancel()) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.details(), + _LOCAL_CANCEL_DETAILS_EXPECTATION) + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + + async def test_cancel_after_rpc(self): + + interceptor_reached = asyncio.Event() + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + await call + interceptor_reached.set() + await asyncio.sleep(0) + + # This line should never be reached + raise Exception() + + server_target, _ = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel( + server_target, interceptors=[Interceptor()]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + + self.assertFalse(call.cancelled()) + self.assertFalse(call.done()) + + await interceptor_reached.wait() + self.assertTrue(call.cancel()) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.details(), + _LOCAL_CANCEL_DETAILS_EXPECTATION) + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + + async def test_cancel_inside_interceptor_after_rpc_awaiting(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + call.cancel() + await call + return call + + server_target, _ = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel( + server_target, interceptors=[Interceptor()]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.details(), + _LOCAL_CANCEL_DETAILS_EXPECTATION) + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + + async def test_cancel_inside_interceptor_after_rpc_not_awaiting(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + call.cancel() + return call + + server_target, _ = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel( + server_target, interceptors=[Interceptor()]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.details(), + _LOCAL_CANCEL_DETAILS_EXPECTATION) + self.assertEqual(await call.initial_metadata(), tuple()) + self.assertEqual(await call.trailing_metadata(), None) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) From 833df2b6c854d63be6931179d250e2099521b0a6 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Tue, 7 Jan 2020 00:55:59 +0100 Subject: [PATCH 02/12] Apply feedback --- src/python/grpcio/grpc/experimental/aio/_interceptor.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index 205a710b249..2476310fff0 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -129,10 +129,8 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): interceptors: Iterator[UnaryUnaryClientInterceptor], client_call_details: ClientCallDetails, request: RequestType) -> _base_call.UnaryUnaryCall: - try: - interceptor = next(interceptors) - except StopIteration: - interceptor = None + + interceptor = next(interceptors, None) if interceptor: continuation = functools.partial(_run_interceptor, interceptors) From a2667b80c35f44fbd81a20456d58a7a7ce91f64b Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Tue, 7 Jan 2020 01:32:47 +0100 Subject: [PATCH 03/12] make YAPF happy --- .../grpcio/grpc/experimental/aio/__init__.py | 18 ++++---- .../grpcio/grpc/experimental/aio/_channel.py | 10 +++-- .../grpc/experimental/aio/_interceptor.py | 27 ++++++------ .../tests_aio/unit/interceptor_test.py | 41 +++++++++++-------- 4 files changed, 52 insertions(+), 44 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/__init__.py b/src/python/grpcio/grpc/experimental/aio/__init__.py index 2ece2ed3775..9a0b6a6fa64 100644 --- a/src/python/grpcio/grpc/experimental/aio/__init__.py +++ b/src/python/grpcio/grpc/experimental/aio/__init__.py @@ -52,16 +52,16 @@ def insecure_channel( Returns: A Channel. """ - return Channel( - target, () if options is None else options, - None, - compression, - interceptors=interceptors) + return Channel(target, () if options is None else options, + None, + compression, + interceptors=interceptors) ################################### __all__ ################################# -__all__ = ('AioRpcError', 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall', - 'init_grpc_aio', 'Channel', 'UnaryUnaryMultiCallable', - 'ClientCallDetails', 'UnaryUnaryClientInterceptor', - 'InterceptedUnaryUnaryCall', 'insecure_channel', 'server') +__all__ = ('AioRpcError', 'RpcContext', 'Call', 'UnaryUnaryCall', + 'UnaryStreamCall', 'init_grpc_aio', 'Channel', + 'UnaryUnaryMultiCallable', 'ClientCallDetails', + 'UnaryUnaryClientInterceptor', 'InterceptedUnaryUnaryCall', + 'insecure_channel', 'server') diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 9c91101c43d..3aa9fc07360 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -216,7 +216,8 @@ class Channel: else: self._unary_unary_interceptors = list( filter( - lambda interceptor: isinstance(interceptor, UnaryUnaryClientInterceptor), + lambda interceptor: isinstance(interceptor, + UnaryUnaryClientInterceptor), interceptors)) invalid_interceptors = set(interceptors) - set( @@ -249,9 +250,10 @@ class Channel: Returns: A UnaryUnaryMultiCallable value for the named unary-unary method. """ - return UnaryUnaryMultiCallable( - self._channel, _common.encode(method), request_serializer, - response_deserializer, self._unary_unary_interceptors) + return UnaryUnaryMultiCallable(self._channel, _common.encode(method), + request_serializer, + response_deserializer, + self._unary_unary_interceptors) def unary_stream( self, diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index 2476310fff0..00ea17924a4 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -118,11 +118,12 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): def __del__(self): self.cancel() - async def _invoke( - self, interceptors: Sequence[UnaryUnaryClientInterceptor], - method: bytes, timeout: Optional[float], request: RequestType, - request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction) -> UnaryUnaryCall: + async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor], + method: bytes, timeout: Optional[float], + request: RequestType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction + ) -> UnaryUnaryCall: """Run the RPC call wrapped in interceptors""" async def _run_interceptor( @@ -154,16 +155,16 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): return UnaryUnaryCallResponse(call_or_response) else: - return UnaryUnaryCall(request, - _timeout_to_deadline( - self._loop, - client_call_details.timeout), - self._channel, client_call_details.method, - request_serializer, response_deserializer) + return UnaryUnaryCall( + request, + _timeout_to_deadline(self._loop, + client_call_details.timeout), + self._channel, client_call_details.method, + request_serializer, response_deserializer) client_call_details = ClientCallDetails(method, timeout, None, None) - return await _run_interceptor( - iter(interceptors), client_call_details, request) + return await _run_interceptor(iter(interceptors), client_call_details, + request) def cancel(self) -> bool: if self._interceptors_task.done(): diff --git a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py index f55c83eb48b..f97fbe171d3 100644 --- a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py @@ -52,8 +52,8 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): server_target, _ = await start_test_server() # pylint: disable=unused-variable - async with aio.insecure_channel( - server_target, interceptors=interceptors) as channel: + async with aio.insecure_channel(server_target, + interceptors=interceptors) as channel: multicallable = channel.unary_unary( '/grpc.testing.TestService/UnaryCall', request_serializer=messages_pb2.SimpleRequest.SerializeToString, @@ -99,8 +99,8 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): interceptor = StatusCodeOkInterceptor() server_target, server = await start_test_server() # pylint: disable=unused-variable - async with aio.insecure_channel( - server_target, interceptors=[interceptor]) as channel: + async with aio.insecure_channel(server_target, + interceptors=[interceptor]) as channel: # when no error StatusCode.OK must be observed multicallable = channel.unary_unary( @@ -129,8 +129,8 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): interceptor = TimeoutInterceptor() server_target, server = await start_test_server() - async with aio.insecure_channel( - server_target, interceptors=[interceptor]) as channel: + async with aio.insecure_channel(server_target, + interceptors=[interceptor]) as channel: multicallable = channel.unary_unary( '/grpc.testing.TestService/UnaryCall', @@ -190,8 +190,8 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): interceptor = RetryInterceptor() server_target, server = await start_test_server() - async with aio.insecure_channel( - server_target, interceptors=[interceptor]) as channel: + async with aio.insecure_channel(server_target, + interceptors=[interceptor]) as channel: multicallable = channel.unary_unary( '/grpc.testing.TestService/UnaryCall', @@ -329,8 +329,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): server_target, _ = await start_test_server() # pylint: disable=unused-variable - async with aio.insecure_channel( - server_target, interceptors=[Interceptor()]) as channel: + async with aio.insecure_channel(server_target, + interceptors=[Interceptor() + ]) as channel: multicallable = channel.unary_unary( '/grpc.testing.TestService/UnaryCall', @@ -363,8 +364,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): server_target, _ = await start_test_server() # pylint: disable=unused-variable - async with aio.insecure_channel( - server_target, interceptors=[Interceptor()]) as channel: + async with aio.insecure_channel(server_target, + interceptors=[Interceptor() + ]) as channel: multicallable = channel.unary_unary( '/grpc.testing.TestService/UnaryCall', @@ -407,8 +409,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): server_target, _ = await start_test_server() # pylint: disable=unused-variable - async with aio.insecure_channel( - server_target, interceptors=[Interceptor()]) as channel: + async with aio.insecure_channel(server_target, + interceptors=[Interceptor() + ]) as channel: multicallable = channel.unary_unary( '/grpc.testing.TestService/UnaryCall', @@ -446,8 +449,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): server_target, _ = await start_test_server() # pylint: disable=unused-variable - async with aio.insecure_channel( - server_target, interceptors=[Interceptor()]) as channel: + async with aio.insecure_channel(server_target, + interceptors=[Interceptor() + ]) as channel: multicallable = channel.unary_unary( '/grpc.testing.TestService/UnaryCall', @@ -478,8 +482,9 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): server_target, _ = await start_test_server() # pylint: disable=unused-variable - async with aio.insecure_channel( - server_target, interceptors=[Interceptor()]) as channel: + async with aio.insecure_channel(server_target, + interceptors=[Interceptor() + ]) as channel: multicallable = channel.unary_unary( '/grpc.testing.TestService/UnaryCall', From 33765f5ee54dd9c01c4436983bee87a816d54911 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Tue, 7 Jan 2020 23:31:50 +0100 Subject: [PATCH 04/12] Not mask AioRpcError and CancelledError at interceptor level --- .../grpcio/grpc/experimental/aio/_call.py | 2 +- .../grpc/experimental/aio/_interceptor.py | 158 +++++++----------- .../tests_aio/unit/interceptor_test.py | 156 +++++++++++------ 3 files changed, 160 insertions(+), 156 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index f04c6cdd761..0ff5c0a83f8 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -233,7 +233,7 @@ class Call(_base_call.Call): if self._code is grpc.StatusCode.OK: return _OK_CALL_REPRESENTATION.format( self.__class__.__name__, self._code, - self._status.result().self._status.result().details()) + self._status.result().details()) else: return _NON_OK_CALL_REPRESENTATION.format( self.__class__.__name__, self._code, diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index 00ea17924a4..4c643443dda 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -22,7 +22,7 @@ import grpc from grpc._cython import cygrpc from . import _base_call -from ._call import UnaryUnaryCall +from ._call import UnaryUnaryCall, AioRpcError from ._utils import _timeout_to_deadline from ._typing import (RequestType, SerializingFunction, DeserializingFunction, MetadataType, ResponseType) @@ -135,19 +135,9 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): if interceptor: continuation = functools.partial(_run_interceptor, interceptors) - try: - call_or_response = await interceptor.intercept_unary_unary( - continuation, client_call_details, request) - except grpc.RpcError as err: - # gRPC error is masked inside an artificial call, - # caller will see this error if and only - # if it runs an `await call` operation - return UnaryUnaryCallRpcError(err) - except asyncio.CancelledError: - # Cancellation is masked inside an artificial call, - # caller will see this error if and only - # if it runs an `await call` operation - return UnaryUnaryCancelledError() + + call_or_response = await interceptor.intercept_unary_unary( + continuation, client_call_details, request) if isinstance(call_or_response, _base_call.UnaryUnaryCall): return call_or_response @@ -176,14 +166,25 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): if not self._interceptors_task.done(): return False - call = self._interceptors_task.result() - return call.cancelled() + try: + call = self._interceptors_task.result() + except AioRpcError: + return False + except asyncio.CancelledError: + return True + else: + return call.cancelled() def done(self) -> bool: if not self._interceptors_task.done(): return False - return True + try: + call = self._interceptors_task.result() + except (AioRpcError, asyncio.CancelledError): + return True + else: + return call.done() def add_done_callback(self, unused_callback) -> None: raise NotImplementedError() @@ -192,19 +193,54 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): raise NotImplementedError() async def initial_metadata(self) -> Optional[MetadataType]: - return await (await self._interceptors_task).initial_metadata() + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.initial_metadata() + except asyncio.CancelledError: + return None + else: + return await call.initial_metadata() async def trailing_metadata(self) -> Optional[MetadataType]: - return await (await self._interceptors_task).trailing_metadata() + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.trailing_metadata() + except asyncio.CancelledError: + return None + else: + return await call.trailing_metadata() async def code(self) -> grpc.StatusCode: - return await (await self._interceptors_task).code() + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.code() + except asyncio.CancelledError: + return grpc.StatusCode.CANCELLED + else: + return await call.code() async def details(self) -> str: - return await (await self._interceptors_task).details() + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.details() + except asyncio.CancelledError: + return _LOCAL_CANCELLATION_DETAILS + else: + return await call.details() async def debug_error_string(self) -> Optional[str]: - return await (await self._interceptors_task).debug_error_string() + try: + call = await self._interceptors_task + except AioRpcError as err: + return err.debug_error_string() + except asyncio.CancelledError: + return '' + else: + return await call.debug_error_string() def __await__(self): call = yield from self._interceptors_task.__await__() @@ -212,47 +248,6 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): return response -class UnaryUnaryCallRpcError(_base_call.UnaryUnaryCall): - """Final UnaryUnaryCall class finished with an RpcError.""" - _error: grpc.RpcError - - def __init__(self, error: grpc.RpcError) -> None: - self._error = error - - def cancel(self) -> bool: - return False - - def cancelled(self) -> bool: - return False - - def done(self) -> bool: - return True - - def add_done_callback(self, unused_callback) -> None: - raise NotImplementedError() - - def time_remaining(self) -> Optional[float]: - raise NotImplementedError() - - async def initial_metadata(self) -> Optional[MetadataType]: - return None - - async def trailing_metadata(self) -> Optional[MetadataType]: - return self._error.initial_metadata() - - async def code(self) -> grpc.StatusCode: - return self._error.code() - - async def details(self) -> str: - return self._error.details() - - async def debug_error_string(self) -> Optional[str]: - return self._error.debug_error_string() - - def __await__(self): - raise self._error - - class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): """Final UnaryUnaryCall class finished with a response.""" _response: ResponseType @@ -296,40 +291,3 @@ class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): # for telling the interpreter that __await__ is a generator. yield None return self._response - - -class UnaryUnaryCancelledError(_base_call.UnaryUnaryCall): - """Final UnaryUnaryCall class finished with an asyncio.CancelledError.""" - - def cancel(self) -> bool: - return False - - def cancelled(self) -> bool: - return True - - def done(self) -> bool: - return True - - def add_done_callback(self, unused_callback) -> None: - raise NotImplementedError() - - def time_remaining(self) -> Optional[float]: - raise NotImplementedError() - - async def initial_metadata(self) -> Optional[MetadataType]: - return None - - async def trailing_metadata(self) -> Optional[MetadataType]: - return None - - async def code(self) -> grpc.StatusCode: - return grpc.StatusCode.CANCELLED - - async def details(self) -> str: - return _LOCAL_CANCELLATION_DETAILS - - async def debug_error_string(self) -> Optional[str]: - return None - - def __await__(self): - raise asyncio.CancelledError() diff --git a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py index f97fbe171d3..f39360d2f3e 100644 --- a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py @@ -177,6 +177,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): self.calls.append(call) + new_client_call_details = aio.ClientCallDetails( method=client_call_details.method, timeout=None, @@ -212,61 +213,6 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): self.assertEqual(await interceptor.calls[1].code(), grpc.StatusCode.OK) - async def test_rpcerror_raised_when_call_is_awaited(self): - - class Interceptor(aio.UnaryUnaryClientInterceptor): - """RpcErrors are only seen when the call is awaited""" - - def __init__(self): - self.deadline_seen = False - - async def intercept_unary_unary(self, continuation, - client_call_details, request): - call = await continuation(client_call_details, request) - - try: - await call - except aio.AioRpcError as err: - if err.code() == grpc.StatusCode.DEADLINE_EXCEEDED: - self.deadline_seen = True - raise - - # This point should never be reached - raise Exception() - - interceptor_a, interceptor_b = (Interceptor(), Interceptor()) - server_target, server = await start_test_server() - - async with aio.insecure_channel( - server_target, interceptors=[interceptor_a, - interceptor_b]) as channel: - - multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', - request_serializer=messages_pb2.SimpleRequest.SerializeToString, - response_deserializer=messages_pb2.SimpleResponse.FromString) - - call = multicallable(messages_pb2.SimpleRequest(), timeout=0.1) - - with self.assertRaises(grpc.RpcError) as exception_context: - await call - - # Check that the two interceptors catch the deadline exception - # only when the call was awaited - self.assertTrue(interceptor_a.deadline_seen) - self.assertTrue(interceptor_b.deadline_seen) - - # Check all of the UnaryUnaryCallRpcError attributes - self.assertTrue(call.done()) - self.assertFalse(call.cancel()) - self.assertFalse(call.cancelled()) - self.assertEqual(await call.code(), - grpc.StatusCode.DEADLINE_EXCEEDED) - self.assertEqual(await call.details(), 'Deadline Exceeded') - self.assertEqual(await call.initial_metadata(), None) - self.assertEqual(await call.trailing_metadata(), ()) - self.assertEqual(await call.debug_error_string(), None) - async def test_rpcresponse(self): class Interceptor(aio.UnaryUnaryClientInterceptor): @@ -348,6 +294,106 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(await call.initial_metadata(), ()) self.assertEqual(await call.trailing_metadata(), ()) + async def test_call_ok_awaited(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + await call + return call + + server_target, _ = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel(server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + call = multicallable(messages_pb2.SimpleRequest()) + response = await call + + self.assertTrue(call.done()) + self.assertFalse(call.cancelled()) + self.assertEqual(type(response), messages_pb2.SimpleResponse) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.initial_metadata(), ()) + self.assertEqual(await call.trailing_metadata(), ()) + + async def test_call_rpcerror(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + return call + + server_target, server = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel(server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + await server.stop(None) + + call = multicallable(messages_pb2.SimpleRequest(), timeout=0.1) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + self.assertTrue(call.done()) + self.assertFalse(call.cancelled()) + self.assertEqual(await call.code(), grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertEqual(await call.details(), 'Deadline Exceeded') + self.assertEqual(await call.initial_metadata(), ()) + self.assertEqual(await call.trailing_metadata(), ()) + + async def test_call_rpcerror_awaited(self): + + class Interceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + await call + return call + + server_target, server = await start_test_server() # pylint: disable=unused-variable + + async with aio.insecure_channel(server_target, + interceptors=[Interceptor() + ]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + await server.stop(None) + + call = multicallable(messages_pb2.SimpleRequest(), timeout=0.1) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + self.assertTrue(call.done()) + self.assertFalse(call.cancelled()) + self.assertEqual(await call.code(), grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertEqual(await call.details(), 'Deadline Exceeded') + self.assertEqual(await call.initial_metadata(), ()) + self.assertEqual(await call.trailing_metadata(), ()) + async def test_cancel_before_rpc(self): interceptor_reached = asyncio.Event() From 5d664a5f5e0d954a72cb3d6cf7c404cc3ad69345 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Tue, 7 Jan 2020 23:38:26 +0100 Subject: [PATCH 05/12] Make YAPF happy --- src/python/grpcio_tests/tests_aio/unit/interceptor_test.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py index f39360d2f3e..fe2a09ee130 100644 --- a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py @@ -177,7 +177,6 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): self.calls.append(call) - new_client_call_details = aio.ClientCallDetails( method=client_call_details.method, timeout=None, @@ -354,7 +353,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertTrue(call.done()) self.assertFalse(call.cancelled()) - self.assertEqual(await call.code(), grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertEqual(await call.code(), + grpc.StatusCode.DEADLINE_EXCEEDED) self.assertEqual(await call.details(), 'Deadline Exceeded') self.assertEqual(await call.initial_metadata(), ()) self.assertEqual(await call.trailing_metadata(), ()) @@ -389,7 +389,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertTrue(call.done()) self.assertFalse(call.cancelled()) - self.assertEqual(await call.code(), grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertEqual(await call.code(), + grpc.StatusCode.DEADLINE_EXCEEDED) self.assertEqual(await call.details(), 'Deadline Exceeded') self.assertEqual(await call.initial_metadata(), ()) self.assertEqual(await call.trailing_metadata(), ()) From 75c858bcef7c67716ba712a4973e544dc19813d1 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Wed, 8 Jan 2020 15:31:52 +0100 Subject: [PATCH 06/12] Change test name --- src/python/grpcio_tests/tests_aio/unit/interceptor_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py index fe2a09ee130..4cbca8e9994 100644 --- a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py @@ -324,7 +324,7 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(await call.initial_metadata(), ()) self.assertEqual(await call.trailing_metadata(), ()) - async def test_call_rpcerror(self): + async def test_call_rpc_error(self): class Interceptor(aio.UnaryUnaryClientInterceptor): From 2a342b22a7dc5ae25213169a3f54c2dff6da7f28 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Thu, 9 Jan 2020 00:03:42 +0100 Subject: [PATCH 07/12] Fixes bug with deadline --- .../grpc/experimental/aio/_interceptor.py | 32 ++++---- .../grpcio/grpc/experimental/aio/_utils.py | 3 +- .../tests_aio/unit/_test_server.py | 28 ++++++- .../tests_aio/unit/channel_test.py | 21 +++-- .../tests_aio/unit/interceptor_test.py | 79 +++++++++---------- 5 files changed, 97 insertions(+), 66 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index 4c643443dda..dc25f4933c6 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -168,12 +168,12 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): try: call = self._interceptors_task.result() - except AioRpcError: - return False + except AioRpcError as err: + return err.code() == grpc.StatusCode.CANCELLED except asyncio.CancelledError: return True - else: - return call.cancelled() + + return call.cancelled() def done(self) -> bool: if not self._interceptors_task.done(): @@ -183,8 +183,8 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): call = self._interceptors_task.result() except (AioRpcError, asyncio.CancelledError): return True - else: - return call.done() + + return call.done() def add_done_callback(self, unused_callback) -> None: raise NotImplementedError() @@ -199,8 +199,8 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): return err.initial_metadata() except asyncio.CancelledError: return None - else: - return await call.initial_metadata() + + return await call.initial_metadata() async def trailing_metadata(self) -> Optional[MetadataType]: try: @@ -209,8 +209,8 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): return err.trailing_metadata() except asyncio.CancelledError: return None - else: - return await call.trailing_metadata() + + return await call.trailing_metadata() async def code(self) -> grpc.StatusCode: try: @@ -219,8 +219,8 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): return err.code() except asyncio.CancelledError: return grpc.StatusCode.CANCELLED - else: - return await call.code() + + return await call.code() async def details(self) -> str: try: @@ -229,8 +229,8 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): return err.details() except asyncio.CancelledError: return _LOCAL_CANCELLATION_DETAILS - else: - return await call.details() + + return await call.details() async def debug_error_string(self) -> Optional[str]: try: @@ -239,8 +239,8 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): return err.debug_error_string() except asyncio.CancelledError: return '' - else: - return await call.debug_error_string() + + return await call.debug_error_string() def __await__(self): call = yield from self._interceptors_task.__await__() diff --git a/src/python/grpcio/grpc/experimental/aio/_utils.py b/src/python/grpcio/grpc/experimental/aio/_utils.py index 17fabbb5bff..6a1f81a5ed7 100644 --- a/src/python/grpcio/grpc/experimental/aio/_utils.py +++ b/src/python/grpcio/grpc/experimental/aio/_utils.py @@ -13,6 +13,7 @@ # limitations under the License. """Internal utilities used by the gRPC Aio module.""" import asyncio +import time from typing import Optional @@ -20,4 +21,4 @@ def _timeout_to_deadline(loop: asyncio.AbstractEventLoop, timeout: Optional[float]) -> Optional[float]: if timeout is None: return None - return loop.time() + timeout + return time.time() + timeout diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_server.py b/src/python/grpcio_tests/tests_aio/unit/_test_server.py index c3a04f29a00..b126a50a6c4 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_server.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_server.py @@ -16,11 +16,14 @@ import asyncio import logging import datetime +import grpc from grpc.experimental import aio from tests.unit.framework.common import test_constants from src.proto.grpc.testing import messages_pb2 from src.proto.grpc.testing import test_pb2_grpc +UNARY_CALL_WITH_SLEEP_VALUE = 0.2 + class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): @@ -39,11 +42,34 @@ class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): body=b'\x00' * response_parameters.size)) + # Next methods are extra ones that are registred programatically + # when the sever is instantiated. They are not being provided by + # the proto file. + + async def UnaryCallWithSleep(self, request, context): + await asyncio.sleep(UNARY_CALL_WITH_SLEEP_VALUE) + return messages_pb2.SimpleResponse() + async def start_test_server(): server = aio.server(options=(('grpc.so_reuseport', 0),)) - test_pb2_grpc.add_TestServiceServicer_to_server(_TestServiceServicer(), + servicer = _TestServiceServicer() + test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server) + + # Add programatically extra methods not provided by the proto file + # that are used during the tests + rpc_method_handlers = { + 'UnaryCallWithSleep': grpc.unary_unary_rpc_method_handler( + servicer.UnaryCallWithSleep, + request_deserializer=messages_pb2.SimpleRequest.FromString, + response_serializer=messages_pb2.SimpleResponse.SerializeToString + ) + } + extra_handler = grpc.method_handlers_generic_handler( + 'grpc.testing.TestService', rpc_method_handlers) + server.add_generic_rpc_handlers((extra_handler,)) + port = server.add_insecure_port('[::]:0') await server.start() # NOTE(lidizheng) returning the server to prevent it from deallocation diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py index 017ad0ae06c..c079256bb2b 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -23,11 +23,12 @@ from grpc.experimental import aio from src.proto.grpc.testing import messages_pb2 from src.proto.grpc.testing import test_pb2_grpc from tests.unit.framework.common import test_constants -from tests_aio.unit._test_server import start_test_server +from tests_aio.unit._test_server import start_test_server, UNARY_CALL_WITH_SLEEP_VALUE from tests_aio.unit._test_base import AioTestBase from src.proto.grpc.testing import messages_pb2 _UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall' +_UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep' _STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall' _NUM_STREAM_RESPONSES = 5 _RESPONSE_PAYLOAD_SIZE = 42 @@ -52,7 +53,6 @@ class TestChannel(AioTestBase): async def test_unary_unary(self): async with aio.insecure_channel(self._server_target) as channel: - channel = aio.insecure_channel(self._server_target) hi = channel.unary_unary( _UNARY_CALL_METHOD, request_serializer=messages_pb2.SimpleRequest.SerializeToString, @@ -62,15 +62,15 @@ class TestChannel(AioTestBase): self.assertIsInstance(response, messages_pb2.SimpleResponse) async def test_unary_call_times_out(self): - async with aio.insecure_channel(_UNREACHABLE_TARGET) as channel: + async with aio.insecure_channel(self._server_target) as channel: hi = channel.unary_unary( - _UNARY_CALL_METHOD, + _UNARY_CALL_METHOD_WITH_SLEEP, request_serializer=messages_pb2.SimpleRequest.SerializeToString, response_deserializer=messages_pb2.SimpleResponse.FromString, ) with self.assertRaises(grpc.RpcError) as exception_context: - await hi(messages_pb2.SimpleRequest(), timeout=1.0) + await hi(messages_pb2.SimpleRequest(), timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2) _, details = grpc.StatusCode.DEADLINE_EXCEEDED.value # pylint: disable=unused-variable self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, @@ -81,6 +81,17 @@ class TestChannel(AioTestBase): self.assertIsNotNone( exception_context.exception.trailing_metadata()) + async def test_unary_call_does_not_times_out(self): + async with aio.insecure_channel(self._server_target) as channel: + hi = channel.unary_unary( + _UNARY_CALL_METHOD_WITH_SLEEP, + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString, + ) + + call = hi(messages_pb2.SimpleRequest(), timeout=UNARY_CALL_WITH_SLEEP_VALUE * 2) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + async def test_unary_stream(self): channel = aio.insecure_channel(self._server_target) stub = test_pb2_grpc.TestServiceStub(channel) diff --git a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py index 4cbca8e9994..f0b5fcd4ba8 100644 --- a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py @@ -18,15 +18,22 @@ import unittest import grpc from grpc.experimental import aio -from tests_aio.unit._test_server import start_test_server +from tests_aio.unit._test_server import start_test_server, UNARY_CALL_WITH_SLEEP_VALUE from tests_aio.unit._test_base import AioTestBase from src.proto.grpc.testing import messages_pb2 + _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' class TestUnaryUnaryClientInterceptor(AioTestBase): + async def setUp(self): + self._server_target, self._server = await start_test_server() + + async def tearDown(self): + await self._server.stop(None) + def test_invalid_interceptor(self): class InvalidInterceptor: @@ -50,9 +57,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): interceptors = [Interceptor() for i in range(2)] - server_target, _ = await start_test_server() # pylint: disable=unused-variable - - async with aio.insecure_channel(server_target, + async with aio.insecure_channel(self._server_target, interceptors=interceptors) as channel: multicallable = channel.unary_unary( '/grpc.testing.TestService/UnaryCall', @@ -97,9 +102,8 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): return call interceptor = StatusCodeOkInterceptor() - server_target, server = await start_test_server() # pylint: disable=unused-variable - async with aio.insecure_channel(server_target, + async with aio.insecure_channel(self._server_target, interceptors=[interceptor]) as channel: # when no error StatusCode.OK must be observed @@ -121,26 +125,23 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): client_call_details, request): new_client_call_details = aio.ClientCallDetails( method=client_call_details.method, - timeout=0.1, + timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2, metadata=client_call_details.metadata, credentials=client_call_details.credentials) return await continuation(new_client_call_details, request) interceptor = TimeoutInterceptor() - server_target, server = await start_test_server() - async with aio.insecure_channel(server_target, + async with aio.insecure_channel(self._server_target, interceptors=[interceptor]) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + '/grpc.testing.TestService/UnaryCallWithSleep', request_serializer=messages_pb2.SimpleRequest.SerializeToString, response_deserializer=messages_pb2.SimpleResponse.FromString) call = multicallable(messages_pb2.SimpleRequest()) - await server.stop(None) - with self.assertRaises(aio.AioRpcError) as exception_context: await call @@ -165,7 +166,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): new_client_call_details = aio.ClientCallDetails( method=client_call_details.method, - timeout=0.1, + timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2, metadata=client_call_details.metadata, credentials=client_call_details.credentials) @@ -188,13 +189,12 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): return call interceptor = RetryInterceptor() - server_target, server = await start_test_server() - async with aio.insecure_channel(server_target, + async with aio.insecure_channel(self._server_target, interceptors=[interceptor]) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + '/grpc.testing.TestService/UnaryCallWithSleep', request_serializer=messages_pb2.SimpleRequest.SerializeToString, response_deserializer=messages_pb2.SimpleResponse.FromString) @@ -232,10 +232,9 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): return ResponseInterceptor.response interceptor, interceptor_response = Interceptor(), ResponseInterceptor() - server_target, server = await start_test_server() async with aio.insecure_channel( - server_target, interceptors=[interceptor, + self._server_target, interceptors=[interceptor, interceptor_response]) as channel: multicallable = channel.unary_unary( @@ -263,6 +262,12 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): class TestInterceptedUnaryUnaryCall(AioTestBase): + async def setUp(self): + self._server_target, self._server = await start_test_server() + + async def tearDown(self): + await self._server.stop(None) + async def test_call_ok(self): class Interceptor(aio.UnaryUnaryClientInterceptor): @@ -272,9 +277,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): call = await continuation(client_call_details, request) return call - server_target, _ = await start_test_server() # pylint: disable=unused-variable - async with aio.insecure_channel(server_target, + async with aio.insecure_channel(self._server_target, interceptors=[Interceptor() ]) as channel: @@ -303,9 +307,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): await call return call - server_target, _ = await start_test_server() # pylint: disable=unused-variable - async with aio.insecure_channel(server_target, + async with aio.insecure_channel(self._server_target, interceptors=[Interceptor() ]) as channel: @@ -333,20 +336,17 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): call = await continuation(client_call_details, request) return call - server_target, server = await start_test_server() # pylint: disable=unused-variable - async with aio.insecure_channel(server_target, + async with aio.insecure_channel(self._server_target, interceptors=[Interceptor() ]) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + '/grpc.testing.TestService/UnaryCallWithSleep', request_serializer=messages_pb2.SimpleRequest.SerializeToString, response_deserializer=messages_pb2.SimpleResponse.FromString) - await server.stop(None) - - call = multicallable(messages_pb2.SimpleRequest(), timeout=0.1) + call = multicallable(messages_pb2.SimpleRequest(), timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2) with self.assertRaises(aio.AioRpcError) as exception_context: await call @@ -359,7 +359,7 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): self.assertEqual(await call.initial_metadata(), ()) self.assertEqual(await call.trailing_metadata(), ()) - async def test_call_rpcerror_awaited(self): + async def test_call_rpc_error_awaited(self): class Interceptor(aio.UnaryUnaryClientInterceptor): @@ -369,20 +369,17 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): await call return call - server_target, server = await start_test_server() # pylint: disable=unused-variable - async with aio.insecure_channel(server_target, + async with aio.insecure_channel(self._server_target, interceptors=[Interceptor() ]) as channel: multicallable = channel.unary_unary( - '/grpc.testing.TestService/UnaryCall', + '/grpc.testing.TestService/UnaryCallWithSleep', request_serializer=messages_pb2.SimpleRequest.SerializeToString, response_deserializer=messages_pb2.SimpleResponse.FromString) - await server.stop(None) - - call = multicallable(messages_pb2.SimpleRequest(), timeout=0.1) + call = multicallable(messages_pb2.SimpleRequest(), timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2) with self.assertRaises(aio.AioRpcError) as exception_context: await call @@ -409,9 +406,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): # This line should never be reached raise Exception() - server_target, _ = await start_test_server() # pylint: disable=unused-variable - async with aio.insecure_channel(server_target, + async with aio.insecure_channel(self._server_target, interceptors=[Interceptor() ]) as channel: @@ -454,9 +450,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): # This line should never be reached raise Exception() - server_target, _ = await start_test_server() # pylint: disable=unused-variable - async with aio.insecure_channel(server_target, + async with aio.insecure_channel(self._server_target, interceptors=[Interceptor() ]) as channel: @@ -494,9 +489,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): await call return call - server_target, _ = await start_test_server() # pylint: disable=unused-variable - async with aio.insecure_channel(server_target, + async with aio.insecure_channel(self._server_target, interceptors=[Interceptor() ]) as channel: @@ -527,9 +521,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): call.cancel() return call - server_target, _ = await start_test_server() # pylint: disable=unused-variable - async with aio.insecure_channel(server_target, + async with aio.insecure_channel(self._server_target, interceptors=[Interceptor() ]) as channel: From db54580f20cba0b9ab13f6668c73fe199ecb1059 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Thu, 9 Jan 2020 00:04:44 +0100 Subject: [PATCH 08/12] Make YAPF happy --- .../tests_aio/unit/_test_server.py | 14 +++++++------- .../tests_aio/unit/channel_test.py | 6 ++++-- .../tests_aio/unit/interceptor_test.py | 19 ++++++------------- 3 files changed, 17 insertions(+), 22 deletions(-) diff --git a/src/python/grpcio_tests/tests_aio/unit/_test_server.py b/src/python/grpcio_tests/tests_aio/unit/_test_server.py index b126a50a6c4..c12eb5f4836 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_test_server.py +++ b/src/python/grpcio_tests/tests_aio/unit/_test_server.py @@ -54,17 +54,17 @@ class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): async def start_test_server(): server = aio.server(options=(('grpc.so_reuseport', 0),)) servicer = _TestServiceServicer() - test_pb2_grpc.add_TestServiceServicer_to_server(servicer, - server) + test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server) # Add programatically extra methods not provided by the proto file # that are used during the tests rpc_method_handlers = { - 'UnaryCallWithSleep': grpc.unary_unary_rpc_method_handler( - servicer.UnaryCallWithSleep, - request_deserializer=messages_pb2.SimpleRequest.FromString, - response_serializer=messages_pb2.SimpleResponse.SerializeToString - ) + 'UnaryCallWithSleep': + grpc.unary_unary_rpc_method_handler( + servicer.UnaryCallWithSleep, + request_deserializer=messages_pb2.SimpleRequest.FromString, + response_serializer=messages_pb2.SimpleResponse. + SerializeToString) } extra_handler = grpc.method_handlers_generic_handler( 'grpc.testing.TestService', rpc_method_handlers) diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py index c079256bb2b..934dc8de95c 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -70,7 +70,8 @@ class TestChannel(AioTestBase): ) with self.assertRaises(grpc.RpcError) as exception_context: - await hi(messages_pb2.SimpleRequest(), timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2) + await hi(messages_pb2.SimpleRequest(), + timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2) _, details = grpc.StatusCode.DEADLINE_EXCEEDED.value # pylint: disable=unused-variable self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, @@ -89,7 +90,8 @@ class TestChannel(AioTestBase): response_deserializer=messages_pb2.SimpleResponse.FromString, ) - call = hi(messages_pb2.SimpleRequest(), timeout=UNARY_CALL_WITH_SLEEP_VALUE * 2) + call = hi(messages_pb2.SimpleRequest(), + timeout=UNARY_CALL_WITH_SLEEP_VALUE * 2) self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_unary_stream(self): diff --git a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py index f0b5fcd4ba8..f461e3fdd65 100644 --- a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py @@ -22,7 +22,6 @@ from tests_aio.unit._test_server import start_test_server, UNARY_CALL_WITH_SLEEP from tests_aio.unit._test_base import AioTestBase from src.proto.grpc.testing import messages_pb2 - _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' @@ -234,8 +233,8 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): interceptor, interceptor_response = Interceptor(), ResponseInterceptor() async with aio.insecure_channel( - self._server_target, interceptors=[interceptor, - interceptor_response]) as channel: + self._server_target, + interceptors=[interceptor, interceptor_response]) as channel: multicallable = channel.unary_unary( '/grpc.testing.TestService/UnaryCall', @@ -277,7 +276,6 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): call = await continuation(client_call_details, request) return call - async with aio.insecure_channel(self._server_target, interceptors=[Interceptor() ]) as channel: @@ -307,7 +305,6 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): await call return call - async with aio.insecure_channel(self._server_target, interceptors=[Interceptor() ]) as channel: @@ -336,7 +333,6 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): call = await continuation(client_call_details, request) return call - async with aio.insecure_channel(self._server_target, interceptors=[Interceptor() ]) as channel: @@ -346,7 +342,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): request_serializer=messages_pb2.SimpleRequest.SerializeToString, response_deserializer=messages_pb2.SimpleResponse.FromString) - call = multicallable(messages_pb2.SimpleRequest(), timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2) + call = multicallable(messages_pb2.SimpleRequest(), + timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2) with self.assertRaises(aio.AioRpcError) as exception_context: await call @@ -369,7 +366,6 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): await call return call - async with aio.insecure_channel(self._server_target, interceptors=[Interceptor() ]) as channel: @@ -379,7 +375,8 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): request_serializer=messages_pb2.SimpleRequest.SerializeToString, response_deserializer=messages_pb2.SimpleResponse.FromString) - call = multicallable(messages_pb2.SimpleRequest(), timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2) + call = multicallable(messages_pb2.SimpleRequest(), + timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2) with self.assertRaises(aio.AioRpcError) as exception_context: await call @@ -406,7 +403,6 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): # This line should never be reached raise Exception() - async with aio.insecure_channel(self._server_target, interceptors=[Interceptor() ]) as channel: @@ -450,7 +446,6 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): # This line should never be reached raise Exception() - async with aio.insecure_channel(self._server_target, interceptors=[Interceptor() ]) as channel: @@ -489,7 +484,6 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): await call return call - async with aio.insecure_channel(self._server_target, interceptors=[Interceptor() ]) as channel: @@ -521,7 +515,6 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): call.cancel() return call - async with aio.insecure_channel(self._server_target, interceptors=[Interceptor() ]) as channel: From 69884f7d848052747c7f1cb58be7cf822c6ea2dd Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Thu, 9 Jan 2020 00:07:15 +0100 Subject: [PATCH 09/12] Remove unused loop parameter --- src/python/grpcio/grpc/experimental/aio/_channel.py | 4 ++-- src/python/grpcio/grpc/experimental/aio/_interceptor.py | 3 +-- src/python/grpcio/grpc/experimental/aio/_utils.py | 4 +--- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 3aa9fc07360..159c74981ea 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -96,7 +96,7 @@ class UnaryUnaryMultiCallable: if not self._interceptors: return UnaryUnaryCall( request, - _timeout_to_deadline(self._loop, timeout), + _timeout_to_deadline(timeout), self._channel, self._method, self._request_serializer, @@ -166,7 +166,7 @@ class UnaryStreamMultiCallable: if compression: raise NotImplementedError("TODO: compression not implemented yet") - deadline = _timeout_to_deadline(self._loop, timeout) + deadline = _timeout_to_deadline(timeout) return UnaryStreamCall( request, diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index dc25f4933c6..e90e90955ae 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -147,8 +147,7 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): else: return UnaryUnaryCall( request, - _timeout_to_deadline(self._loop, - client_call_details.timeout), + _timeout_to_deadline(client_call_details.timeout), self._channel, client_call_details.method, request_serializer, response_deserializer) diff --git a/src/python/grpcio/grpc/experimental/aio/_utils.py b/src/python/grpcio/grpc/experimental/aio/_utils.py index 6a1f81a5ed7..e5772dce2da 100644 --- a/src/python/grpcio/grpc/experimental/aio/_utils.py +++ b/src/python/grpcio/grpc/experimental/aio/_utils.py @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """Internal utilities used by the gRPC Aio module.""" -import asyncio import time from typing import Optional -def _timeout_to_deadline(loop: asyncio.AbstractEventLoop, - timeout: Optional[float]) -> Optional[float]: +def _timeout_to_deadline(timeout: Optional[float]) -> Optional[float]: if timeout is None: return None return time.time() + timeout From aa473fa68aaace21d30619ac3e1c2d0a82724092 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Thu, 9 Jan 2020 00:08:01 +0100 Subject: [PATCH 10/12] Make YAPF happy --- src/python/grpcio/grpc/experimental/aio/_interceptor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index e90e90955ae..e77d72bf404 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -146,8 +146,7 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): else: return UnaryUnaryCall( - request, - _timeout_to_deadline(client_call_details.timeout), + request, _timeout_to_deadline(client_call_details.timeout), self._channel, client_call_details.method, request_serializer, response_deserializer) From bb2c94f0e225dd5259bd1c083ebe0d55bbf0272c Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Fri, 10 Jan 2020 14:52:55 +0100 Subject: [PATCH 11/12] Fix random segfaults when gRPC call resources are released Due to the GC work for breaking reference cycles the internal attributes of the _AioCall object were cleared and initialized with None in aims of releasing refcounts for breaking direct or indirect references. Once the clear was done, the deallocation of the _AioCall would fail since the pointer to gRPC call was cast to an invalid object, no longer was a GrcpWrapper object but a None object, returning a NULL as a value for the attribute call. --- .../grpc/_cython/_cygrpc/aio/call.pxd.pxi | 4 +-- .../grpc/_cython/_cygrpc/aio/call.pyx.pxi | 25 ++++++++----------- .../grpc/_cython/_cygrpc/aio/server.pyx.pxi | 1 + 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi index b800cee6028..fce4d1c2dc9 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi @@ -13,11 +13,10 @@ # limitations under the License. -cdef class _AioCall: +cdef class _AioCall(GrpcCallWrapper): cdef: AioChannel _channel list _references - GrpcCallWrapper _grpc_call_wrapper # Caches the picked event loop, so we can avoid the 30ns overhead each # time we need access to the event loop. object _loop @@ -30,4 +29,3 @@ cdef class _AioCall: bint _is_locally_cancelled cdef grpc_call* _create_grpc_call(self, object timeout, bytes method) except * - cdef void _destroy_grpc_call(self) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi index 2d013afe6cb..cdb109d0592 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi @@ -29,15 +29,16 @@ cdef class _AioCall: AioChannel channel, object deadline, bytes method): + self.call = NULL self._channel = channel self._references = [] - self._grpc_call_wrapper = GrpcCallWrapper() self._loop = asyncio.get_event_loop() self._create_grpc_call(deadline, method) self._is_locally_cancelled = False def __dealloc__(self): - self._destroy_grpc_call() + if self.call: + grpc_call_unref(self.call) def __repr__(self): class_name = self.__class__.__name__ @@ -62,7 +63,7 @@ cdef class _AioCall: method, len(method) ) - self._grpc_call_wrapper.call = grpc_channel_create_call( + self.call = grpc_channel_create_call( self._channel.channel, NULL, _EMPTY_MASK, @@ -74,10 +75,6 @@ cdef class _AioCall: ) grpc_slice_unref(method_slice) - cdef void _destroy_grpc_call(self): - """Destroys the corresponding Core object for this RPC.""" - grpc_call_unref(self._grpc_call_wrapper.call) - def cancel(self, AioRpcStatus status): """Cancels the RPC in Core with given RPC status. @@ -98,7 +95,7 @@ cdef class _AioCall: c_details = details # By implementation, grpc_call_cancel_with_status always return OK error = grpc_call_cancel_with_status( - self._grpc_call_wrapper.call, + self.call, status.c_code(), c_details, NULL, @@ -106,7 +103,7 @@ cdef class _AioCall: assert error == GRPC_CALL_OK else: # By implementation, grpc_call_cancel always return OK - error = grpc_call_cancel(self._grpc_call_wrapper.call, NULL) + error = grpc_call_cancel(self.call, NULL) assert error == GRPC_CALL_OK async def unary_unary(self, @@ -141,7 +138,7 @@ cdef class _AioCall: # Executes all operations in one batch. # Might raise CancelledError, handling it in Python UnaryUnaryCall. - await execute_batch(self._grpc_call_wrapper, + await execute_batch(self, ops, self._loop) @@ -164,7 +161,7 @@ cdef class _AioCall: """Handles the status sent by peer once received.""" cdef ReceiveStatusOnClientOperation op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS) cdef tuple ops = (op,) - await execute_batch(self._grpc_call_wrapper, ops, self._loop) + await execute_batch(self, ops, self._loop) # Halts if the RPC is locally cancelled if self._is_locally_cancelled: @@ -187,7 +184,7 @@ cdef class _AioCall: # * The client application cancels; # * The server sends final status. received_message = await _receive_message( - self._grpc_call_wrapper, + self, self._loop ) return received_message @@ -218,12 +215,12 @@ cdef class _AioCall: ) # Sends out the request message. - await execute_batch(self._grpc_call_wrapper, + await execute_batch(self, outbound_ops, self._loop) # Receives initial metadata. initial_metadata_observer( - await _receive_initial_metadata(self._grpc_call_wrapper, + await _receive_initial_metadata(self, self._loop), ) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index c4da3560ff5..e18134b5e36 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -30,6 +30,7 @@ cdef class _HandlerCallDetails: cdef class RPCState: def __cinit__(self, AioServer server): + self.call = NULL self.server = server grpc_metadata_array_init(&self.request_metadata) grpc_call_details_init(&self.details) From 17928a43c063e9a89254e6d98f345854523eb93b Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Sat, 11 Jan 2020 15:31:05 +0100 Subject: [PATCH 12/12] Apply review feedback, increase timeout threshold --- .../grpcio_tests/tests_aio/unit/channel_test.py | 2 +- .../grpcio_tests/tests_aio/unit/interceptor_test.py | 12 ++++-------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/python/grpcio_tests/tests_aio/unit/channel_test.py b/src/python/grpcio_tests/tests_aio/unit/channel_test.py index 934dc8de95c..3a17b045c8b 100644 --- a/src/python/grpcio_tests/tests_aio/unit/channel_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/channel_test.py @@ -91,7 +91,7 @@ class TestChannel(AioTestBase): ) call = hi(messages_pb2.SimpleRequest(), - timeout=UNARY_CALL_WITH_SLEEP_VALUE * 2) + timeout=UNARY_CALL_WITH_SLEEP_VALUE * 5) self.assertEqual(await call.code(), grpc.StatusCode.OK) async def test_unary_stream(self): diff --git a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py index f461e3fdd65..9970178d0cd 100644 --- a/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/interceptor_test.py @@ -392,16 +392,14 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): async def test_cancel_before_rpc(self): interceptor_reached = asyncio.Event() + wait_for_ever = self.loop.create_future() class Interceptor(aio.UnaryUnaryClientInterceptor): async def intercept_unary_unary(self, continuation, client_call_details, request): interceptor_reached.set() - await asyncio.sleep(0) - - # This line should never be reached - raise Exception() + await wait_for_ever async with aio.insecure_channel(self._server_target, interceptors=[Interceptor() @@ -433,6 +431,7 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): async def test_cancel_after_rpc(self): interceptor_reached = asyncio.Event() + wait_for_ever = self.loop.create_future() class Interceptor(aio.UnaryUnaryClientInterceptor): @@ -441,10 +440,7 @@ class TestInterceptedUnaryUnaryCall(AioTestBase): call = await continuation(client_call_details, request) await call interceptor_reached.set() - await asyncio.sleep(0) - - # This line should never be reached - raise Exception() + await wait_for_ever async with aio.insecure_channel(self._server_target, interceptors=[Interceptor()