diff --git a/src/python/grpcio/grpc/experimental/aio/__init__.py b/src/python/grpcio/grpc/experimental/aio/__init__.py index d0b6a58a149..c0fc17eeb92 100644 --- a/src/python/grpcio/grpc/experimental/aio/__init__.py +++ b/src/python/grpcio/grpc/experimental/aio/__init__.py @@ -33,7 +33,8 @@ from ._call import AioRpcError from ._interceptor import (ClientCallDetails, ClientInterceptor, InterceptedUnaryUnaryCall, UnaryUnaryClientInterceptor, - UnaryStreamClientInterceptor, ServerInterceptor) + UnaryStreamClientInterceptor, + StreamUnaryClientInterceptor, ServerInterceptor) from ._server import server from ._base_server import Server, ServicerContext from ._typing import ChannelArgumentType @@ -61,6 +62,7 @@ __all__ = ( 'ClientInterceptor', 'UnaryStreamClientInterceptor', 'UnaryUnaryClientInterceptor', + 'StreamUnaryClientInterceptor', 'InterceptedUnaryUnaryCall', 'ServerInterceptor', 'insecure_channel', diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index 00778184658..a0693921461 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -35,6 +35,7 @@ _LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!' _GC_CANCELLATION_DETAILS = 'Cancelled upon garbage collection!' _RPC_ALREADY_FINISHED_DETAILS = 'RPC already finished.' _RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".' +_API_STYLE_ERROR = 'The iterator and read/write APIs may not be mixed on a single RPC.' _OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n' '\tstatus = {}\n' @@ -302,8 +303,7 @@ class _StreamResponseMixin(Call): if self._response_style is _APIStyle.UNKNOWN: self._response_style = style elif self._response_style is not style: - raise cygrpc.UsageError( - 'Please don\'t mix two styles of API for streaming responses') + raise cygrpc.UsageError(_API_STYLE_ERROR) def cancel(self) -> bool: if super().cancel(): @@ -381,8 +381,7 @@ class _StreamRequestMixin(Call): def _raise_for_different_style(self, style: _APIStyle): if self._request_style is not style: - raise cygrpc.UsageError( - 'Please don\'t mix two styles of API for streaming requests') + raise cygrpc.UsageError(_API_STYLE_ERROR) def cancel(self) -> bool: if super().cancel(): @@ -399,7 +398,8 @@ class _StreamRequestMixin(Call): request_iterator: RequestIterableType ) -> None: try: - if inspect.isasyncgen(request_iterator): + if inspect.isasyncgen(request_iterator) or hasattr( + request_iterator, '__aiter__'): async for request in request_iterator: await self._write(request) else: @@ -426,7 +426,6 @@ class _StreamRequestMixin(Call): serialized_request = _common.serialize(request, self._request_serializer) - try: await self._cython_call.send_serialized_message(serialized_request) except asyncio.CancelledError: diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index afa9cf30630..97ffa833bbf 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -25,9 +25,11 @@ from . import _base_call, _base_channel from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall, UnaryUnaryCall) from ._interceptor import (InterceptedUnaryUnaryCall, - InterceptedUnaryStreamCall, ClientInterceptor, + InterceptedUnaryStreamCall, + InterceptedStreamUnaryCall, ClientInterceptor, UnaryUnaryClientInterceptor, - UnaryStreamClientInterceptor) + UnaryStreamClientInterceptor, + StreamUnaryClientInterceptor) from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType, SerializingFunction, RequestIterableType) from ._utils import _timeout_to_deadline @@ -167,10 +169,17 @@ class StreamUnaryMultiCallable(_BaseMultiCallable, deadline = _timeout_to_deadline(timeout) - call = StreamUnaryCall(request_iterator, deadline, metadata, - credentials, wait_for_ready, self._channel, - self._method, self._request_serializer, - self._response_deserializer, self._loop) + if not self._interceptors: + call = StreamUnaryCall(request_iterator, deadline, metadata, + credentials, wait_for_ready, self._channel, + self._method, self._request_serializer, + self._response_deserializer, self._loop) + else: + call = InterceptedStreamUnaryCall( + self._interceptors, request_iterator, deadline, metadata, + credentials, wait_for_ready, self._channel, self._method, + self._request_serializer, self._response_deserializer, + self._loop) return call @@ -204,6 +213,7 @@ class Channel(_base_channel.Channel): _channel: cygrpc.AioChannel _unary_unary_interceptors: List[UnaryUnaryClientInterceptor] _unary_stream_interceptors: List[UnaryStreamClientInterceptor] + _stream_unary_interceptors: List[StreamUnaryClientInterceptor] def __init__(self, target: str, options: ChannelArgumentType, credentials: Optional[grpc.ChannelCredentials], @@ -222,12 +232,15 @@ class Channel(_base_channel.Channel): """ self._unary_unary_interceptors = [] self._unary_stream_interceptors = [] + self._stream_unary_interceptors = [] if interceptors: attrs_and_interceptor_classes = ((self._unary_unary_interceptors, UnaryUnaryClientInterceptor), (self._unary_stream_interceptors, - UnaryStreamClientInterceptor)) + UnaryStreamClientInterceptor), + (self._stream_unary_interceptors, + StreamUnaryClientInterceptor)) # pylint: disable=cell-var-from-loop for attr, interceptor_class in attrs_and_interceptor_classes: @@ -238,14 +251,17 @@ class Channel(_base_channel.Channel): invalid_interceptors = set(interceptors) - set( self._unary_unary_interceptors) - set( - self._unary_stream_interceptors) + self._unary_stream_interceptors) - set( + self._stream_unary_interceptors) if invalid_interceptors: raise ValueError( - "Interceptor must be "+\ - "UnaryUnaryClientInterceptors or "+\ - "UnaryStreamClientInterceptors. The following are invalid: {}"\ - .format(invalid_interceptors)) + "Interceptor must be " + + "{} or ".format(UnaryUnaryClientInterceptor.__name__) + + "{} or ".format(UnaryStreamClientInterceptor.__name__) + + "{}. ".format(StreamUnaryClientInterceptor.__name__) + + "The following are invalid: {}".format(invalid_interceptors) + ) self._loop = asyncio.get_event_loop() self._channel = cygrpc.AioChannel( @@ -383,7 +399,9 @@ class Channel(_base_channel.Channel): ) -> StreamUnaryMultiCallable: return StreamUnaryMultiCallable(self._channel, _common.encode(method), request_serializer, - response_deserializer, None, self._loop) + response_deserializer, + self._stream_unary_interceptors, + self._loop) def stream_stream( self, diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index b9f786e5522..e4969ddb4a5 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -22,10 +22,13 @@ import grpc from grpc._cython import cygrpc from . import _base_call -from ._call import UnaryUnaryCall, UnaryStreamCall, AioRpcError +from ._call import UnaryUnaryCall, UnaryStreamCall, StreamUnaryCall, AioRpcError +from ._call import _RPC_ALREADY_FINISHED_DETAILS, _RPC_HALF_CLOSED_DETAILS +from ._call import _API_STYLE_ERROR from ._utils import _timeout_to_deadline from ._typing import (RequestType, SerializingFunction, DeserializingFunction, - MetadataType, ResponseType, DoneCallbackType) + MetadataType, ResponseType, DoneCallbackType, + RequestIterableType) _LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!' @@ -101,7 +104,7 @@ class UnaryUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta): Args: continuation: A coroutine that proceeds with the invocation by - executing the next interceptor in chain or invoking the + executing the next interceptor in the chain or invoking the actual RPC on the underlying Channel. It is the interceptor's responsibility to call it if it decides to move the RPC forward. The interceptor can use @@ -132,13 +135,17 @@ class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta): ) -> Union[AsyncIterable[ResponseType], UnaryStreamCall]: """Intercepts a unary-stream invocation asynchronously. + The function could return the call object or an asynchronous + iterator, in case of being an asyncrhonous iterator this will + become the source of the reads done by the caller. + Args: continuation: A coroutine that proceeds with the invocation by - executing the next interceptor in chain or invoking the + executing the next interceptor in the chain or invoking the actual RPC on the underlying Channel. It is the interceptor's responsibility to call it if it decides to move the RPC forward. The interceptor can use - `call = await continuation(client_call_details, request, response_iterator))` + `call = await continuation(client_call_details, request)` to continue with the RPC. `continuation` returns the call to the RPC. client_call_details: A ClientCallDetails object describing the @@ -154,6 +161,47 @@ class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta): """ +class StreamUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta): + """Affords intercepting stream-unary invocations.""" + + @abstractmethod + async def intercept_stream_unary( + self, + continuation: Callable[[ClientCallDetails, RequestType], + UnaryStreamCall], + client_call_details: ClientCallDetails, + request_iterator: RequestIterableType, + ) -> StreamUnaryCall: + """Intercepts a stream-unary invocation asynchronously. + + Within the interceptor the usage of the call methods like `write` or + even awaiting the call should be done carefully, since the caller + could be expecting an untouched call, for example for start writing + messages to it. + + Args: + continuation: A coroutine that proceeds with the invocation by + executing the next interceptor in the chain or invoking the + actual RPC on the underlying Channel. It is the interceptor's + responsibility to call it if it decides to move the RPC forward. + The interceptor can use + `call = await continuation(client_call_details, request_iterator)` + to continue with the RPC. `continuation` returns the call to the + RPC. + client_call_details: A ClientCallDetails object describing the + outgoing RPC. + request_iterator: The request iterator that will produce requests + for the RPC. + + Returns: + The RPC Call. + + Raises: + AioRpcError: Indicating that the RPC terminated with non-OK status. + asyncio.CancelledError: Indicating that the RPC was canceled. + """ + + class InterceptedCall: """Base implementation for all intecepted call arities. @@ -332,7 +380,16 @@ class InterceptedCall: return await call.wait_for_connection() -class InterceptedUnaryUnaryCall(InterceptedCall, _base_call.UnaryUnaryCall): +class _InterceptedUnaryResponseMixin: + + def __await__(self): + call = yield from self._interceptors_task.__await__() + response = yield from call.__await__() + return response + + +class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall, + _base_call.UnaryUnaryCall): """Used for running a `UnaryUnaryCall` wrapped by interceptors. For the `__await__` method is it is proxied to the intercepted call only when @@ -402,11 +459,6 @@ class InterceptedUnaryUnaryCall(InterceptedCall, _base_call.UnaryUnaryCall): return await _run_interceptor(iter(interceptors), client_call_details, request) - def __await__(self): - call = yield from self._interceptors_task.__await__() - response = yield from call.__await__() - return response - def time_remaining(self) -> Optional[float]: raise NotImplementedError() @@ -504,6 +556,149 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall): raise NotImplementedError() +class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin, + InterceptedCall, _base_call.StreamUnaryCall): + """Used for running a `StreamUnaryCall` wrapped by interceptors. + + For the `__await__` method is it is proxied to the intercepted call only when + the interceptor task is finished. + """ + + _loop: asyncio.AbstractEventLoop + _channel: cygrpc.AioChannel + _write_to_iterator_async_gen: Optional[AsyncIterable[RequestType]] + _write_to_iterator_queue: Optional[asyncio.Queue] + + _FINISH_ITERATOR_SENTINEL = object() + + # pylint: disable=too-many-arguments + def __init__(self, interceptors: Sequence[StreamUnaryClientInterceptor], + request_iterator: Optional[RequestIterableType], + timeout: Optional[float], metadata: MetadataType, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, + method: bytes, request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop + self._channel = channel + if request_iterator is None: + # We provide our own request iterator which is a proxy + # of the futures writes that will be done by the caller. + self._write_to_iterator_queue = asyncio.Queue(maxsize=1) + self._write_to_iterator_async_gen = self._proxy_writes_as_request_iterator( + ) + request_iterator = self._write_to_iterator_async_gen + else: + self._write_to_iterator_queue = None + + interceptors_task = loop.create_task( + self._invoke(interceptors, method, timeout, metadata, credentials, + wait_for_ready, request_iterator, request_serializer, + response_deserializer)) + super().__init__(interceptors_task) + + # pylint: disable=too-many-arguments + async def _invoke( + self, interceptors: Sequence[StreamUnaryClientInterceptor], + method: bytes, timeout: Optional[float], + metadata: Optional[MetadataType], + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + request_iterator: RequestIterableType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction) -> StreamUnaryCall: + """Run the RPC call wrapped in interceptors""" + + async def _run_interceptor( + interceptors: Iterator[UnaryUnaryClientInterceptor], + client_call_details: ClientCallDetails, + request_iterator: RequestIterableType + ) -> _base_call.StreamUnaryCall: + + interceptor = next(interceptors, None) + + if interceptor: + continuation = functools.partial(_run_interceptor, interceptors) + + return await interceptor.intercept_stream_unary( + continuation, client_call_details, request_iterator) + else: + return StreamUnaryCall( + request_iterator, + _timeout_to_deadline(client_call_details.timeout), + client_call_details.metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, self._channel, + client_call_details.method, request_serializer, + response_deserializer, self._loop) + + client_call_details = ClientCallDetails(method, timeout, metadata, + credentials, wait_for_ready) + return await _run_interceptor(iter(interceptors), client_call_details, + request_iterator) + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + async def _proxy_writes_as_request_iterator(self): + await self._interceptors_task + + while True: + value = await self._write_to_iterator_queue.get() + if value is InterceptedStreamUnaryCall._FINISH_ITERATOR_SENTINEL: + break + yield value + + async def write(self, request: RequestType) -> None: + # If no queue was created it means that requests + # should be expected through an iterators provided + # by the caller. + if self._write_to_iterator_queue is None: + raise cygrpc.UsageError(_API_STYLE_ERROR) + + try: + call = await self._interceptors_task + except (asyncio.CancelledError, AioRpcError): + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + + if call.done(): + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + elif call._done_writing_flag: + raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) + + # Write might never end up since the call could abrubtly finish, + # we give up on the first awaitable object that finishes.. + _, _ = await asyncio.wait( + (self._write_to_iterator_queue.put(request), call), + return_when=asyncio.FIRST_COMPLETED) + + if call.done(): + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + + async def done_writing(self) -> None: + """Signal peer that client is done writing. + + This method is idempotent. + """ + # If no queue was created it means that requests + # should be expected through an iterators provided + # by the caller. + if self._write_to_iterator_queue is None: + raise cygrpc.UsageError(_API_STYLE_ERROR) + + try: + call = await self._interceptors_task + except asyncio.CancelledError: + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + + # Write might never end up since the call could abrubtly finish, + # we give up on the first awaitable object that finishes. + _, _ = await asyncio.wait((self._write_to_iterator_queue.put( + InterceptedStreamUnaryCall._FINISH_ITERATOR_SENTINEL), call), + return_when=asyncio.FIRST_COMPLETED) + + class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): """Final UnaryUnaryCall class finished with a response.""" _response: ResponseType diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index e657cfa7e4d..20fa4fb588d 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -14,6 +14,7 @@ "unit.channel_argument_test.TestChannelArgument", "unit.channel_ready_test.TestChannelReady", "unit.channel_test.TestChannel", + "unit.client_stream_unary_interceptor_test.TestStreamUnaryClientInterceptor", "unit.client_unary_stream_interceptor_test.TestUnaryStreamClientInterceptor", "unit.client_unary_unary_interceptor_test.TestInterceptedUnaryUnaryCall", "unit.client_unary_unary_interceptor_test.TestUnaryUnaryClientInterceptor", diff --git a/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py new file mode 100644 index 00000000000..318a117ffe4 --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py @@ -0,0 +1,531 @@ +# Copyright 2020 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import unittest +import datetime + +import grpc + +from grpc.experimental import aio +from tests_aio.unit._constants import UNREACHABLE_TARGET +from tests_aio.unit._common import inject_callbacks +from tests_aio.unit._test_server import start_test_server +from tests_aio.unit._test_base import AioTestBase +from tests.unit.framework.common import test_constants +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc + +_SHORT_TIMEOUT_S = 1.0 + +_NUM_STREAM_REQUESTS = 5 +_REQUEST_PAYLOAD_SIZE = 7 +_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000) + + +class _CountingRequestIterator: + + def __init__(self, request_iterator): + self.request_cnt = 0 + self._request_iterator = request_iterator + + async def _forward_requests(self): + async for request in self._request_iterator: + self.request_cnt += 1 + yield request + + def __aiter__(self): + return self._forward_requests() + + +class _StreamUnaryInterceptorEmpty(aio.StreamUnaryClientInterceptor): + + async def intercept_stream_unary(self, continuation, client_call_details, + request_iterator): + return await continuation(client_call_details, request_iterator) + + def assert_in_final_state(self, test: unittest.TestCase): + pass + + +class _StreamUnaryInterceptorWithRequestIterator( + aio.StreamUnaryClientInterceptor): + + async def intercept_stream_unary(self, continuation, client_call_details, + request_iterator): + self.request_iterator = _CountingRequestIterator(request_iterator) + call = await continuation(client_call_details, self.request_iterator) + return call + + def assert_in_final_state(self, test: unittest.TestCase): + test.assertEqual(_NUM_STREAM_REQUESTS, + self.request_iterator.request_cnt) + + +class TestStreamUnaryClientInterceptor(AioTestBase): + + async def setUp(self): + self._server_target, self._server = await start_test_server() + + async def tearDown(self): + await self._server.stop(None) + + async def test_intercepts(self): + for interceptor_class in (_StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator): + + with self.subTest(name=interceptor_class): + interceptor = interceptor_class() + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * + _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest( + payload=payload) + + async def request_iterator(): + for _ in range(_NUM_STREAM_REQUESTS): + yield request + + call = stub.StreamingInputCall(request_iterator()) + + response = await call + + self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.initial_metadata(), ()) + self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(call.cancel(), False) + self.assertEqual(call.cancelled(), False) + self.assertEqual(call.done(), True) + + interceptor.assert_in_final_state(self) + + await channel.close() + + async def test_intercepts_using_write(self): + for interceptor_class in (_StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator): + + with self.subTest(name=interceptor_class): + interceptor = interceptor_class() + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * + _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest( + payload=payload) + + call = stub.StreamingInputCall() + + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(request) + + await call.done_writing() + + response = await call + + self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.initial_metadata(), ()) + self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(call.cancel(), False) + self.assertEqual(call.cancelled(), False) + self.assertEqual(call.done(), True) + + interceptor.assert_in_final_state(self) + + await channel.close() + + async def test_add_done_callback_interceptor_task_not_finished(self): + for interceptor_class in (_StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator): + + with self.subTest(name=interceptor_class): + interceptor = interceptor_class() + + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * + _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest( + payload=payload) + + async def request_iterator(): + for _ in range(_NUM_STREAM_REQUESTS): + yield request + + call = stub.StreamingInputCall(request_iterator()) + + validation = inject_callbacks(call) + + response = await call + + await validation + + await channel.close() + + async def test_add_done_callback_interceptor_task_finished(self): + for interceptor_class in (_StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator): + + with self.subTest(name=interceptor_class): + interceptor = interceptor_class() + + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * + _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest( + payload=payload) + + async def request_iterator(): + for _ in range(_NUM_STREAM_REQUESTS): + yield request + + call = stub.StreamingInputCall(request_iterator()) + + response = await call + + validation = inject_callbacks(call) + + await validation + + await channel.close() + + async def test_multiple_interceptors_request_iterator(self): + for interceptor_class in (_StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator): + + with self.subTest(name=interceptor_class): + + interceptors = [interceptor_class(), interceptor_class()] + channel = aio.insecure_channel(self._server_target, + interceptors=interceptors) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * + _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest( + payload=payload) + + async def request_iterator(): + for _ in range(_NUM_STREAM_REQUESTS): + yield request + + call = stub.StreamingInputCall(request_iterator()) + + response = await call + + self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE, + response.aggregated_payload_size) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.initial_metadata(), ()) + self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(call.cancel(), False) + self.assertEqual(call.cancelled(), False) + self.assertEqual(call.done(), True) + + for interceptor in interceptors: + interceptor.assert_in_final_state(self) + + await channel.close() + + async def test_intercepts_request_iterator_rpc_error(self): + for interceptor_class in (_StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator): + + with self.subTest(name=interceptor_class): + channel = aio.insecure_channel( + UNREACHABLE_TARGET, interceptors=[interceptor_class()]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * + _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest( + payload=payload) + + # When there is an error the request iterator is no longer + # consumed. + async def request_iterator(): + for _ in range(_NUM_STREAM_REQUESTS): + yield request + + call = stub.StreamingInputCall(request_iterator()) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + self.assertEqual(grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code()) + self.assertTrue(call.done()) + self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) + + await channel.close() + + async def test_intercepts_request_iterator_rpc_error_using_write(self): + for interceptor_class in (_StreamUnaryInterceptorEmpty, + _StreamUnaryInterceptorWithRequestIterator): + + with self.subTest(name=interceptor_class): + channel = aio.insecure_channel( + UNREACHABLE_TARGET, interceptors=[interceptor_class()]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * + _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest( + payload=payload) + + call = stub.StreamingInputCall() + + # When there is an error during the write, exception is raised. + with self.assertRaises(asyncio.InvalidStateError): + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(request) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + self.assertEqual(grpc.StatusCode.UNAVAILABLE, + exception_context.exception.code()) + self.assertTrue(call.done()) + self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) + + await channel.close() + + async def test_cancel_before_rpc(self): + + interceptor_reached = asyncio.Event() + wait_for_ever = self.loop.create_future() + + class Interceptor(aio.StreamUnaryClientInterceptor): + + async def intercept_stream_unary(self, continuation, + client_call_details, + request_iterator): + interceptor_reached.set() + await wait_for_ever + + channel = aio.insecure_channel(self._server_target, + interceptors=[Interceptor()]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + call = stub.StreamingInputCall() + + self.assertFalse(call.cancelled()) + self.assertFalse(call.done()) + + await interceptor_reached.wait() + self.assertTrue(call.cancel()) + + # When there is an error during the write, exception is raised. + with self.assertRaises(asyncio.InvalidStateError): + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(request) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + await channel.close() + + async def test_cancel_after_rpc(self): + + interceptor_reached = asyncio.Event() + wait_for_ever = self.loop.create_future() + + class Interceptor(aio.StreamUnaryClientInterceptor): + + async def intercept_stream_unary(self, continuation, + client_call_details, + request_iterator): + call = await continuation(client_call_details, request_iterator) + interceptor_reached.set() + await wait_for_ever + + channel = aio.insecure_channel(self._server_target, + interceptors=[Interceptor()]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + call = stub.StreamingInputCall() + + self.assertFalse(call.cancelled()) + self.assertFalse(call.done()) + + await interceptor_reached.wait() + self.assertTrue(call.cancel()) + + # When there is an error during the write, exception is raised. + with self.assertRaises(asyncio.InvalidStateError): + for _ in range(_NUM_STREAM_REQUESTS): + await call.write(request) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + self.assertEqual(await call.initial_metadata(), None) + self.assertEqual(await call.trailing_metadata(), None) + await channel.close() + + async def test_cancel_while_writing(self): + # Test cancelation before making any write or after doing at least 1 + for num_writes_before_cancel in (0, 1): + with self.subTest(name="Num writes before cancel: {}".format( + num_writes_before_cancel)): + + channel = aio.insecure_channel( + UNREACHABLE_TARGET, + interceptors=[_StreamUnaryInterceptorWithRequestIterator()]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * + _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest( + payload=payload) + + call = stub.StreamingInputCall() + + with self.assertRaises(asyncio.InvalidStateError): + for i in range(_NUM_STREAM_REQUESTS): + if i == num_writes_before_cancel: + self.assertTrue(call.cancel()) + await call.write(request) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + + await channel.close() + + async def test_cancel_by_the_interceptor(self): + + class Interceptor(aio.StreamUnaryClientInterceptor): + + async def intercept_stream_unary(self, continuation, + client_call_details, + request_iterator): + call = await continuation(client_call_details, request_iterator) + call.cancel() + return call + + channel = aio.insecure_channel(UNREACHABLE_TARGET, + interceptors=[Interceptor()]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + call = stub.StreamingInputCall() + + with self.assertRaises(asyncio.InvalidStateError): + for i in range(_NUM_STREAM_REQUESTS): + await call.write(request) + + with self.assertRaises(asyncio.CancelledError): + await call + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + + await channel.close() + + async def test_exception_raised_by_interceptor(self): + + class InterceptorException(Exception): + pass + + class Interceptor(aio.StreamUnaryClientInterceptor): + + async def intercept_stream_unary(self, continuation, + client_call_details, + request_iterator): + raise InterceptorException + + channel = aio.insecure_channel(UNREACHABLE_TARGET, + interceptors=[Interceptor()]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + call = stub.StreamingInputCall() + + with self.assertRaises(InterceptorException): + for i in range(_NUM_STREAM_REQUESTS): + await call.write(request) + + with self.assertRaises(InterceptorException): + await call + + await channel.close() + + async def test_intercepts_prohibit_mixing_style(self): + channel = aio.insecure_channel( + self._server_target, interceptors=[_StreamUnaryInterceptorEmpty()]) + stub = test_pb2_grpc.TestServiceStub(channel) + + payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE) + request = messages_pb2.StreamingInputCallRequest(payload=payload) + + async def request_iterator(): + for _ in range(_NUM_STREAM_REQUESTS): + yield request + + call = stub.StreamingInputCall(request_iterator()) + + with self.assertRaises(grpc._cython.cygrpc.UsageError): + await call.write(request) + + with self.assertRaises(grpc._cython.cygrpc.UsageError): + await call.done_writing() + + await channel.close() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2)