Add stream stream client interceptor support

This was the last missing arity which did not have support yet for
the interceptors in the client side for the Aio package. This commit
adds specific support for this interceptro which allows the deveveloper
to intercept the request iterator and the response iterator.
pull/23092/head
Pau Freixes 5 years ago
parent c9ed65f469
commit b3425f6dbf
  1. 4
      src/python/grpcio/grpc/experimental/aio/__init__.py
  2. 41
      src/python/grpcio/grpc/experimental/aio/_channel.py
  3. 344
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  4. 1
      src/python/grpcio_tests/tests_aio/tests.json
  5. 30
      src/python/grpcio_tests/tests_aio/unit/_common.py
  6. 202
      src/python/grpcio_tests/tests_aio/unit/client_stream_stream_interceptor_test.py
  7. 18
      src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py
  8. 18
      src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py

@ -34,7 +34,8 @@ from ._interceptor import (ClientCallDetails, ClientInterceptor,
InterceptedUnaryUnaryCall,
UnaryUnaryClientInterceptor,
UnaryStreamClientInterceptor,
StreamUnaryClientInterceptor, ServerInterceptor)
StreamUnaryClientInterceptor,
StreamStreamClientInterceptor, ServerInterceptor)
from ._server import server
from ._base_server import Server, ServicerContext
from ._typing import ChannelArgumentType
@ -63,6 +64,7 @@ __all__ = (
'UnaryStreamClientInterceptor',
'UnaryUnaryClientInterceptor',
'StreamUnaryClientInterceptor',
'StreamStreamClientInterceptor',
'InterceptedUnaryUnaryCall',
'ServerInterceptor',
'insecure_channel',

@ -24,12 +24,11 @@ from grpc._cython import cygrpc
from . import _base_call, _base_channel
from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall,
UnaryUnaryCall)
from ._interceptor import (InterceptedUnaryUnaryCall,
InterceptedUnaryStreamCall,
InterceptedStreamUnaryCall, ClientInterceptor,
UnaryUnaryClientInterceptor,
UnaryStreamClientInterceptor,
StreamUnaryClientInterceptor)
from ._interceptor import (
InterceptedUnaryUnaryCall, InterceptedUnaryStreamCall,
InterceptedStreamUnaryCall, InterceptedStreamStreamCall, ClientInterceptor,
UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
StreamUnaryClientInterceptor, StreamStreamClientInterceptor)
from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
SerializingFunction, RequestIterableType)
from ._utils import _timeout_to_deadline
@ -200,10 +199,17 @@ class StreamStreamMultiCallable(_BaseMultiCallable,
deadline = _timeout_to_deadline(timeout)
call = StreamStreamCall(request_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)
if not self._interceptors:
call = StreamStreamCall(request_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)
else:
call = InterceptedStreamStreamCall(
self._interceptors, request_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel, self._method,
self._request_serializer, self._response_deserializer,
self._loop)
return call
@ -214,6 +220,7 @@ class Channel(_base_channel.Channel):
_unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
_unary_stream_interceptors: List[UnaryStreamClientInterceptor]
_stream_unary_interceptors: List[StreamUnaryClientInterceptor]
_stream_stream_interceptors: List[StreamStreamClientInterceptor]
def __init__(self, target: str, options: ChannelArgumentType,
credentials: Optional[grpc.ChannelCredentials],
@ -233,6 +240,7 @@ class Channel(_base_channel.Channel):
self._unary_unary_interceptors = []
self._unary_stream_interceptors = []
self._stream_unary_interceptors = []
self._stream_stream_interceptors = []
if interceptors:
attrs_and_interceptor_classes = ((self._unary_unary_interceptors,
@ -240,7 +248,9 @@ class Channel(_base_channel.Channel):
(self._unary_stream_interceptors,
UnaryStreamClientInterceptor),
(self._stream_unary_interceptors,
StreamUnaryClientInterceptor))
StreamUnaryClientInterceptor),
(self._stream_stream_interceptors,
StreamStreamClientInterceptor))
# pylint: disable=cell-var-from-loop
for attr, interceptor_class in attrs_and_interceptor_classes:
@ -252,14 +262,16 @@ class Channel(_base_channel.Channel):
invalid_interceptors = set(interceptors) - set(
self._unary_unary_interceptors) - set(
self._unary_stream_interceptors) - set(
self._stream_unary_interceptors)
self._stream_unary_interceptors) - set(
self._stream_stream_interceptors)
if invalid_interceptors:
raise ValueError(
"Interceptor must be " +
"{} or ".format(UnaryUnaryClientInterceptor.__name__) +
"{} or ".format(UnaryStreamClientInterceptor.__name__) +
"{}. ".format(StreamUnaryClientInterceptor.__name__) +
"{} or ".format(StreamUnaryClientInterceptor.__name__) +
"{}. ".format(StreamStreamClientInterceptor.__name__) +
"The following are invalid: {}".format(invalid_interceptors)
)
@ -411,7 +423,8 @@ class Channel(_base_channel.Channel):
) -> StreamStreamMultiCallable:
return StreamStreamMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer, None,
response_deserializer,
self._stream_stream_interceptors,
self._loop)

