diff --git a/src/python/grpcio/grpc/experimental/aio/__init__.py b/src/python/grpcio/grpc/experimental/aio/__init__.py index c0fc17eeb92..2933aa5a45e 100644 --- a/src/python/grpcio/grpc/experimental/aio/__init__.py +++ b/src/python/grpcio/grpc/experimental/aio/__init__.py @@ -34,7 +34,8 @@ from ._interceptor import (ClientCallDetails, ClientInterceptor, InterceptedUnaryUnaryCall, UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, - StreamUnaryClientInterceptor, ServerInterceptor) + StreamUnaryClientInterceptor, + StreamStreamClientInterceptor, ServerInterceptor) from ._server import server from ._base_server import Server, ServicerContext from ._typing import ChannelArgumentType @@ -63,6 +64,7 @@ __all__ = ( 'UnaryStreamClientInterceptor', 'UnaryUnaryClientInterceptor', 'StreamUnaryClientInterceptor', + 'StreamStreamClientInterceptor', 'InterceptedUnaryUnaryCall', 'ServerInterceptor', 'insecure_channel', diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 97ffa833bbf..9a70466d157 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -24,12 +24,11 @@ from grpc._cython import cygrpc from . import _base_call, _base_channel from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall, UnaryUnaryCall) -from ._interceptor import (InterceptedUnaryUnaryCall, - InterceptedUnaryStreamCall, - InterceptedStreamUnaryCall, ClientInterceptor, - UnaryUnaryClientInterceptor, - UnaryStreamClientInterceptor, - StreamUnaryClientInterceptor) +from ._interceptor import ( + InterceptedUnaryUnaryCall, InterceptedUnaryStreamCall, + InterceptedStreamUnaryCall, InterceptedStreamStreamCall, ClientInterceptor, + UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, + StreamUnaryClientInterceptor, StreamStreamClientInterceptor) from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType, SerializingFunction, RequestIterableType) from ._utils import _timeout_to_deadline @@ -200,10 +199,17 @@ class StreamStreamMultiCallable(_BaseMultiCallable, deadline = _timeout_to_deadline(timeout) - call = StreamStreamCall(request_iterator, deadline, metadata, - credentials, wait_for_ready, self._channel, - self._method, self._request_serializer, - self._response_deserializer, self._loop) + if not self._interceptors: + call = StreamStreamCall(request_iterator, deadline, metadata, + credentials, wait_for_ready, self._channel, + self._method, self._request_serializer, + self._response_deserializer, self._loop) + else: + call = InterceptedStreamStreamCall( + self._interceptors, request_iterator, deadline, metadata, + credentials, wait_for_ready, self._channel, self._method, + self._request_serializer, self._response_deserializer, + self._loop) return call @@ -214,6 +220,7 @@ class Channel(_base_channel.Channel): _unary_unary_interceptors: List[UnaryUnaryClientInterceptor] _unary_stream_interceptors: List[UnaryStreamClientInterceptor] _stream_unary_interceptors: List[StreamUnaryClientInterceptor] + _stream_stream_interceptors: List[StreamStreamClientInterceptor] def __init__(self, target: str, options: ChannelArgumentType, credentials: Optional[grpc.ChannelCredentials], @@ -233,6 +240,7 @@ class Channel(_base_channel.Channel): self._unary_unary_interceptors = [] self._unary_stream_interceptors = [] self._stream_unary_interceptors = [] + self._stream_stream_interceptors = [] if interceptors: attrs_and_interceptor_classes = ((self._unary_unary_interceptors, @@ -240,7 +248,9 @@ class Channel(_base_channel.Channel): (self._unary_stream_interceptors, UnaryStreamClientInterceptor), (self._stream_unary_interceptors, - StreamUnaryClientInterceptor)) + StreamUnaryClientInterceptor), + (self._stream_stream_interceptors, + StreamStreamClientInterceptor)) # pylint: disable=cell-var-from-loop for attr, interceptor_class in attrs_and_interceptor_classes: @@ -252,14 +262,16 @@ class Channel(_base_channel.Channel): invalid_interceptors = set(interceptors) - set( self._unary_unary_interceptors) - set( self._unary_stream_interceptors) - set( - self._stream_unary_interceptors) + self._stream_unary_interceptors) - set( + self._stream_stream_interceptors) if invalid_interceptors: raise ValueError( "Interceptor must be " + "{} or ".format(UnaryUnaryClientInterceptor.__name__) + "{} or ".format(UnaryStreamClientInterceptor.__name__) + - "{}. ".format(StreamUnaryClientInterceptor.__name__) + + "{} or ".format(StreamUnaryClientInterceptor.__name__) + + "{}. ".format(StreamStreamClientInterceptor.__name__) + "The following are invalid: {}".format(invalid_interceptors) ) @@ -411,7 +423,8 @@ class Channel(_base_channel.Channel): ) -> StreamStreamMultiCallable: return StreamStreamMultiCallable(self._channel, _common.encode(method), request_serializer, - response_deserializer, None, + response_deserializer, + self._stream_stream_interceptors, self._loop) diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index e4969ddb4a5..4ce1f2306cf 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -22,7 +22,7 @@ import grpc from grpc._cython import cygrpc from . import _base_call -from ._call import UnaryUnaryCall, UnaryStreamCall, StreamUnaryCall, AioRpcError +from ._call import UnaryUnaryCall, UnaryStreamCall, StreamUnaryCall, StreamStreamCall, AioRpcError from ._call import _RPC_ALREADY_FINISHED_DETAILS, _RPC_HALF_CLOSED_DETAILS from ._call import _API_STYLE_ERROR from ._utils import _timeout_to_deadline @@ -153,7 +153,7 @@ class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta): request: The request value for the RPC. Returns: - The RPC Call. + The RPC Call or an asynchronous iterator. Raises: AioRpcError: Indicating that the RPC terminated with non-OK status. @@ -202,6 +202,51 @@ class StreamUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta): """ +class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta): + """Affords intercepting stream-stream invocations.""" + + @abstractmethod + async def intercept_stream_stream( + self, + continuation: Callable[[ClientCallDetails, RequestType], + UnaryStreamCall], + client_call_details: ClientCallDetails, + request_iterator: RequestIterableType, + ) -> Union[AsyncIterable[ResponseType], StreamStreamCall]: + """Intercepts a stream-stream invocation asynchronously. + + Within the interceptor the usage of the call methods like `write` or + even awaiting the call should be done carefully, since the caller + could be expecting an untouched call, for example for start writing + messages to it. + + The function could return the call object or an asynchronous + iterator, in case of being an asyncrhonous iterator this will + become the source of the reads done by the caller. + + Args: + continuation: A coroutine that proceeds with the invocation by + executing the next interceptor in the chain or invoking the + actual RPC on the underlying Channel. It is the interceptor's + responsibility to call it if it decides to move the RPC forward. + The interceptor can use + `call = await continuation(client_call_details, request_iterator)` + to continue with the RPC. `continuation` returns the call to the + RPC. + client_call_details: A ClientCallDetails object describing the + outgoing RPC. + request_iterator: The request iterator that will produce requests + for the RPC. + + Returns: + The RPC Call or an asynchronous iterator. + + Raises: + AioRpcError: Indicating that the RPC terminated with non-OK status. + asyncio.CancelledError: Indicating that the RPC was canceled. + """ + + class InterceptedCall: """Base implementation for all intecepted call arities. @@ -388,6 +433,111 @@ class _InterceptedUnaryResponseMixin: return response +class _InterceptedStreamResponseMixin: + _response_aiter: AsyncIterable[ResponseType] + + def _init_stream_response_mixin(self) -> None: + self._response_aiter = self._wait_for_interceptor_task_response_iterator( + ) + + async def _wait_for_interceptor_task_response_iterator(self + ) -> ResponseType: + call = await self._interceptors_task + async for response in call: + yield response + + def __aiter__(self) -> AsyncIterable[ResponseType]: + return self._response_aiter + + async def read(self) -> ResponseType: + return await self._response_aiter.asend(None) + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() + + +class _InterceptedStreamRequestMixin: + + _write_to_iterator_async_gen: Optional[AsyncIterable[RequestType]] + _write_to_iterator_queue: Optional[asyncio.Queue] + + _FINISH_ITERATOR_SENTINEL = object() + + def _init_stream_request_mixin( + self, request_iterator: Optional[RequestIterableType] + ) -> RequestIterableType: + + if request_iterator is None: + # We provide our own request iterator which is a proxy + # of the futures writes that will be done by the caller. + self._write_to_iterator_queue = asyncio.Queue(maxsize=1) + self._write_to_iterator_async_gen = self._proxy_writes_as_request_iterator( + ) + request_iterator = self._write_to_iterator_async_gen + else: + self._write_to_iterator_queue = None + + return request_iterator + + async def _proxy_writes_as_request_iterator(self): + await self._interceptors_task + + while True: + value = await self._write_to_iterator_queue.get() + if value is _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL: + break + yield value + + async def write(self, request: RequestType) -> None: + # If no queue was created it means that requests + # should be expected through an iterators provided + # by the caller. + if self._write_to_iterator_queue is None: + raise cygrpc.UsageError(_API_STYLE_ERROR) + + try: + call = await self._interceptors_task + except (asyncio.CancelledError, AioRpcError): + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + + if call.done(): + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + elif call._done_writing_flag: + raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) + + # Write might never end up since the call could abrubtly finish, + # we give up on the first awaitable object that finishes. + _, _ = await asyncio.wait( + (self._write_to_iterator_queue.put(request), call.code()), + return_when=asyncio.FIRST_COMPLETED) + + if call.done(): + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + + async def done_writing(self) -> None: + """Signal peer that client is done writing. + + This method is idempotent. + """ + # If no queue was created it means that requests + # should be expected through an iterators provided + # by the caller. + if self._write_to_iterator_queue is None: + raise cygrpc.UsageError(_API_STYLE_ERROR) + + try: + call = await self._interceptors_task + except asyncio.CancelledError: + raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + + # Write might never end up since the call could abrubtly finish, + # we give up on the first awaitable object that finishes. + _, _ = await asyncio.wait((self._write_to_iterator_queue.put( + _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL), + call.code()), + return_when=asyncio.FIRST_COMPLETED) + + class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall, _base_call.UnaryUnaryCall): """Used for running a `UnaryUnaryCall` wrapped by interceptors. @@ -463,12 +613,12 @@ class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall, raise NotImplementedError() -class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall): +class InterceptedUnaryStreamCall(_InterceptedStreamResponseMixin, + InterceptedCall, _base_call.UnaryStreamCall): """Used for running a `UnaryStreamCall` wrapped by interceptors.""" _loop: asyncio.AbstractEventLoop _channel: cygrpc.AioChannel - _response_aiter: AsyncIterable[ResponseType] _last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall] # pylint: disable=too-many-arguments @@ -482,8 +632,7 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall): loop: asyncio.AbstractEventLoop) -> None: self._loop = loop self._channel = channel - self._response_aiter = self._wait_for_interceptor_task_response_iterator( - ) + self._init_stream_response_mixin() self._last_returned_call_from_interceptors = None interceptors_task = loop.create_task( self._invoke(interceptors, method, timeout, metadata, credentials, @@ -517,7 +666,7 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall): continuation, client_call_details, request) if isinstance(call_or_response_iterator, - _base_call.UnaryUnaryCall): + _base_call.UnaryStreamCall): self._last_returned_call_from_interceptors = call_or_response_iterator else: self._last_returned_call_from_interceptors = UnaryStreamCallResponseIterator( @@ -540,23 +689,12 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall): return await _run_interceptor(iter(interceptors), client_call_details, request) - async def _wait_for_interceptor_task_response_iterator(self - ) -> ResponseType: - call = await self._interceptors_task - async for response in call: - yield response - - def __aiter__(self) -> AsyncIterable[ResponseType]: - return self._response_aiter - - async def read(self) -> ResponseType: - return await self._response_aiter.asend(None) - def time_remaining(self) -> Optional[float]: raise NotImplementedError() class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin, + _InterceptedStreamRequestMixin, InterceptedCall, _base_call.StreamUnaryCall): """Used for running a `StreamUnaryCall` wrapped by interceptors. @@ -566,10 +704,6 @@ class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin, _loop: asyncio.AbstractEventLoop _channel: cygrpc.AioChannel - _write_to_iterator_async_gen: Optional[AsyncIterable[RequestType]] - _write_to_iterator_queue: Optional[asyncio.Queue] - - _FINISH_ITERATOR_SENTINEL = object() # pylint: disable=too-many-arguments def __init__(self, interceptors: Sequence[StreamUnaryClientInterceptor], @@ -582,16 +716,7 @@ class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin, loop: asyncio.AbstractEventLoop) -> None: self._loop = loop self._channel = channel - if request_iterator is None: - # We provide our own request iterator which is a proxy - # of the futures writes that will be done by the caller. - self._write_to_iterator_queue = asyncio.Queue(maxsize=1) - self._write_to_iterator_async_gen = self._proxy_writes_as_request_iterator( - ) - request_iterator = self._write_to_iterator_async_gen - else: - self._write_to_iterator_queue = None - + request_iterator = self._init_stream_request_mixin(request_iterator) interceptors_task = loop.create_task( self._invoke(interceptors, method, timeout, metadata, credentials, wait_for_ready, request_iterator, request_serializer, @@ -641,62 +766,88 @@ class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin, def time_remaining(self) -> Optional[float]: raise NotImplementedError() - async def _proxy_writes_as_request_iterator(self): - await self._interceptors_task - while True: - value = await self._write_to_iterator_queue.get() - if value is InterceptedStreamUnaryCall._FINISH_ITERATOR_SENTINEL: - break - yield value +class InterceptedStreamStreamCall(_InterceptedStreamResponseMixin, + _InterceptedStreamRequestMixin, + InterceptedCall, _base_call.StreamStreamCall): + """Used for running a `StreamStreamCall` wrapped by interceptors.""" - async def write(self, request: RequestType) -> None: - # If no queue was created it means that requests - # should be expected through an iterators provided - # by the caller. - if self._write_to_iterator_queue is None: - raise cygrpc.UsageError(_API_STYLE_ERROR) + _loop: asyncio.AbstractEventLoop + _channel: cygrpc.AioChannel + _last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall] - try: - call = await self._interceptors_task - except (asyncio.CancelledError, AioRpcError): - raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + # pylint: disable=too-many-arguments + def __init__(self, interceptors: Sequence[StreamStreamClientInterceptor], + request_iterator: Optional[RequestIterableType], + timeout: Optional[float], metadata: MetadataType, + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], channel: cygrpc.AioChannel, + method: bytes, request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction, + loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop + self._channel = channel + self._init_stream_response_mixin() + request_iterator = self._init_stream_request_mixin(request_iterator) + self._last_returned_call_from_interceptors = None + interceptors_task = loop.create_task( + self._invoke(interceptors, method, timeout, metadata, credentials, + wait_for_ready, request_iterator, request_serializer, + response_deserializer)) + super().__init__(interceptors_task) - if call.done(): - raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) - elif call._done_writing_flag: - raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS) + # pylint: disable=too-many-arguments + async def _invoke( + self, interceptors: Sequence[StreamStreamClientInterceptor], + method: bytes, timeout: Optional[float], + metadata: Optional[MetadataType], + credentials: Optional[grpc.CallCredentials], + wait_for_ready: Optional[bool], + request_iterator: RequestIterableType, + request_serializer: SerializingFunction, + response_deserializer: DeserializingFunction) -> StreamStreamCall: + """Run the RPC call wrapped in interceptors""" - # Write might never end up since the call could abrubtly finish, - # we give up on the first awaitable object that finishes.. - _, _ = await asyncio.wait( - (self._write_to_iterator_queue.put(request), call), - return_when=asyncio.FIRST_COMPLETED) + async def _run_interceptor( + interceptors: Iterator[StreamStreamClientInterceptor], + client_call_details: ClientCallDetails, + request_iterator: RequestIterableType + ) -> _base_call.StreamStreamCall: - if call.done(): - raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + interceptor = next(interceptors, None) - async def done_writing(self) -> None: - """Signal peer that client is done writing. + if interceptor: + continuation = functools.partial(_run_interceptor, interceptors) - This method is idempotent. - """ - # If no queue was created it means that requests - # should be expected through an iterators provided - # by the caller. - if self._write_to_iterator_queue is None: - raise cygrpc.UsageError(_API_STYLE_ERROR) + call_or_response_iterator = await interceptor.intercept_stream_stream( + continuation, client_call_details, request_iterator) - try: - call = await self._interceptors_task - except asyncio.CancelledError: - raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) + if isinstance(call_or_response_iterator, + _base_call.StreamStreamCall): + self._last_returned_call_from_interceptors = call_or_response_iterator + else: + self._last_returned_call_from_interceptors = StreamStreamCallResponseIterator( + self._last_returned_call_from_interceptors, + call_or_response_iterator) + return self._last_returned_call_from_interceptors + else: + self._last_returned_call_from_interceptors = StreamStreamCall( + request_iterator, + _timeout_to_deadline(client_call_details.timeout), + client_call_details.metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, self._channel, + client_call_details.method, request_serializer, + response_deserializer, self._loop) + return self._last_returned_call_from_interceptors - # Write might never end up since the call could abrubtly finish, - # we give up on the first awaitable object that finishes. - _, _ = await asyncio.wait((self._write_to_iterator_queue.put( - InterceptedStreamUnaryCall._FINISH_ITERATOR_SENTINEL), call), - return_when=asyncio.FIRST_COMPLETED) + client_call_details = ClientCallDetails(method, timeout, metadata, + credentials, wait_for_ready) + return await _run_interceptor(iter(interceptors), client_call_details, + request_iterator) + + def time_remaining(self) -> Optional[float]: + raise NotImplementedError() class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): @@ -747,12 +898,13 @@ class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): pass -class UnaryStreamCallResponseIterator(_base_call.UnaryStreamCall): - """UnaryStreamCall class wich uses an alternative response iterator.""" - _call: _base_call.UnaryStreamCall +class _StreamCallResponseIterator: + + _call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall] _response_iterator: AsyncIterable[ResponseType] - def __init__(self, call: _base_call.UnaryStreamCall, + def __init__(self, call: Union[_base_call.UnaryStreamCall, _base_call. + StreamStreamCall], response_iterator: AsyncIterable[ResponseType]) -> None: self._response_iterator = response_iterator self._call = call @@ -797,3 +949,29 @@ class UnaryStreamCallResponseIterator(_base_call.UnaryStreamCall): # Behind the scenes everyting goes through the # async iterator. So this path should not be reached. raise Exception() + + +class UnaryStreamCallResponseIterator(_StreamCallResponseIterator, + _base_call.UnaryStreamCall): + """UnaryStreamCall class wich uses an alternative response iterator.""" + + +class StreamStreamCallResponseIterator(_StreamCallResponseIterator, + _base_call.StreamStreamCall): + """UnaryStreamCall class wich uses an alternative response iterator.""" + + async def write(self, request: RequestType) -> None: + # Behind the scenes everyting goes through the + # async iterator provided by the InterceptedStreamStreamCall. + # So this path should not be reached. + raise Exception() + + async def done_writing(self) -> None: + # Behind the scenes everyting goes through the + # async iterator provided by the InterceptedStreamStreamCall. + # So this path should not be reached. + raise Exception() + + @property + def _done_writing_flag(self) -> bool: + return self._call._done_writing_flag diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index 87cfce80ae8..f01d7d0570d 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -16,6 +16,7 @@ "unit.channel_argument_test.TestChannelArgument", "unit.channel_ready_test.TestChannelReady", "unit.channel_test.TestChannel", + "unit.client_stream_stream_interceptor_test.TestStreamStreamClientInterceptor", "unit.client_stream_unary_interceptor_test.TestStreamUnaryClientInterceptor", "unit.client_unary_stream_interceptor_test.TestUnaryStreamClientInterceptor", "unit.client_unary_unary_interceptor_test.TestInterceptedUnaryUnaryCall", diff --git a/src/python/grpcio_tests/tests_aio/unit/_common.py b/src/python/grpcio_tests/tests_aio/unit/_common.py index 97cbe759ed0..5c1e49bd497 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_common.py +++ b/src/python/grpcio_tests/tests_aio/unit/_common.py @@ -64,3 +64,33 @@ def inject_callbacks(call): test_constants.SHORT_TIMEOUT) return validation() + + +class CountingRequestIterator: + + def __init__(self, request_iterator): + self.request_cnt = 0 + self._request_iterator = request_iterator + + async def _forward_requests(self): + async for request in self._request_iterator: + self.request_cnt += 1 + yield request + + def __aiter__(self): + return self._forward_requests() + + +class CountingResponseIterator: + + def __init__(self, response_iterator): + self.response_cnt = 0 + self._response_iterator = response_iterator + + async def _forward_responses(self): + async for response in self._response_iterator: + self.response_cnt += 1 + yield response + + def __aiter__(self): + return self._forward_responses() diff --git a/src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py new file mode 100644 index 00000000000..ed924558fbc --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py @@ -0,0 +1,202 @@ +# Copyright 2020 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import unittest + +import grpc + +from grpc.experimental import aio +from tests_aio.unit._common import CountingResponseIterator, CountingRequestIterator +from tests_aio.unit._test_server import start_test_server +from tests_aio.unit._test_base import AioTestBase +from src.proto.grpc.testing import messages_pb2, test_pb2_grpc + +_NUM_STREAM_RESPONSES = 5 +_NUM_STREAM_REQUESTS = 5 +_RESPONSE_PAYLOAD_SIZE = 7 + + +class _StreamStreamInterceptorEmpty(aio.StreamStreamClientInterceptor): + + async def intercept_stream_stream(self, continuation, client_call_details, + request_iterator): + return await continuation(client_call_details, request_iterator) + + def assert_in_final_state(self, test: unittest.TestCase): + pass + + +class _StreamStreamInterceptorWithRequestAndResponseIterator( + aio.StreamStreamClientInterceptor): + + async def intercept_stream_stream(self, continuation, client_call_details, + request_iterator): + self.request_iterator = CountingRequestIterator(request_iterator) + call = await continuation(client_call_details, self.request_iterator) + self.response_iterator = CountingResponseIterator(call) + return self.response_iterator + + def assert_in_final_state(self, test: unittest.TestCase): + test.assertEqual(_NUM_STREAM_REQUESTS, + self.request_iterator.request_cnt) + test.assertEqual(_NUM_STREAM_RESPONSES, + self.response_iterator.response_cnt) + + +class TestStreamStreamClientInterceptor(AioTestBase): + + async def setUp(self): + self._server_target, self._server = await start_test_server() + + async def tearDown(self): + await self._server.stop(None) + + async def test_intercepts(self): + + for interceptor_class in ( + _StreamStreamInterceptorEmpty, + _StreamStreamInterceptorWithRequestAndResponseIterator): + + with self.subTest(name=interceptor_class): + interceptor = interceptor_class() + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE)) + + async def request_iterator(): + for _ in range(_NUM_STREAM_REQUESTS): + yield request + + call = stub.FullDuplexCall(request_iterator()) + + await call.wait_for_connection() + + response_cnt = 0 + async for response in call: + response_cnt += 1 + self.assertIs(type(response), + messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, + len(response.payload.body)) + + self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.initial_metadata(), ()) + self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(call.cancel(), False) + self.assertEqual(call.cancelled(), False) + self.assertEqual(call.done(), True) + + interceptor.assert_in_final_state(self) + + await channel.close() + + async def test_intercepts_using_write_and_read(self): + for interceptor_class in ( + _StreamStreamInterceptorEmpty, + _StreamStreamInterceptorWithRequestAndResponseIterator): + + with self.subTest(name=interceptor_class): + interceptor = interceptor_class() + channel = aio.insecure_channel(self._server_target, + interceptors=[interceptor]) + stub = test_pb2_grpc.TestServiceStub(channel) + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE)) + + call = stub.FullDuplexCall() + + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + response = await call.read() + self.assertIsInstance( + response, messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, + len(response.payload.body)) + + await call.done_writing() + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.initial_metadata(), ()) + self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(call.cancel(), False) + self.assertEqual(call.cancelled(), False) + self.assertEqual(call.done(), True) + + interceptor.assert_in_final_state(self) + + await channel.close() + + async def test_multiple_interceptors_request_iterator(self): + for interceptor_class in ( + _StreamStreamInterceptorEmpty, + _StreamStreamInterceptorWithRequestAndResponseIterator): + + with self.subTest(name=interceptor_class): + + interceptors = [interceptor_class(), interceptor_class()] + channel = aio.insecure_channel(self._server_target, + interceptors=interceptors) + stub = test_pb2_grpc.TestServiceStub(channel) + + # Prepares the request + request = messages_pb2.StreamingOutputCallRequest() + request.response_parameters.append( + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE)) + + call = stub.FullDuplexCall() + + for _ in range(_NUM_STREAM_RESPONSES): + await call.write(request) + response = await call.read() + self.assertIsInstance( + response, messages_pb2.StreamingOutputCallResponse) + self.assertEqual(_RESPONSE_PAYLOAD_SIZE, + len(response.payload.body)) + + await call.done_writing() + + self.assertEqual(await call.code(), grpc.StatusCode.OK) + self.assertEqual(await call.initial_metadata(), ()) + self.assertEqual(await call.trailing_metadata(), ()) + self.assertEqual(await call.details(), '') + self.assertEqual(await call.debug_error_string(), '') + self.assertEqual(call.cancel(), False) + self.assertEqual(call.cancelled(), False) + self.assertEqual(call.done(), True) + + for interceptor in interceptors: + interceptor.assert_in_final_state(self) + + await channel.close() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py index 318a117ffe4..2d58cff0e41 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py @@ -21,6 +21,7 @@ import grpc from grpc.experimental import aio from tests_aio.unit._constants import UNREACHABLE_TARGET from tests_aio.unit._common import inject_callbacks +from tests_aio.unit._common import CountingRequestIterator from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_base import AioTestBase from tests.unit.framework.common import test_constants @@ -33,21 +34,6 @@ _REQUEST_PAYLOAD_SIZE = 7 _RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000) -class _CountingRequestIterator: - - def __init__(self, request_iterator): - self.request_cnt = 0 - self._request_iterator = request_iterator - - async def _forward_requests(self): - async for request in self._request_iterator: - self.request_cnt += 1 - yield request - - def __aiter__(self): - return self._forward_requests() - - class _StreamUnaryInterceptorEmpty(aio.StreamUnaryClientInterceptor): async def intercept_stream_unary(self, continuation, client_call_details, @@ -63,7 +49,7 @@ class _StreamUnaryInterceptorWithRequestIterator( async def intercept_stream_unary(self, continuation, client_call_details, request_iterator): - self.request_iterator = _CountingRequestIterator(request_iterator) + self.request_iterator = CountingRequestIterator(request_iterator) call = await continuation(client_call_details, self.request_iterator) return call diff --git a/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py index cfb3ead226f..6137538ffca 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py @@ -21,6 +21,7 @@ import grpc from grpc.experimental import aio from tests_aio.unit._constants import UNREACHABLE_TARGET from tests_aio.unit._common import inject_callbacks +from tests_aio.unit._common import CountingResponseIterator from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_base import AioTestBase from tests.unit.framework.common import test_constants @@ -34,21 +35,6 @@ _RESPONSE_PAYLOAD_SIZE = 7 _RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000) -class _CountingResponseIterator: - - def __init__(self, response_iterator): - self.response_cnt = 0 - self._response_iterator = response_iterator - - async def _forward_responses(self): - async for response in self._response_iterator: - self.response_cnt += 1 - yield response - - def __aiter__(self): - return self._forward_responses() - - class _UnaryStreamInterceptorEmpty(aio.UnaryStreamClientInterceptor): async def intercept_unary_stream(self, continuation, client_call_details, @@ -65,7 +51,7 @@ class _UnaryStreamInterceptorWithResponseIterator( async def intercept_unary_stream(self, continuation, client_call_details, request): call = await continuation(client_call_details, request) - self.response_iterator = _CountingResponseIterator(call) + self.response_iterator = CountingResponseIterator(call) return self.response_iterator def assert_in_final_state(self, test: unittest.TestCase):