diff --git a/src/python/grpcio/grpc/experimental/aio/__init__.py b/src/python/grpcio/grpc/experimental/aio/__init__.py index 634795e7cc3..d0b6a58a149 100644 --- a/src/python/grpcio/grpc/experimental/aio/__init__.py +++ b/src/python/grpcio/grpc/experimental/aio/__init__.py @@ -30,8 +30,10 @@ from ._base_channel import (Channel, StreamStreamMultiCallable, StreamUnaryMultiCallable, UnaryStreamMultiCallable, UnaryUnaryMultiCallable) from ._call import AioRpcError -from ._interceptor import (ClientCallDetails, InterceptedUnaryUnaryCall, - UnaryUnaryClientInterceptor, ServerInterceptor) +from ._interceptor import (ClientCallDetails, ClientInterceptor, + InterceptedUnaryUnaryCall, + UnaryUnaryClientInterceptor, + UnaryStreamClientInterceptor, ServerInterceptor) from ._server import server from ._base_server import Server, ServicerContext from ._typing import ChannelArgumentType @@ -56,6 +58,8 @@ __all__ = ( 'StreamUnaryMultiCallable', 'StreamStreamMultiCallable', 'ClientCallDetails', + 'ClientInterceptor', + 'UnaryStreamClientInterceptor', 'UnaryUnaryClientInterceptor', 'InterceptedUnaryUnaryCall', 'ServerInterceptor', diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 3d1d19fd3fa..00778184658 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -318,6 +318,9 @@ class _StreamResponseMixin(Call): yield message message = await self._read() + # If the read operation failed, Core should explain why. + await self._raise_for_status() + def __aiter__(self) -> AsyncIterable[ResponseType]: self._update_response_style(_APIStyle.ASYNC_GENERATOR) if self._message_aiter is None: diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 5e669e1a3f5..89c556c997e 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -15,7 +15,7 @@ import asyncio import sys -from typing import Any, Iterable, Optional, Sequence +from typing import Any, Iterable, Optional, Sequence, List import grpc from grpc import _common, _compression, _grpcio_metadata @@ -25,7 +25,9 @@ from . import _base_call, _base_channel from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall, UnaryUnaryCall) from ._interceptor import (InterceptedUnaryUnaryCall, - UnaryUnaryClientInterceptor) + InterceptedUnaryStreamCall, ClientInterceptor, + UnaryUnaryClientInterceptor, + UnaryStreamClientInterceptor) from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType, SerializingFunction, RequestIterableType) from ._utils import _timeout_to_deadline @@ -65,7 +67,7 @@ class _BaseMultiCallable: _method: bytes _request_serializer: SerializingFunction _response_deserializer: DeserializingFunction - _interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] + _interceptors: Optional[Sequence[ClientInterceptor]] _loop: asyncio.AbstractEventLoop # pylint: disable=too-many-arguments @@ -75,7 +77,7 @@ class _BaseMultiCallable: method: bytes, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction, - interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]], + interceptors: Optional[Sequence[ClientInterceptor]], loop: asyncio.AbstractEventLoop, ) -> None: self._loop = loop @@ -134,10 +136,17 @@ class UnaryStreamMultiCallable(_BaseMultiCallable, deadline = _timeout_to_deadline(timeout) - call = UnaryStreamCall(request, deadline, metadata, credentials, - wait_for_ready, self._channel, self._method, - self._request_serializer, - self._response_deserializer, self._loop) + if not self._interceptors: + call = UnaryStreamCall(request, deadline, metadata, credentials, + wait_for_ready, self._channel, self._method, + self._request_serializer, + self._response_deserializer, self._loop) + else: + call = InterceptedUnaryStreamCall( + self._interceptors, request, deadline, metadata, credentials, + wait_for_ready, self._channel, self._method, + self._request_serializer, self._response_deserializer, + self._loop) return call @@ -193,12 +202,13 @@ class StreamStreamMultiCallable(_BaseMultiCallable, class Channel(_base_channel.Channel): _loop: asyncio.AbstractEventLoop _channel: cygrpc.AioChannel - _unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] + _unary_unary_interceptors: List[UnaryUnaryClientInterceptor] + _unary_stream_interceptors: List[UnaryStreamClientInterceptor] def __init__(self, target: str, options: ChannelArgumentType, credentials: Optional[grpc.ChannelCredentials], compression: Optional[grpc.Compression], - interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]): + interceptors: Optional[Sequence[ClientInterceptor]]): """Constructor. Args: @@ -210,22 +220,31 @@ class Channel(_base_channel.Channel): interceptors: An optional list of interceptors that would be used for intercepting any RPC executed with that channel. """ - if interceptors is None: - self._unary_unary_interceptors = None - else: - self._unary_unary_interceptors = list( - filter( - lambda interceptor: isinstance(interceptor, - UnaryUnaryClientInterceptor), - interceptors)) + self._unary_unary_interceptors = [] + self._unary_stream_interceptors = [] + + if interceptors: + attrs_and_interceptor_classes = ((self._unary_unary_interceptors, + UnaryUnaryClientInterceptor), + (self._unary_stream_interceptors, + UnaryStreamClientInterceptor)) + + # pylint: disable=cell-var-from-loop + for attr, interceptor_class in attrs_and_interceptor_classes: + attr.extend([ + interceptor for interceptor in interceptors + if isinstance(interceptor, interceptor_class) + ]) invalid_interceptors = set(interceptors) - set( - self._unary_unary_interceptors) + self._unary_unary_interceptors) - set( + self._unary_stream_interceptors) if invalid_interceptors: raise ValueError( "Interceptor must be "+\ - "UnaryUnaryClientInterceptors, the following are invalid: {}"\ + "UnaryUnaryClientInterceptors or "+\ + "UnaryStreamClientInterceptors. The following are invalid: {}"\ .format(invalid_interceptors)) self._loop = asyncio.get_event_loop() @@ -352,7 +371,9 @@ class Channel(_base_channel.Channel): ) -> UnaryStreamMultiCallable: return UnaryStreamMultiCallable(self._channel, _common.encode(method), request_serializer, - response_deserializer, None, self._loop) + response_deserializer, + self._unary_stream_interceptors, + self._loop) def stream_unary( self, @@ -380,7 +401,7 @@ def insecure_channel( target: str, options: Optional[ChannelArgumentType] = None, compression: Optional[grpc.Compression] = None, - interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] = None): + interceptors: Optional[Sequence[ClientInterceptor]] = None): """Creates an insecure asynchronous Channel to a server. Args: @@ -399,12 +420,11 @@ def insecure_channel( compression, interceptors) -def secure_channel( - target: str, - credentials: grpc.ChannelCredentials, - options: Optional[ChannelArgumentType] = None, - compression: Optional[grpc.Compression] = None, - interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] = None): +def secure_channel(target: str, + credentials: grpc.ChannelCredentials, + options: Optional[ChannelArgumentType] = None, + compression: Optional[grpc.Compression] = None, + interceptors: Optional[Sequence[ClientInterceptor]] = None): """Creates a secure asynchronous Channel to a server. Args: diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index d4aca3ae0fc..80d17e04ce9 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -16,13 +16,13 @@ import asyncio import collections import functools from abc import ABCMeta, abstractmethod -from typing import Callable, Optional, Iterator, Sequence, Union, Awaitable +from typing import Callable, Optional, Iterator, Sequence, Union, Awaitable, AsyncIterable import grpc from grpc._cython import cygrpc from . import _base_call -from ._call import UnaryUnaryCall, AioRpcError +from ._call import UnaryUnaryCall, UnaryStreamCall, AioRpcError from ._utils import _timeout_to_deadline from ._typing import (RequestType, SerializingFunction, DeserializingFunction, MetadataType, ResponseType, DoneCallbackType) @@ -84,7 +84,11 @@ class ClientCallDetails( wait_for_ready: Optional[bool] -class UnaryUnaryClientInterceptor(metaclass=ABCMeta): +class ClientInterceptor(metaclass=ABCMeta): + """Base class used for all Aio Client Interceptor classes""" + + +class UnaryUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta): """Affords intercepting unary-unary invocations.""" @abstractmethod @@ -101,8 +105,8 @@ class UnaryUnaryClientInterceptor(metaclass=ABCMeta): actual RPC on the underlying Channel. It is the interceptor's responsibility to call it if it decides to move the RPC forward. The interceptor can use - `response_future = await continuation(client_call_details, request)` - to continue with the RPC. `continuation` returns the response of the + `call = await continuation(client_call_details, request)` + to continue with the RPC. `continuation` returns the call to the RPC. client_call_details: A ClientCallDetails object describing the outgoing RPC. @@ -117,8 +121,41 @@ class UnaryUnaryClientInterceptor(metaclass=ABCMeta): """ -class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): - """Used for running a `UnaryUnaryCall` wrapped by interceptors. +class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta): + """Affords intercepting unary-stream invocations.""" + + @abstractmethod + async def intercept_unary_stream( + self, continuation: Callable[[ClientCallDetails, RequestType], + UnaryStreamCall], + client_call_details: ClientCallDetails, request: RequestType + ) -> Union[AsyncIterable[ResponseType], UnaryStreamCall]: + """Intercepts a unary-stream invocation asynchronously. + + Args: + continuation: A coroutine that proceeds with the invocation by + executing the next interceptor in chain or invoking the + actual RPC on the underlying Channel. It is the interceptor's + responsibility to call it if it decides to move the RPC forward. + The interceptor can use + `call = await continuation(client_call_details, request, response_iterator))` + to continue with the RPC. `continuation` returns the call to the + RPC. + client_call_details: A ClientCallDetails object describing the + outgoing RPC. + request: The request value for the RPC. + + Returns: + The RPC Call. + + Raises: + AioRpcError: Indicating that the RPC terminated with non-OK status. + asyncio.CancelledError: Indicating that the RPC was canceled. + """ + + +class InterceptedCall: + """Base implementation for all intecepted call arities. Interceptors might have some work to do before the RPC invocation with the capacity of changing the invocation parameters, and some work to do @@ -133,103 +170,68 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): intercepted call, being at the same time the same call returned to the interceptors. - For most of the methods, like `initial_metadata()` the caller does not need - to wait until the interceptors task is finished, once the RPC is done the - caller will have the freedom for accessing to the results. - - For the `__await__` method is it is proxied to the intercepted call only when - the interceptor task is finished. + As a base class for all of the interceptors implements the logic around + final status, metadata and cancellation. """ - _loop: asyncio.AbstractEventLoop - _channel: cygrpc.AioChannel - _cancelled_before_rpc: bool - _intercepted_call: Optional[_base_call.UnaryUnaryCall] - _intercepted_call_created: asyncio.Event _interceptors_task: asyncio.Task _pending_add_done_callbacks: Sequence[DoneCallbackType] - # pylint: disable=too-many-arguments - def __init__(self, interceptors: Sequence[UnaryUnaryClientInterceptor], - request: RequestType, timeout: Optional[float], - metadata: MetadataType, - credentials: Optional[grpc.CallCredentials], - wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, - method: bytes, request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction, - loop: asyncio.AbstractEventLoop) -> None: - self._channel = channel - self._loop = loop - self._interceptors_task = loop.create_task( - self._invoke(interceptors, method, timeout, metadata, credentials, - wait_for_ready, request, request_serializer, - response_deserializer)) + def __init__(self, interceptors_task: asyncio.Task) -> None: + self._interceptors_task = interceptors_task self._pending_add_done_callbacks = [] self._interceptors_task.add_done_callback( - self._fire_pending_add_done_callbacks) + self._fire_or_add_pending_done_callbacks) def __del__(self): self.cancel() - # pylint: disable=too-many-arguments - async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor], - method: bytes, timeout: Optional[float], - metadata: Optional[MetadataType], - credentials: Optional[grpc.CallCredentials], - wait_for_ready: Optional[bool], request: RequestType, - request_serializer: SerializingFunction, - response_deserializer: DeserializingFunction - ) -> UnaryUnaryCall: - """Run the RPC call wrapped in interceptors""" - - async def _run_interceptor( - interceptors: Iterator[UnaryUnaryClientInterceptor], - client_call_details: ClientCallDetails, - request: RequestType) -> _base_call.UnaryUnaryCall: - - interceptor = next(interceptors, None) - - if interceptor: - continuation = functools.partial(_run_interceptor, interceptors) + def _fire_or_add_pending_done_callbacks(self, + interceptors_task: asyncio.Task + ) -> None: - call_or_response = await interceptor.intercept_unary_unary( - continuation, client_call_details, request) - - if isinstance(call_or_response, _base_call.UnaryUnaryCall): - return call_or_response - else: - return UnaryUnaryCallResponse(call_or_response) + if not self._pending_add_done_callbacks: + return - else: - return UnaryUnaryCall( - request, _timeout_to_deadline(client_call_details.timeout), - client_call_details.metadata, - client_call_details.credentials, - client_call_details.wait_for_ready, self._channel, - client_call_details.method, request_serializer, - response_deserializer, self._loop) + call_completed = False - client_call_details = ClientCallDetails(method, timeout, metadata, - credentials, wait_for_ready) - return await _run_interceptor(iter(interceptors), client_call_details, - request) + try: + call = interceptors_task.result() + if call.done(): + call_completed = True + except (AioRpcError, asyncio.CancelledError): + call_completed = True - def _fire_pending_add_done_callbacks(self, - unused_task: asyncio.Task) -> None: - for callback in self._pending_add_done_callbacks: - callback(self) + if call_completed: + for callback in self._pending_add_done_callbacks: + callback(self) + else: + for callback in self._pending_add_done_callbacks: + callback = functools.partial(self._wrap_add_done_callback, + callback) + call.add_done_callback(callback) self._pending_add_done_callbacks = [] def _wrap_add_done_callback(self, callback: DoneCallbackType, - unused_task: asyncio.Task) -> None: + unused_call: _base_call.Call) -> None: callback(self) def cancel(self) -> bool: - if self._interceptors_task.done(): + if not self._interceptors_task.done(): + # There is no yet the intercepted call available, + # Trying to cancel it by using the generic Asyncio + # cancellation method. + return self._interceptors_task.cancel() + + try: + call = self._interceptors_task.result() + except AioRpcError: + return False + except asyncio.CancelledError: return False - return self._interceptors_task.cancel() + return call.cancel() def cancelled(self) -> bool: if not self._interceptors_task.done(): @@ -270,7 +272,7 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): callback(self) else: callback = functools.partial(self._wrap_add_done_callback, callback) - call.add_done_callback(self._wrap_add_done_callback) + call.add_done_callback(callback) def time_remaining(self) -> Optional[float]: raise NotImplementedError() @@ -325,14 +327,181 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): return await call.debug_error_string() + async def wait_for_connection(self) -> None: + call = await self._interceptors_task + return await call.wait_for_connection() + + +class InterceptedUnaryUnaryCall(InterceptedCall, _base_call.UnaryUnaryCall): + """Used for running a `UnaryUnaryCall` wrapped by interceptors. + + For the `__await__` method is it is proxied to the intercepted call only when + the interceptor task is finished. + """ + + _loop: asyncio.AbstractEventLoop + _channel: cygrpc.AioChannel + + # pylint: disable=too-many-arguments + def __init__(self, interceptors: Sequence[UnaryUnaryClientInterceptor], + request: RequestType, timeout: Optional[float], + metadata: MetadataType, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, + method: bytes, request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop + self._channel = channel + interceptors_task = loop.create_task( + self._invoke(interceptors, method, timeout, metadata, credentials, + wait_for_ready, request, request_serializer, + response_deserializer)) + super().__init__(interceptors_task) + + # pylint: disable=too-many-arguments + async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor], + method: bytes, timeout: Optional[float], + metadata: Optional[MetadataType], + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], request: RequestType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction + ) -> UnaryUnaryCall: + """Run the RPC call wrapped in interceptors""" + + async def _run_interceptor( + interceptors: Iterator[UnaryUnaryClientInterceptor], + client_call_details: ClientCallDetails, + request: RequestType) -> _base_call.UnaryUnaryCall: + + interceptor = next(interceptors, None) + + if interceptor: + continuation = functools.partial(_run_interceptor, interceptors) + + call_or_response = await interceptor.intercept_unary_unary( + continuation, client_call_details, request) + + if isinstance(call_or_response, _base_call.UnaryUnaryCall): + return call_or_response + else: + return UnaryUnaryCallResponse(call_or_response) + + else: + return UnaryUnaryCall( + request, _timeout_to_deadline(client_call_details.timeout), + client_call_details.metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, self._channel, + client_call_details.method, request_serializer, + response_deserializer, self._loop) + + client_call_details = ClientCallDetails(method, timeout, metadata, + credentials, wait_for_ready) + return await _run_interceptor(iter(interceptors), client_call_details, + request) + def __await__(self): call = yield from self._interceptors_task.__await__() response = yield from call.__await__() return response - async def wait_for_connection(self) -> None: + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + +class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall): + """Used for running a `UnaryStreamCall` wrapped by interceptors.""" + + _loop: asyncio.AbstractEventLoop + _channel: cygrpc.AioChannel + _response_aiter: AsyncIterable[ResponseType] + _last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall] + + # pylint: disable=too-many-arguments + def __init__(self, interceptors: Sequence[UnaryStreamClientInterceptor], + request: RequestType, timeout: Optional[float], + metadata: MetadataType, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, + method: bytes, request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop + self._channel = channel + self._response_aiter = self._wait_for_interceptor_task_response_iterator( + ) + self._last_returned_call_from_interceptors = None + interceptors_task = loop.create_task( + self._invoke(interceptors, method, timeout, metadata, credentials, + wait_for_ready, request, request_serializer, + response_deserializer)) + super().__init__(interceptors_task) + + # pylint: disable=too-many-arguments + async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor], + method: bytes, timeout: Optional[float], + metadata: Optional[MetadataType], + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], request: RequestType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction + ) -> UnaryStreamCall: + """Run the RPC call wrapped in interceptors""" + + async def _run_interceptor( + interceptors: Iterator[UnaryStreamClientInterceptor], + client_call_details: ClientCallDetails, + request: RequestType, + ) -> _base_call.UnaryUnaryCall: + + interceptor = next(interceptors, None) + + if interceptor: + continuation = functools.partial(_run_interceptor, interceptors) + + call_or_response_iterator = await interceptor.intercept_unary_stream( + continuation, client_call_details, request) + + if isinstance(call_or_response_iterator, + _base_call.UnaryUnaryCall): + self._last_returned_call_from_interceptors = call_or_response_iterator + else: + self._last_returned_call_from_interceptors = UnaryStreamCallResponseIterator( + self._last_returned_call_from_interceptors, + call_or_response_iterator) + return self._last_returned_call_from_interceptors + else: + self._last_returned_call_from_interceptors = UnaryStreamCall( + request, _timeout_to_deadline(client_call_details.timeout), + client_call_details.metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, self._channel, + client_call_details.method, request_serializer, + response_deserializer, self._loop) + + return self._last_returned_call_from_interceptors + + client_call_details = ClientCallDetails(method, timeout, metadata, + credentials, wait_for_ready) + return await _run_interceptor(iter(interceptors), client_call_details, + request) + + async def _wait_for_interceptor_task_response_iterator(self + ) -> ResponseType: call = await self._interceptors_task - return await call.wait_for_connection() + async for response in call: + yield response + + def __aiter__(self) -> AsyncIterable[ResponseType]: + return self._response_aiter + + async def read(self) -> ResponseType: + return await self._response_aiter.asend(None) + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): @@ -381,3 +550,55 @@ class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): async def wait_for_connection(self) -> None: pass + + +class UnaryStreamCallResponseIterator(_base_call.UnaryStreamCall): + """UnaryStreamCall class wich uses an alternative response iterator.""" + _call: _base_call.UnaryStreamCall + _response_iterator: AsyncIterable[ResponseType] + + def __init__(self, call: _base_call.UnaryStreamCall, + response_iterator: AsyncIterable[ResponseType]) -> None: + self._response_iterator = response_iterator + self._call = call + + def cancel(self) -> bool: + return self._call.cancel() + + def cancelled(self) -> bool: + return self._call.cancelled() + + def done(self) -> bool: + return self._call.done() + + def add_done_callback(self, callback) -> None: + self._call.add_done_callback(callback) + + def time_remaining(self) -> Optional[float]: + return self._call.time_remaining() + + async def initial_metadata(self) -> Optional[MetadataType]: + return await self._call.initial_metadata() + + async def trailing_metadata(self) -> Optional[MetadataType]: + return await self._call.trailing_metadata() + + async def code(self) -> grpc.StatusCode: + return await self._call.code() + + async def details(self) -> str: + return await self._call.details() + + async def debug_error_string(self) -> Optional[str]: + return await self._call.debug_error_string() + + def __aiter__(self): + return self._response_iterator.__aiter__() + + async def wait_for_connection(self) -> None: + return await self._call.wait_for_connection() + + async def read(self) -> ResponseType: + # Behind the scenes everyting goes through the + # async iterator. So this path should not be reached. + raise Exception() diff --git a/src/python/grpcio_tests/tests_aio/health_check/health_servicer_test.py b/src/python/grpcio_tests/tests_aio/health_check/health_servicer_test.py index ac6c84b2f54..a539dbf1409 100644 --- a/src/python/grpcio_tests/tests_aio/health_check/health_servicer_test.py +++ b/src/python/grpcio_tests/tests_aio/health_check/health_servicer_test.py @@ -108,7 +108,10 @@ class HealthServicerTest(AioTestBase): (await queue.get()).status) call.cancel() - await task + + with self.assertRaises(asyncio.CancelledError): + await task + self.assertTrue(queue.empty()) async def test_watch_new_service(self): @@ -131,7 +134,10 @@ class HealthServicerTest(AioTestBase): (await queue.get()).status) call.cancel() - await task + + with self.assertRaises(asyncio.CancelledError): + await task + self.assertTrue(queue.empty()) async def test_watch_service_isolation(self): @@ -151,7 +157,10 @@ class HealthServicerTest(AioTestBase): await asyncio.wait_for(queue.get(), test_constants.SHORT_TIMEOUT) call.cancel() - await task + + with self.assertRaises(asyncio.CancelledError): + await task + self.assertTrue(queue.empty()) async def test_two_watchers(self): @@ -177,8 +186,13 @@ class HealthServicerTest(AioTestBase): call1.cancel() call2.cancel() - await task1 - await task2 + + with self.assertRaises(asyncio.CancelledError): + await task1 + + with self.assertRaises(asyncio.CancelledError): + await task2 + self.assertTrue(queue1.empty()) self.assertTrue(queue2.empty()) @@ -194,7 +208,9 @@ class HealthServicerTest(AioTestBase): call.cancel() await self._servicer.set(_WATCH_SERVICE, health_pb2.HealthCheckResponse.SERVING) - await task + + with self.assertRaises(asyncio.CancelledError): + await task # Wait for the serving coroutine to process client cancellation. timeout = time.monotonic() + test_constants.TIME_ALLOWANCE @@ -226,7 +242,10 @@ class HealthServicerTest(AioTestBase): resp.status) call.cancel() - await task + + with self.assertRaises(asyncio.CancelledError): + await task + self.assertTrue(queue.empty()) async def test_no_duplicate_status(self): @@ -251,7 +270,10 @@ class HealthServicerTest(AioTestBase): last_status = status call.cancel() - await task + + with self.assertRaises(asyncio.CancelledError): + await task + self.assertTrue(queue.empty()) diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index 71f8733f5f9..0bdd1f72e50 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -13,8 +13,9 @@ "unit.channel_argument_test.TestChannelArgument", "unit.channel_ready_test.TestChannelReady", "unit.channel_test.TestChannel", - "unit.client_interceptor_test.TestInterceptedUnaryUnaryCall", - "unit.client_interceptor_test.TestUnaryUnaryClientInterceptor", + "unit.client_unary_stream_interceptor_test.TestUnaryStreamClientInterceptor", + "unit.client_unary_unary_interceptor_test.TestInterceptedUnaryUnaryCall", + "unit.client_unary_unary_interceptor_test.TestUnaryUnaryClientInterceptor", "unit.close_channel_test.TestCloseChannel", "unit.compatibility_test.TestCompatibility", "unit.compression_test.TestCompression", diff --git a/src/python/grpcio_tests/tests_aio/unit/_common.py b/src/python/grpcio_tests/tests_aio/unit/_common.py index 1b5a4d909fa..97cbe759ed0 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_common.py +++ b/src/python/grpcio_tests/tests_aio/unit/_common.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import grpc from grpc.experimental import aio from grpc.experimental.aio._typing import MetadataType, MetadatumType +from tests.unit.framework.common import test_constants + def seen_metadata(expected: MetadataType, actual: MetadataType): return not bool(set(expected) - set(actual)) @@ -32,3 +35,32 @@ async def block_until_certain_state(channel: aio.Channel, while state != expected_state: await channel.wait_for_state_change(state) state = channel.get_state() + + +def inject_callbacks(call): + first_callback_ran = asyncio.Event() + + def first_callback(call): + # Validate that all resopnses have been received + # and the call is an end state. + assert call.done() + first_callback_ran.set() + + second_callback_ran = asyncio.Event() + + def second_callback(call): + # Validate that all resopnses have been received + # and the call is an end state. + assert call.done() + second_callback_ran.set() + + call.add_done_callback(first_callback) + call.add_done_callback(second_callback) + + async def validation(): + await asyncio.wait_for( + asyncio.gather(first_callback_ran.wait(), + second_callback_ran.wait()), + test_constants.SHORT_TIMEOUT) + + return validation() diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index 2548e777783..3ce2a2f7b52 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -217,6 +217,23 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase): class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase): + async def test_call_rpc_error(self): + channel = aio.insecure_channel(UNREACHABLE_TARGET) + request = messages_pb2.StreamingOutputCallRequest() + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + with self.assertRaises(aio.AioRpcError) as exception_context: + async for response in call: + pass + + self.assertEqual(grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code()) + + self.assertTrue(call.done()) + self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) + await channel.close() + async def test_cancel_unary_stream(self): # Prepares the request request = messages_pb2.StreamingOutputCallRequest() @@ -550,7 +567,6 @@ class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase): cancel_later_task = self.loop.create_task(cancel_later()) - # No exceptions here with self.assertRaises(asyncio.CancelledError): await call @@ -772,9 +788,10 @@ class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase): cancel_later_task = self.loop.create_task(cancel_later()) - # No exceptions here - async for response in call: - self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + with self.assertRaises(asyncio.CancelledError): + async for response in call: + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, + len(response.payload.body)) await request_iterator_received_the_exception.wait() diff --git a/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py new file mode 100644 index 00000000000..cfb3ead226f --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py @@ -0,0 +1,409 @@ +# Copyright 2020 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import unittest +import datetime + +import grpc + +from grpc.experimental import aio +from tests_aio.unit._constants import UNREACHABLE_TARGET +from tests_aio.unit._common import inject_callbacks +from tests_aio.unit._test_server import start_test_server +from tests_aio.unit._test_base import AioTestBase +from tests.unit.framework.common import test_constants +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc + +_SHORT_TIMEOUT_S = 1.0 + +_NUM_STREAM_RESPONSES = 5 +_REQUEST_PAYLOAD_SIZE = 7 +_RESPONSE_PAYLOAD_SIZE = 7 +_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000) + + +class _CountingResponseIterator: + + def __init__(self, response_iterator): + self.response_cnt = 0 + self._response_iterator = response_iterator + + async def _forward_responses(self): + async for response in self._response_iterator: + self.response_cnt += 1 + yield response + + def __aiter__(self): + return self._forward_responses() + + +class _UnaryStreamInterceptorEmpty(aio.UnaryStreamClientInterceptor): + + async def intercept_unary_stream(self, continuation, client_call_details, + request): + return await continuation(client_call_details, request) + + def assert_in_final_state(self, test: unittest.TestCase): + pass + + +class _UnaryStreamInterceptorWithResponseIterator( + aio.UnaryStreamClientInterceptor): + + async def intercept_unary_stream(self, continuation, client_call_details, + request): + call = await continuation(client_call_details, request) + self.response_iterator = _CountingResponseIterator(call) + return self.response_iterator + + def assert_in_final_state(self, test: unittest.TestCase): + test.assertEqual(_NUM_STREAM_RESPONSES, + self.response_iterator.response_cnt) + + +class TestUnaryStreamClientInterceptor(AioTestBase): + + async def setUp(self): + self._server_target, self._server = await start_test_server() + + async def tearDown(self): + await self._server.stop(None) + + async def test_intercepts(self): + for interceptor_class in (_UnaryStreamInterceptorEmpty, + _UnaryStreamInterceptorWithResponseIterator): + + with self.subTest(name=interceptor_class): + interceptor = interceptor_class() + + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.extend([ + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ] * _NUM_STREAM_RESPONSES) + + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + await call.wait_for_connection() + + response_cnt = 0 + async for response in call: + response_cnt += 1 + self.assertIs(type(response), + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, + len(response.payload.body)) + + self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.initial_metadata(), ()) + self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(call.cancel(), False) + self.assertEqual(call.cancelled(), False) + self.assertEqual(call.done(), True) + + interceptor.assert_in_final_state(self) + + await channel.close() + + async def test_add_done_callback_interceptor_task_not_finished(self): + for interceptor_class in (_UnaryStreamInterceptorEmpty, + _UnaryStreamInterceptorWithResponseIterator): + + with self.subTest(name=interceptor_class): + interceptor = interceptor_class() + + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.extend([ + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ] * _NUM_STREAM_RESPONSES) + + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + validation = inject_callbacks(call) + + async for response in call: + pass + + await validation + + await channel.close() + + async def test_add_done_callback_interceptor_task_finished(self): + for interceptor_class in (_UnaryStreamInterceptorEmpty, + _UnaryStreamInterceptorWithResponseIterator): + + with self.subTest(name=interceptor_class): + interceptor = interceptor_class() + + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.extend([ + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ] * _NUM_STREAM_RESPONSES) + + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + # This ensures that the callbacks will be registered + # with the intercepted call rather than saving in the + # pending state list. + await call.wait_for_connection() + + validation = inject_callbacks(call) + + async for response in call: + pass + + await validation + + await channel.close() + + async def test_response_iterator_using_read(self): + interceptor = _UnaryStreamInterceptorWithResponseIterator() + + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.extend( + [messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)] * + _NUM_STREAM_RESPONSES) + + call = stub.StreamingOutputCall(request) + + response_cnt = 0 + for response in range(_NUM_STREAM_RESPONSES): + response = await call.read() + response_cnt += 1 + self.assertIs(type(response), + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) + + self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) + self.assertEqual(interceptor.response_iterator.response_cnt, + _NUM_STREAM_RESPONSES) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + await channel.close() + + async def test_multiple_interceptors_response_iterator(self): + for interceptor_class in (_UnaryStreamInterceptorEmpty, + _UnaryStreamInterceptorWithResponseIterator): + + with self.subTest(name=interceptor_class): + + interceptors = [interceptor_class(), interceptor_class()] + + channel = aio.insecure_channel(self._server_target, + interceptors=interceptors) + stub = test_pb2_grpc.TestServiceStub(channel) + + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.extend([ + messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE) + ] * _NUM_STREAM_RESPONSES) + + call = stub.StreamingOutputCall(request) + + response_cnt = 0 + async for response in call: + response_cnt += 1 + self.assertIs(type(response), + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, + len(response.payload.body)) + + self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + + await channel.close() + + async def test_intercepts_response_iterator_rpc_error(self): + for interceptor_class in (_UnaryStreamInterceptorEmpty, + _UnaryStreamInterceptorWithResponseIterator): + + with self.subTest(name=interceptor_class): + + channel = aio.insecure_channel( + UNREACHABLE_TARGET, interceptors=[interceptor_class()]) + request = messages_pb2.StreamingOutputCallRequest() + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + with self.assertRaises(aio.AioRpcError) as exception_context: + async for response in call: + pass + + self.assertEqual(grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code()) + + self.assertTrue(call.done()) + self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) + await channel.close() + + async def test_cancel_before_rpc(self): + + interceptor_reached = asyncio.Event() + wait_for_ever = self.loop.create_future() + + class Interceptor(aio.UnaryStreamClientInterceptor): + + async def intercept_unary_stream(self, continuation, + client_call_details, request): + interceptor_reached.set() + await wait_for_ever + + channel = aio.insecure_channel(UNREACHABLE_TARGET, + interceptors=[Interceptor()]) + request = messages_pb2.StreamingOutputCallRequest() + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + self.assertFalse(call.cancelled()) + self.assertFalse(call.done()) + + await interceptor_reached.wait() + self.assertTrue(call.cancel()) + + with self.assertRaises(asyncio.CancelledError): + async for response in call: + pass + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + await channel.close() + + async def test_cancel_after_rpc(self): + + interceptor_reached = asyncio.Event() + wait_for_ever = self.loop.create_future() + + class Interceptor(aio.UnaryStreamClientInterceptor): + + async def intercept_unary_stream(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + interceptor_reached.set() + await wait_for_ever + + channel = aio.insecure_channel(UNREACHABLE_TARGET, + interceptors=[Interceptor()]) + request = messages_pb2.StreamingOutputCallRequest() + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + self.assertFalse(call.cancelled()) + self.assertFalse(call.done()) + + await interceptor_reached.wait() + self.assertTrue(call.cancel()) + + with self.assertRaises(asyncio.CancelledError): + async for response in call: + pass + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + await channel.close() + + async def test_cancel_consuming_response_iterator(self): + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.extend( + [messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)] * + _NUM_STREAM_RESPONSES) + + channel = aio.insecure_channel( + self._server_target, + interceptors=[_UnaryStreamInterceptorWithResponseIterator()]) + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + with self.assertRaises(asyncio.CancelledError): + async for response in call: + call.cancel() + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + await channel.close() + + async def test_cancel_by_the_interceptor(self): + + class Interceptor(aio.UnaryStreamClientInterceptor): + + async def intercept_unary_stream(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + call.cancel() + return call + + channel = aio.insecure_channel(UNREACHABLE_TARGET, + interceptors=[Interceptor()]) + request = messages_pb2.StreamingOutputCallRequest() + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + with self.assertRaises(asyncio.CancelledError): + async for response in call: + pass + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + await channel.close() + + async def test_exception_raised_by_interceptor(self): + + class InterceptorException(Exception): + pass + + class Interceptor(aio.UnaryStreamClientInterceptor): + + async def intercept_unary_stream(self, continuation, + client_call_details, request): + raise InterceptorException + + channel = aio.insecure_channel(UNREACHABLE_TARGET, + interceptors=[Interceptor()]) + request = messages_pb2.StreamingOutputCallRequest() + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + with self.assertRaises(InterceptorException): + async for response in call: + pass + + await channel.close() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/client_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py similarity index 100% rename from src/python/grpcio_tests/tests_aio/unit/client_interceptor_test.py rename to src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py diff --git a/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py b/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py index a312e45711f..481bafd5679 100644 --- a/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py @@ -21,6 +21,7 @@ import gc import grpc from grpc.experimental import aio +from tests_aio.unit._common import inject_callbacks from tests_aio.unit._test_base import AioTestBase from tests.unit.framework.common import test_constants from src.proto.grpc.testing import messages_pb2, test_pb2_grpc @@ -31,29 +32,6 @@ _REQUEST_PAYLOAD_SIZE = 7 _RESPONSE_PAYLOAD_SIZE = 42 -def _inject_callbacks(call): - first_callback_ran = asyncio.Event() - - def first_callback(unused_call): - first_callback_ran.set() - - second_callback_ran = asyncio.Event() - - def second_callback(unused_call): - second_callback_ran.set() - - call.add_done_callback(first_callback) - call.add_done_callback(second_callback) - - async def validation(): - await asyncio.wait_for( - asyncio.gather(first_callback_ran.wait(), - second_callback_ran.wait()), - test_constants.SHORT_TIMEOUT) - - return validation() - - class TestDoneCallback(AioTestBase): async def setUp(self): @@ -69,12 +47,12 @@ class TestDoneCallback(AioTestBase): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) self.assertEqual(grpc.StatusCode.OK, await call.code()) - validation = _inject_callbacks(call) + validation = inject_callbacks(call) await validation async def test_unary_unary(self): call = self._stub.UnaryCall(messages_pb2.SimpleRequest()) - validation = _inject_callbacks(call) + validation = inject_callbacks(call) self.assertEqual(grpc.StatusCode.OK, await call.code()) @@ -87,7 +65,7 @@ class TestDoneCallback(AioTestBase): messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) call = self._stub.StreamingOutputCall(request) - validation = _inject_callbacks(call) + validation = inject_callbacks(call) response_cnt = 0 async for response in call: @@ -110,7 +88,7 @@ class TestDoneCallback(AioTestBase): yield request call = self._stub.StreamingInputCall(gen()) - validation = _inject_callbacks(call) + validation = inject_callbacks(call) response = await call self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse) @@ -122,7 +100,7 @@ class TestDoneCallback(AioTestBase): async def test_stream_stream(self): call = self._stub.FullDuplexCall() - validation = _inject_callbacks(call) + validation = inject_callbacks(call) request = messages_pb2.StreamingOutputCallRequest() request.response_parameters.append(