@ -22,7 +22,7 @@ import grpc
from grpc._cython import cygrpc
from . import _base_call
from ._call import UnaryUnaryCall, UnaryStreamCall, StreamUnaryCall, AioRpcError
from ._call import UnaryUnaryCall, UnaryStreamCall, StreamUnaryCall, StreamStreamCall, AioRpcError
from ._call import _RPC_ALREADY_FINISHED_DETAILS, _RPC_HALF_CLOSED_DETAILS
from ._call import _API_STYLE_ERROR
from ._utils import _timeout_to_deadline
@ -153,7 +153,7 @@ class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
request: The request value for the RPC.
Returns:
The RPC Call.
The RPC Call or an asynchronous iterator.
Raises:
AioRpcError: Indicating that the RPC terminated with non-OK status.
@ -202,6 +202,51 @@ class StreamUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
"""
class StreamStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
"""Affords intercepting stream-stream invocations."""
@abstractmethod
async def intercept_stream_stream(
self,
continuation: Callable[[ClientCallDetails, RequestType],
UnaryStreamCall],
client_call_details: ClientCallDetails,
request_iterator: RequestIterableType,
) -> Union[AsyncIterable[ResponseType], StreamStreamCall]:
"""Intercepts a stream-stream invocation asynchronously.
Within the interceptor the usage of the call methods like `write` or
even awaiting the call should be done carefully, since the caller
could be expecting an untouched call, for example for start writing
messages to it.
The function could return the call object or an asynchronous
iterator, in case of being an asyncrhonous iterator this will
become the source of the reads done by the caller.
Args:
continuation: A coroutine that proceeds with the invocation by
executing the next interceptor in the chain or invoking the
actual RPC on the underlying Channel. It is the interceptor's
responsibility to call it if it decides to move the RPC forward.
The interceptor can use
`call = await continuation(client_call_details, request_iterator)`
to continue with the RPC. `continuation` returns the call to the
RPC.
client_call_details: A ClientCallDetails object describing the
outgoing RPC.
request_iterator: The request iterator that will produce requests
for the RPC.
Returns:
The RPC Call or an asynchronous iterator.
Raises:
AioRpcError: Indicating that the RPC terminated with non-OK status.
asyncio.CancelledError: Indicating that the RPC was canceled.
"""
class InterceptedCall:
"""Base implementation for all intecepted call arities.
@ -388,6 +433,111 @@ class _InterceptedUnaryResponseMixin:
return response
class _InterceptedStreamResponseMixin:
_response_aiter: AsyncIterable[ResponseType]
def _init_stream_response_mixin(self) -> None:
self._response_aiter = self._wait_for_interceptor_task_response_iterator(
)
async def _wait_for_interceptor_task_response_iterator(self
) -> ResponseType:
call = await self._interceptors_task
async for response in call:
yield response
def __aiter__(self) -> AsyncIterable[ResponseType]:
return self._response_aiter
async def read(self) -> ResponseType:
return await self._response_aiter.asend(None)
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
class _InterceptedStreamRequestMixin:
_write_to_iterator_async_gen: Optional[AsyncIterable[RequestType]]
_write_to_iterator_queue: Optional[asyncio.Queue]
_FINISH_ITERATOR_SENTINEL = object()
def _init_stream_request_mixin(
self, request_iterator: Optional[RequestIterableType]
) -> RequestIterableType:
if request_iterator is None:
# We provide our own request iterator which is a proxy
# of the futures writes that will be done by the caller.
self._write_to_iterator_queue = asyncio.Queue(maxsize=1)
self._write_to_iterator_async_gen = self._proxy_writes_as_request_iterator(
)
request_iterator = self._write_to_iterator_async_gen
else:
self._write_to_iterator_queue = None
return request_iterator
async def _proxy_writes_as_request_iterator(self):
await self._interceptors_task
while True:
value = await self._write_to_iterator_queue.get()
if value is _InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL:
break
yield value
async def write(self, request: RequestType) -> None:
# If no queue was created it means that requests
# should be expected through an iterators provided
# by the caller.
if self._write_to_iterator_queue is None:
raise cygrpc.UsageError(_API_STYLE_ERROR)
try:
call = await self._interceptors_task
except (asyncio.CancelledError, AioRpcError):
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
if call.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
elif call._done_writing_flag:
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
# Write might never end up since the call could abrubtly finish,
# we give up on the first awaitable object that finishes.
_, _ = await asyncio.wait(
(self._write_to_iterator_queue.put(request), call.code()),
return_when=asyncio.FIRST_COMPLETED)
if call.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
async def done_writing(self) -> None:
"""Signal peer that client is done writing.
This method is idempotent.
"""
# If no queue was created it means that requests
# should be expected through an iterators provided
# by the caller.
if self._write_to_iterator_queue is None:
raise cygrpc.UsageError(_API_STYLE_ERROR)
try:
call = await self._interceptors_task
except asyncio.CancelledError:
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
# Write might never end up since the call could abrubtly finish,
# we give up on the first awaitable object that finishes.
_, _ = await asyncio.wait((self._write_to_iterator_queue.put(
_InterceptedStreamRequestMixin._FINISH_ITERATOR_SENTINEL),
call.code()),
return_when=asyncio.FIRST_COMPLETED)
class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall,
_base_call.UnaryUnaryCall):
"""Used for running a `UnaryUnaryCall` wrapped by interceptors.
@ -463,12 +613,12 @@ class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall,
raise NotImplementedError()
class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall):
class InterceptedUnaryStreamCall(_InterceptedStreamResponseMixin,
InterceptedCall, _base_call.UnaryStreamCall):
"""Used for running a `UnaryStreamCall` wrapped by interceptors."""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_response_aiter: AsyncIterable[ResponseType]
_last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall]
# pylint: disable=too-many-arguments
@ -482,8 +632,7 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall):
loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._channel = channel
self._response_aiter = self._wait_for_interceptor_task_response_iterator(
)
self._init_stream_response_mixin()
self._last_returned_call_from_interceptors = None
interceptors_task = loop.create_task(
self._invoke(interceptors, method, timeout, metadata, credentials,
@ -517,7 +666,7 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall):
continuation, client_call_details, request)
if isinstance(call_or_response_iterator,
_base_call.UnaryUnaryCall):
_base_call.UnaryStreamCall):
self._last_returned_call_from_interceptors = call_or_response_iterator
else:
self._last_returned_call_from_interceptors = UnaryStreamCallResponseIterator(
@ -540,23 +689,12 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall):
return await _run_interceptor(iter(interceptors), client_call_details,
request)
async def _wait_for_interceptor_task_response_iterator(self
) -> ResponseType:
call = await self._interceptors_task
async for response in call:
yield response
def __aiter__(self) -> AsyncIterable[ResponseType]:
return self._response_aiter
async def read(self) -> ResponseType:
return await self._response_aiter.asend(None)
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin,
_InterceptedStreamRequestMixin,
InterceptedCall, _base_call.StreamUnaryCall):
"""Used for running a `StreamUnaryCall` wrapped by interceptors.
@ -566,10 +704,6 @@ class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin,
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_write_to_iterator_async_gen: Optional[AsyncIterable[RequestType]]
_write_to_iterator_queue: Optional[asyncio.Queue]
_FINISH_ITERATOR_SENTINEL = object()
# pylint: disable=too-many-arguments
def __init__(self, interceptors: Sequence[StreamUnaryClientInterceptor],
@ -582,16 +716,7 @@ class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin,
loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._channel = channel
if request_iterator is None:
# We provide our own request iterator which is a proxy
# of the futures writes that will be done by the caller.
self._write_to_iterator_queue = asyncio.Queue(maxsize=1)
self._write_to_iterator_async_gen = self._proxy_writes_as_request_iterator(
)
request_iterator = self._write_to_iterator_async_gen
else:
self._write_to_iterator_queue = None
request_iterator = self._init_stream_request_mixin(request_iterator)
interceptors_task = loop.create_task(
self._invoke(interceptors, method, timeout, metadata, credentials,
wait_for_ready, request_iterator, request_serializer,
@ -641,62 +766,88 @@ class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin,
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
async def _proxy_writes_as_request_iterator(self):
await self._interceptors_task
while True:
value = await self._write_to_iterator_queue.get()
if value is InterceptedStreamUnaryCall._FINISH_ITERATOR_SENTINEL:
break
yield value
class InterceptedStreamStreamCall(_InterceptedStreamResponseMixin,
_InterceptedStreamRequestMixin,
InterceptedCall, _base_call.StreamStreamCall):
"""Used for running a `StreamStreamCall` wrapped by interceptors."""
async def write(self, request: RequestType) -> None:
# If no queue was created it means that requests
# should be expected through an iterators provided
# by the caller.
if self._write_to_iterator_queue is None:
raise cygrpc.UsageError(_API_STYLE_ERROR)
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall]
try:
call = await self._interceptors_task
except (asyncio.CancelledError, AioRpcError):
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
# pylint: disable=too-many-arguments
def __init__(self, interceptors: Sequence[StreamStreamClientInterceptor],
request_iterator: Optional[RequestIterableType],
timeout: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._channel = channel
self._init_stream_response_mixin()
request_iterator = self._init_stream_request_mixin(request_iterator)
self._last_returned_call_from_interceptors = None
interceptors_task = loop.create_task(
self._invoke(interceptors, method, timeout, metadata, credentials,
wait_for_ready, request_iterator, request_serializer,
response_deserializer))
super().__init__(interceptors_task)
if call.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
elif call._done_writing_flag:
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
# pylint: disable=too-many-arguments
async def _invoke(
self, interceptors: Sequence[StreamStreamClientInterceptor],
method: bytes, timeout: Optional[float],
metadata: Optional[MetadataType],
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool],
request_iterator: RequestIterableType,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> StreamStreamCall:
"""Run the RPC call wrapped in interceptors"""
# Write might never end up since the call could abrubtly finish,
# we give up on the first awaitable object that finishes..
_, _ = await asyncio.wait(
(self._write_to_iterator_queue.put(request), call),
return_when=asyncio.FIRST_COMPLETED)
async def _run_interceptor(
interceptors: Iterator[StreamStreamClientInterceptor],
client_call_details: ClientCallDetails,
request_iterator: RequestIterableType
) -> _base_call.StreamStreamCall:
if call.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
interceptor = next(interceptors, None)
async def done_writing(self) -> None:
"""Signal peer that client is done writing.
if interceptor:
continuation = functools.partial(_run_interceptor, interceptors)
This method is idempotent.
"""
# If no queue was created it means that requests
# should be expected through an iterators provided
# by the caller.
if self._write_to_iterator_queue is None:
raise cygrpc.UsageError(_API_STYLE_ERROR)
call_or_response_iterator = await interceptor.intercept_stream_stream(
continuation, client_call_details, request_iterator)
try:
call = await self._interceptors_task
except asyncio.CancelledError:
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
if isinstance(call_or_response_iterator,
_base_call.StreamStreamCall):
self._last_returned_call_from_interceptors = call_or_response_iterator
else:
self._last_returned_call_from_interceptors = StreamStreamCallResponseIterator(
self._last_returned_call_from_interceptors,
call_or_response_iterator)
return self._last_returned_call_from_interceptors
else:
self._last_returned_call_from_interceptors = StreamStreamCall(
request_iterator,
_timeout_to_deadline(client_call_details.timeout),
client_call_details.metadata,
client_call_details.credentials,
client_call_details.wait_for_ready, self._channel,
client_call_details.method, request_serializer,
response_deserializer, self._loop)
return self._last_returned_call_from_interceptors
# Write might never end up since the call could abrubtly finish,
# we give up on the first awaitable object that finishes.
_, _ = await asyncio.wait((self._write_to_iterator_queue.put(
InterceptedStreamUnaryCall._FINISH_ITERATOR_SENTINEL), call),
return_when=asyncio.FIRST_COMPLETED)
client_call_details = ClientCallDetails(method, timeout, metadata,
credentials, wait_for_ready)
return await _run_interceptor(iter(interceptors), client_call_details,
request_iterator)
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
@ -747,12 +898,13 @@ class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
pass
class UnaryStreamCallResponseIterator(_base_call.UnaryStreamCall):
"""UnaryStreamCall class wich uses an alternative response iterator."""
_call: _base_call.UnaryStreamCall
class _StreamCallResponseIterator:
_call: Union[_base_call.UnaryStreamCall, _base_call.StreamStreamCall]
_response_iterator: AsyncIterable[ResponseType]
def __init__(self, call: _base_call.UnaryStreamCall,
def __init__(self, call: Union[_base_call.UnaryStreamCall, _base_call.
StreamStreamCall],
response_iterator: AsyncIterable[ResponseType]) -> None:
self._response_iterator = response_iterator
self._call = call
@ -797,3 +949,29 @@ class UnaryStreamCallResponseIterator(_base_call.UnaryStreamCall):
# Behind the scenes everyting goes through the
# async iterator. So this path should not be reached.
raise Exception()
class UnaryStreamCallResponseIterator(_StreamCallResponseIterator,
_base_call.UnaryStreamCall):
"""UnaryStreamCall class wich uses an alternative response iterator."""
class StreamStreamCallResponseIterator(_StreamCallResponseIterator,
_base_call.StreamStreamCall):
"""UnaryStreamCall class wich uses an alternative response iterator."""
async def write(self, request: RequestType) -> None:
# Behind the scenes everyting goes through the
# async iterator provided by the InterceptedStreamStreamCall.
# So this path should not be reached.
raise Exception()
async def done_writing(self) -> None:
# Behind the scenes everyting goes through the
# async iterator provided by the InterceptedStreamStreamCall.
# So this path should not be reached.
raise Exception()
@property
def _done_writing_flag(self) -> bool:
return self._call._done_writing_flag

@ -16,6 +16,7 @@
"unit.channel_argument_test.TestChannelArgument",
"unit.channel_ready_test.TestChannelReady",
"unit.channel_test.TestChannel",
"unit.client_stream_stream_interceptor_test.TestStreamStreamClientInterceptor",
"unit.client_stream_unary_interceptor_test.TestStreamUnaryClientInterceptor",
"unit.client_unary_stream_interceptor_test.TestUnaryStreamClientInterceptor",
"unit.client_unary_unary_interceptor_test.TestInterceptedUnaryUnaryCall",

@ -64,3 +64,33 @@ def inject_callbacks(call):
test_constants.SHORT_TIMEOUT)
return validation()
class CountingRequestIterator:
def __init__(self, request_iterator):
self.request_cnt = 0
self._request_iterator = request_iterator
async def _forward_requests(self):
async for request in self._request_iterator:
self.request_cnt += 1
yield request
def __aiter__(self):
return self._forward_requests()
class CountingResponseIterator:
def __init__(self, response_iterator):
self.response_cnt = 0
self._response_iterator = response_iterator
async def _forward_responses(self):
async for response in self._response_iterator:
self.response_cnt += 1
yield response
def __aiter__(self):
return self._forward_responses()

@ -0,0 +1,202 @@
# Copyright 2020 The gRPC Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import unittest
import grpc
from grpc.experimental import aio
from tests_aio.unit._common import CountingResponseIterator, CountingRequestIterator
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
_NUM_STREAM_RESPONSES = 5
_NUM_STREAM_REQUESTS = 5
_RESPONSE_PAYLOAD_SIZE = 7
class _StreamStreamInterceptorEmpty(aio.StreamStreamClientInterceptor):
async def intercept_stream_stream(self, continuation, client_call_details,
request_iterator):
return await continuation(client_call_details, request_iterator)
def assert_in_final_state(self, test: unittest.TestCase):
pass
class _StreamStreamInterceptorWithRequestAndResponseIterator(
aio.StreamStreamClientInterceptor):
async def intercept_stream_stream(self, continuation, client_call_details,
request_iterator):
self.request_iterator = CountingRequestIterator(request_iterator)
call = await continuation(client_call_details, self.request_iterator)
self.response_iterator = CountingResponseIterator(call)
return self.response_iterator
def assert_in_final_state(self, test: unittest.TestCase):
test.assertEqual(_NUM_STREAM_REQUESTS,
self.request_iterator.request_cnt)
test.assertEqual(_NUM_STREAM_RESPONSES,
self.response_iterator.response_cnt)
class TestStreamStreamClientInterceptor(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
async def tearDown(self):
await self._server.stop(None)
async def test_intercepts(self):
for interceptor_class in (
_StreamStreamInterceptorEmpty,
_StreamStreamInterceptorWithRequestAndResponseIterator):
with self.subTest(name=interceptor_class):
interceptor = interceptor_class()
channel = aio.insecure_channel(self._server_target,
interceptors=[interceptor])
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE))
async def request_iterator():
for _ in range(_NUM_STREAM_REQUESTS):
yield request
call = stub.FullDuplexCall(request_iterator())
await call.wait_for_connection()
response_cnt = 0
async for response in call:
response_cnt += 1
self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.details(), '')
self.assertEqual(await call.debug_error_string(), '')
self.assertEqual(call.cancel(), False)
self.assertEqual(call.cancelled(), False)
self.assertEqual(call.done(), True)
interceptor.assert_in_final_state(self)
await channel.close()
async def test_intercepts_using_write_and_read(self):
for interceptor_class in (
_StreamStreamInterceptorEmpty,
_StreamStreamInterceptorWithRequestAndResponseIterator):
with self.subTest(name=interceptor_class):
interceptor = interceptor_class()
channel = aio.insecure_channel(self._server_target,
interceptors=[interceptor])
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE))
call = stub.FullDuplexCall()
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
response = await call.read()
self.assertIsInstance(
response, messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
await call.done_writing()
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.details(), '')
self.assertEqual(await call.debug_error_string(), '')
self.assertEqual(call.cancel(), False)
self.assertEqual(call.cancelled(), False)
self.assertEqual(call.done(), True)
interceptor.assert_in_final_state(self)
await channel.close()
async def test_multiple_interceptors_request_iterator(self):
for interceptor_class in (
_StreamStreamInterceptorEmpty,
_StreamStreamInterceptorWithRequestAndResponseIterator):
with self.subTest(name=interceptor_class):
interceptors = [interceptor_class(), interceptor_class()]
channel = aio.insecure_channel(self._server_target,
interceptors=interceptors)
stub = test_pb2_grpc.TestServiceStub(channel)
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE))
call = stub.FullDuplexCall()
for _ in range(_NUM_STREAM_RESPONSES):
await call.write(request)
response = await call.read()
self.assertIsInstance(
response, messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
await call.done_writing()
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.details(), '')
self.assertEqual(await call.debug_error_string(), '')
self.assertEqual(call.cancel(), False)
self.assertEqual(call.cancelled(), False)
self.assertEqual(call.done(), True)
for interceptor in interceptors:
interceptor.assert_in_final_state(self)
await channel.close()
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)

@ -21,6 +21,7 @@ import grpc
from grpc.experimental import aio
from tests_aio.unit._constants import UNREACHABLE_TARGET
from tests_aio.unit._common import inject_callbacks
from tests_aio.unit._common import CountingRequestIterator
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase
from tests.unit.framework.common import test_constants
@ -33,21 +34,6 @@ _REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000)
class _CountingRequestIterator:
def __init__(self, request_iterator):
self.request_cnt = 0
self._request_iterator = request_iterator
async def _forward_requests(self):
async for request in self._request_iterator:
self.request_cnt += 1
yield request
def __aiter__(self):
return self._forward_requests()
class _StreamUnaryInterceptorEmpty(aio.StreamUnaryClientInterceptor):
async def intercept_stream_unary(self, continuation, client_call_details,
@ -63,7 +49,7 @@ class _StreamUnaryInterceptorWithRequestIterator(
async def intercept_stream_unary(self, continuation, client_call_details,
request_iterator):
self.request_iterator = _CountingRequestIterator(request_iterator)
self.request_iterator = CountingRequestIterator(request_iterator)
call = await continuation(client_call_details, self.request_iterator)
return call

@ -21,6 +21,7 @@ import grpc
from grpc.experimental import aio
from tests_aio.unit._constants import UNREACHABLE_TARGET
from tests_aio.unit._common import inject_callbacks
from tests_aio.unit._common import CountingResponseIterator
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase
from tests.unit.framework.common import test_constants
@ -34,21 +35,6 @@ _RESPONSE_PAYLOAD_SIZE = 7
_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000)
class _CountingResponseIterator:
def __init__(self, response_iterator):
self.response_cnt = 0
self._response_iterator = response_iterator
async def _forward_responses(self):
async for response in self._response_iterator:
self.response_cnt += 1
yield response
def __aiter__(self):
return self._forward_responses()
class _UnaryStreamInterceptorEmpty(aio.UnaryStreamClientInterceptor):
async def intercept_unary_stream(self, continuation, client_call_details,
@ -65,7 +51,7 @@ class _UnaryStreamInterceptorWithResponseIterator(
async def intercept_unary_stream(self, continuation, client_call_details,
request):
call = await continuation(client_call_details, request)
self.response_iterator = _CountingResponseIterator(call)
self.response_iterator = CountingResponseIterator(call)
return self.response_iterator
def assert_in_final_state(self, test: unittest.TestCase):

Loading…
Cancel
Save