[Aio] Stream Unary client interceptor

Add support for Stream Unary client interceptor
pull/22821/head
Pau Freixes 5 years ago
parent 23c8cfcfda
commit 32acd9d5fd
  1. 4
      src/python/grpcio/grpc/experimental/aio/__init__.py
  2. 11
      src/python/grpcio/grpc/experimental/aio/_call.py
  3. 29
      src/python/grpcio/grpc/experimental/aio/_channel.py
  4. 211
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  5. 1
      src/python/grpcio_tests/tests_aio/tests.json
  6. 531
      src/python/grpcio_tests/tests_aio/unit/client_stream_unary_interceptor_test.py

@ -33,7 +33,8 @@ from ._call import AioRpcError
from ._interceptor import (ClientCallDetails, ClientInterceptor,
InterceptedUnaryUnaryCall,
UnaryUnaryClientInterceptor,
UnaryStreamClientInterceptor, ServerInterceptor)
UnaryStreamClientInterceptor,
StreamUnaryClientInterceptor, ServerInterceptor)
from ._server import server
from ._base_server import Server, ServicerContext
from ._typing import ChannelArgumentType
@ -61,6 +62,7 @@ __all__ = (
'ClientInterceptor',
'UnaryStreamClientInterceptor',
'UnaryUnaryClientInterceptor',
'StreamUnaryClientInterceptor',
'InterceptedUnaryUnaryCall',
'ServerInterceptor',
'insecure_channel',

@ -35,6 +35,7 @@ _LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
_GC_CANCELLATION_DETAILS = 'Cancelled upon garbage collection!'
_RPC_ALREADY_FINISHED_DETAILS = 'RPC already finished.'
_RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".'
_API_STYLE_ERROR = 'Please don\'t mix two styles of API for streaming requests'
_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tstatus = {}\n'
@ -302,8 +303,7 @@ class _StreamResponseMixin(Call):
if self._response_style is _APIStyle.UNKNOWN:
self._response_style = style
elif self._response_style is not style:
raise cygrpc.UsageError(
'Please don\'t mix two styles of API for streaming responses')
raise cygrpc.UsageError(_API_STYLE_ERROR)
def cancel(self) -> bool:
if super().cancel():
@ -381,8 +381,7 @@ class _StreamRequestMixin(Call):
def _raise_for_different_style(self, style: _APIStyle):
if self._request_style is not style:
raise cygrpc.UsageError(
'Please don\'t mix two styles of API for streaming requests')
raise cygrpc.UsageError(_API_STYLE_ERROR)
def cancel(self) -> bool:
if super().cancel():
@ -399,7 +398,8 @@ class _StreamRequestMixin(Call):
request_iterator: RequestIterableType
) -> None:
try:
if inspect.isasyncgen(request_iterator):
if inspect.isasyncgen(request_iterator) or hasattr(
request_iterator, '__aiter__'):
async for request in request_iterator:
await self._write(request)
else:
@ -426,7 +426,6 @@ class _StreamRequestMixin(Call):
serialized_request = _common.serialize(request,
self._request_serializer)
try:
await self._cython_call.send_serialized_message(serialized_request)
except asyncio.CancelledError:

@ -25,9 +25,11 @@ from . import _base_call, _base_channel
from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall,
UnaryUnaryCall)
from ._interceptor import (InterceptedUnaryUnaryCall,
InterceptedUnaryStreamCall, ClientInterceptor,
InterceptedUnaryStreamCall,
InterceptedStreamUnaryCall, ClientInterceptor,
UnaryUnaryClientInterceptor,
UnaryStreamClientInterceptor)
UnaryStreamClientInterceptor,
StreamUnaryClientInterceptor)
from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
SerializingFunction, RequestIterableType)
from ._utils import _timeout_to_deadline
@ -167,10 +169,17 @@ class StreamUnaryMultiCallable(_BaseMultiCallable,
deadline = _timeout_to_deadline(timeout)
if not self._interceptors:
call = StreamUnaryCall(request_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel,
self._method, self._request_serializer,
self._response_deserializer, self._loop)
else:
call = InterceptedStreamUnaryCall(
self._interceptors, request_iterator, deadline, metadata,
credentials, wait_for_ready, self._channel, self._method,
self._request_serializer, self._response_deserializer,
self._loop)
return call
@ -204,6 +213,7 @@ class Channel(_base_channel.Channel):
_channel: cygrpc.AioChannel
_unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
_unary_stream_interceptors: List[UnaryStreamClientInterceptor]
_stream_unary_interceptors: List[StreamUnaryClientInterceptor]
def __init__(self, target: str, options: ChannelArgumentType,
credentials: Optional[grpc.ChannelCredentials],
@ -222,12 +232,15 @@ class Channel(_base_channel.Channel):
"""
self._unary_unary_interceptors = []
self._unary_stream_interceptors = []
self._stream_unary_interceptors = []
if interceptors:
attrs_and_interceptor_classes = ((self._unary_unary_interceptors,
UnaryUnaryClientInterceptor),
(self._unary_stream_interceptors,
UnaryStreamClientInterceptor))
UnaryStreamClientInterceptor),
(self._stream_unary_interceptors,
StreamUnaryClientInterceptor))
# pylint: disable=cell-var-from-loop
for attr, interceptor_class in attrs_and_interceptor_classes:
@ -238,13 +251,15 @@ class Channel(_base_channel.Channel):
invalid_interceptors = set(interceptors) - set(
self._unary_unary_interceptors) - set(
self._unary_stream_interceptors)
self._unary_stream_interceptors) - set(
self._stream_unary_interceptors)
if invalid_interceptors:
raise ValueError(
"Interceptor must be "+\
"UnaryUnaryClientInterceptors or "+\
"UnaryStreamClientInterceptors. The following are invalid: {}"\
"UnaryUnaryClientInterceptors or "+\
"StreamUnaryClientInterceptors. The following are invalid: {}"\
.format(invalid_interceptors))
self._loop = asyncio.get_event_loop()
@ -383,7 +398,9 @@ class Channel(_base_channel.Channel):
) -> StreamUnaryMultiCallable:
return StreamUnaryMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer, None, self._loop)
response_deserializer,
self._stream_unary_interceptors,
self._loop)
def stream_stream(
self,

@ -22,10 +22,13 @@ import grpc
from grpc._cython import cygrpc
from . import _base_call
from ._call import UnaryUnaryCall, UnaryStreamCall, AioRpcError
from ._call import UnaryUnaryCall, UnaryStreamCall, StreamUnaryCall, AioRpcError
from ._call import _RPC_ALREADY_FINISHED_DETAILS, _RPC_HALF_CLOSED_DETAILS
from ._call import _API_STYLE_ERROR
from ._utils import _timeout_to_deadline
from ._typing import (RequestType, SerializingFunction, DeserializingFunction,
MetadataType, ResponseType, DoneCallbackType)
MetadataType, ResponseType, DoneCallbackType,
RequestIterableType)
_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
@ -132,13 +135,17 @@ class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
) -> Union[AsyncIterable[ResponseType], UnaryStreamCall]:
"""Intercepts a unary-stream invocation asynchronously.
The function could return the call object or an asynchronous
iterator, in case of being an asyncrhonous iterator this will
become the source of the reads done by the caller.
Args:
continuation: A coroutine that proceeds with the invocation by
executing the next interceptor in chain or invoking the
actual RPC on the underlying Channel. It is the interceptor's
responsibility to call it if it decides to move the RPC forward.
The interceptor can use
`call = await continuation(client_call_details, request, response_iterator))`
`call = await continuation(client_call_details, request)`
to continue with the RPC. `continuation` returns the call to the
RPC.
client_call_details: A ClientCallDetails object describing the
@ -154,6 +161,42 @@ class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
"""
class StreamUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
"""Affords intercepting stream-unary invocations."""
@abstractmethod
async def intercept_stream_unary(
self,
continuation: Callable[[ClientCallDetails, RequestType],
UnaryStreamCall],
client_call_details: ClientCallDetails,
request_iterator: RequestIterableType,
) -> StreamUnaryCall:
"""Intercepts a stream-unary invocation asynchronously.
Args:
continuation: A coroutine that proceeds with the invocation by
executing the next interceptor in chain or invoking the
actual RPC on the underlying Channel. It is the interceptor's
responsibility to call it if it decides to move the RPC forward.
The interceptor can use
`call = await continuation(client_call_details, request_iterator)`
to continue with the RPC. `continuation` returns the call to the
RPC.
client_call_details: A ClientCallDetails object describing the
outgoing RPC.
request_iterator: The request iterator that will produce requests
for the RPC.
Returns:
The RPC Call.
Raises:
AioRpcError: Indicating that the RPC terminated with non-OK status.
asyncio.CancelledError: Indicating that the RPC was canceled.
"""
class InterceptedCall:
"""Base implementation for all intecepted call arities.
@ -332,7 +375,16 @@ class InterceptedCall:
return await call.wait_for_connection()
class InterceptedUnaryUnaryCall(InterceptedCall, _base_call.UnaryUnaryCall):
class _InterceptedUnaryResponseMixin:
def __await__(self):
call = yield from self._interceptors_task.__await__()
response = yield from call.__await__()
return response
class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall,
_base_call.UnaryUnaryCall):
"""Used for running a `UnaryUnaryCall` wrapped by interceptors.
For the `__await__` method is it is proxied to the intercepted call only when
@ -402,11 +454,6 @@ class InterceptedUnaryUnaryCall(InterceptedCall, _base_call.UnaryUnaryCall):
return await _run_interceptor(iter(interceptors), client_call_details,
request)
def __await__(self):
call = yield from self._interceptors_task.__await__()
response = yield from call.__await__()
return response
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
@ -504,6 +551,152 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall):
raise NotImplementedError()
class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin,
InterceptedCall, _base_call.StreamUnaryCall):
"""Used for running a `StreamUnaryCall` wrapped by interceptors.
For the `__await__` method is it is proxied to the intercepted call only when
the interceptor task is finished.
"""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_write_to_iterator_async_gen: Optional[AsyncIterable[RequestType]]
_write_to_iterator_queue: Optional[asyncio.Queue]
_FINISH_ITERATOR_SENTINEL = tuple()
# pylint: disable=too-many-arguments
def __init__(self, interceptors: Sequence[StreamUnaryClientInterceptor],
request_iterator: Optional[RequestIterableType],
timeout: Optional[float], metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._channel = channel
if not request_iterator:
# We provide our own request iterator which is a proxy
# of the future wries done by the caller. This iterator
# will use internally a queue for consuming messages produced
# by the write method.
self._write_to_iterator_queue = asyncio.Queue(maxsize=1)
self._write_to_iterator_async_gen = self._proxies_writes_as_a_request_iteerator(
)
request_iterator = self._write_to_iterator_async_gen
else:
self._write_to_iterator_queue = None
interceptors_task = loop.create_task(
self._invoke(interceptors, method, timeout, metadata, credentials,
wait_for_ready, request_iterator, request_serializer,
response_deserializer))
super().__init__(interceptors_task)
# pylint: disable=too-many-arguments
async def _invoke(
self, interceptors: Sequence[StreamUnaryClientInterceptor],
method: bytes, timeout: Optional[float],
metadata: Optional[MetadataType],
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool],
request_iterator: RequestIterableType,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> StreamUnaryCall:
"""Run the RPC call wrapped in interceptors"""
async def _run_interceptor(
interceptors: Iterator[UnaryUnaryClientInterceptor],
client_call_details: ClientCallDetails,
request_iterator: RequestIterableType
) -> _base_call.StreamUnaryCall:
interceptor = next(interceptors, None)
if interceptor:
continuation = functools.partial(_run_interceptor, interceptors)
return await interceptor.intercept_stream_unary(
continuation, client_call_details, request_iterator)
else:
return StreamUnaryCall(
request_iterator,
_timeout_to_deadline(client_call_details.timeout),
client_call_details.metadata,
client_call_details.credentials,
client_call_details.wait_for_ready, self._channel,
client_call_details.method, request_serializer,
response_deserializer, self._loop)
client_call_details = ClientCallDetails(method, timeout, metadata,
credentials, wait_for_ready)
return await _run_interceptor(iter(interceptors), client_call_details,
request_iterator)
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
async def _proxies_writes_as_a_request_iteerator(self):
await self._interceptors_task
while True:
value = await self._write_to_iterator_queue.get()
if value is InterceptedStreamUnaryCall._FINISH_ITERATOR_SENTINEL:
break
yield value
async def write(self, request: RequestType) -> None:
# If no queue was created it means that requests
# should be expected through an iterators provided
# by the caller.
if self._write_to_iterator_queue is None:
raise cygrpc.UsageError(
'Please don\'t mix two styles of API for streaming requests')
try:
call = await self._interceptors_task
except (asyncio.CancelledError, AioRpcError):
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
if call.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
elif call._done_writing_flag:
raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
# Write might never end up since the call could abrubtly finish,
# we give up on the first awaitable object that finishes..
_, _ = await asyncio.wait(
(self._write_to_iterator_queue.put(request), call),
return_when=asyncio.FIRST_COMPLETED)
if call.done():
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
async def done_writing(self) -> None:
"""Signal peer that client is done writing.
This method is idempotent.
"""
# If no queue was created it means that requests
# should be expected through an iterators provided
# by the caller.
if self._write_to_iterator_queue is None:
raise cygrpc.UsageError(_API_STYLE_ERROR)
try:
call = await self._interceptors_task
except asyncio.CancelledError:
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
# Write might never end up since the call could abrubtly finish,
# we give up on the first awaitable object that finishes.
_, _ = await asyncio.wait((self._write_to_iterator_queue.put(
InterceptedStreamUnaryCall._FINISH_ITERATOR_SENTINEL), call),
return_when=asyncio.FIRST_COMPLETED)
class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
"""Final UnaryUnaryCall class finished with a response."""
_response: ResponseType

@ -13,6 +13,7 @@
"unit.channel_argument_test.TestChannelArgument",
"unit.channel_ready_test.TestChannelReady",
"unit.channel_test.TestChannel",
"unit.client_stream_unary_interceptor_test.TestStreamUnaryClientInterceptor",
"unit.client_unary_stream_interceptor_test.TestUnaryStreamClientInterceptor",
"unit.client_unary_unary_interceptor_test.TestInterceptedUnaryUnaryCall",
"unit.client_unary_unary_interceptor_test.TestUnaryUnaryClientInterceptor",

@ -0,0 +1,531 @@
# Copyright 2020 The gRPC Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import unittest
import datetime
import grpc
from grpc.experimental import aio
from tests_aio.unit._constants import UNREACHABLE_TARGET
from tests_aio.unit._common import inject_callbacks
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase
from tests.unit.framework.common import test_constants
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
_SHORT_TIMEOUT_S = 1.0
_NUM_STREAM_REQUESTS = 5
_REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000)
class _CountingRequestIterator:
def __init__(self, request_iterator):
self.request_cnt = 0
self._request_iterator = request_iterator
async def _forward_requests(self):
async for request in self._request_iterator:
self.request_cnt += 1
yield request
def __aiter__(self):
return self._forward_requests()
class _StreamUnaryInterceptorEmpty(aio.StreamUnaryClientInterceptor):
async def intercept_stream_unary(self, continuation, client_call_details,
request_iterator):
return await continuation(client_call_details, request_iterator)
def assert_in_final_state(self, test: unittest.TestCase):
pass
class _StreamUnaryInterceptorWithRequestIterator(
aio.StreamUnaryClientInterceptor):
async def intercept_stream_unary(self, continuation, client_call_details,
request_iterator):
self.request_iterator = _CountingRequestIterator(request_iterator)
call = await continuation(client_call_details, self.request_iterator)
return call
def assert_in_final_state(self, test: unittest.TestCase):
test.assertEqual(_NUM_STREAM_REQUESTS,
self.request_iterator.request_cnt)
class TestStreamUnaryClientInterceptor(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
async def tearDown(self):
await self._server.stop(None)
async def test_intercepts(self):
for interceptor_class in (_StreamUnaryInterceptorEmpty,
_StreamUnaryInterceptorWithRequestIterator):
with self.subTest(name=interceptor_class):
interceptor = interceptor_class()
channel = aio.insecure_channel(self._server_target,
interceptors=[interceptor])
stub = test_pb2_grpc.TestServiceStub(channel)
payload = messages_pb2.Payload(body=b'\0' *
_REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(
payload=payload)
async def request_iterator():
for _ in range(_NUM_STREAM_REQUESTS):
yield request
call = stub.StreamingInputCall(request_iterator())
response = await call
self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.details(), '')
self.assertEqual(await call.debug_error_string(), '')
self.assertEqual(call.cancel(), False)
self.assertEqual(call.cancelled(), False)
self.assertEqual(call.done(), True)
interceptor.assert_in_final_state(self)
await channel.close()
async def test_intercepts_using_write(self):
for interceptor_class in (_StreamUnaryInterceptorEmpty,
_StreamUnaryInterceptorWithRequestIterator):
with self.subTest(name=interceptor_class):
interceptor = interceptor_class()
channel = aio.insecure_channel(self._server_target,
interceptors=[interceptor])
stub = test_pb2_grpc.TestServiceStub(channel)
payload = messages_pb2.Payload(body=b'\0' *
_REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(
payload=payload)
call = stub.StreamingInputCall()
for _ in range(_NUM_STREAM_REQUESTS):
await call.write(request)
await call.done_writing()
response = await call
self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.details(), '')
self.assertEqual(await call.debug_error_string(), '')
self.assertEqual(call.cancel(), False)
self.assertEqual(call.cancelled(), False)
self.assertEqual(call.done(), True)
interceptor.assert_in_final_state(self)
await channel.close()
async def test_add_done_callback_interceptor_task_not_finished(self):
for interceptor_class in (_StreamUnaryInterceptorEmpty,
_StreamUnaryInterceptorWithRequestIterator):
with self.subTest(name=interceptor_class):
interceptor = interceptor_class()
channel = aio.insecure_channel(self._server_target,
interceptors=[interceptor])
stub = test_pb2_grpc.TestServiceStub(channel)
payload = messages_pb2.Payload(body=b'\0' *
_REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(
payload=payload)
async def request_iterator():
for _ in range(_NUM_STREAM_REQUESTS):
yield request
call = stub.StreamingInputCall(request_iterator())
validation = inject_callbacks(call)
response = await call
await validation
await channel.close()
async def test_add_done_callback_interceptor_task_finished(self):
for interceptor_class in (_StreamUnaryInterceptorEmpty,
_StreamUnaryInterceptorWithRequestIterator):
with self.subTest(name=interceptor_class):
interceptor = interceptor_class()
channel = aio.insecure_channel(self._server_target,
interceptors=[interceptor])
stub = test_pb2_grpc.TestServiceStub(channel)
payload = messages_pb2.Payload(body=b'\0' *
_REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(
payload=payload)
async def request_iterator():
for _ in range(_NUM_STREAM_REQUESTS):
yield request
call = stub.StreamingInputCall(request_iterator())
response = await call
validation = inject_callbacks(call)
await validation
await channel.close()
async def test_multiple_interceptors_request_iterator(self):
for interceptor_class in (_StreamUnaryInterceptorEmpty,
_StreamUnaryInterceptorWithRequestIterator):
with self.subTest(name=interceptor_class):
interceptors = [interceptor_class(), interceptor_class()]
channel = aio.insecure_channel(self._server_target,
interceptors=interceptors)
stub = test_pb2_grpc.TestServiceStub(channel)
payload = messages_pb2.Payload(body=b'\0' *
_REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(
payload=payload)
async def request_iterator():
for _ in range(_NUM_STREAM_REQUESTS):
yield request
call = stub.StreamingInputCall(request_iterator())
response = await call
self.assertEqual(_NUM_STREAM_REQUESTS * _REQUEST_PAYLOAD_SIZE,
response.aggregated_payload_size)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.details(), '')
self.assertEqual(await call.debug_error_string(), '')
self.assertEqual(call.cancel(), False)
self.assertEqual(call.cancelled(), False)
self.assertEqual(call.done(), True)
for interceptor in interceptors:
interceptor.assert_in_final_state(self)
await channel.close()
async def test_intercepts_request_iterator_rpc_error(self):
for interceptor_class in (_StreamUnaryInterceptorEmpty,
_StreamUnaryInterceptorWithRequestIterator):
with self.subTest(name=interceptor_class):
channel = aio.insecure_channel(
UNREACHABLE_TARGET, interceptors=[interceptor_class()])
stub = test_pb2_grpc.TestServiceStub(channel)
payload = messages_pb2.Payload(body=b'\0' *
_REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(
payload=payload)
# When there is an error the request iterator is no longer
# consumed.
async def request_iterator():
for _ in range(_NUM_STREAM_REQUESTS):
yield request
call = stub.StreamingInputCall(request_iterator())
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
self.assertEqual(grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code())
self.assertTrue(call.done())
self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
await channel.close()
async def test_intercepts_request_iterator_rpc_error_using_write(self):
for interceptor_class in (_StreamUnaryInterceptorEmpty,
_StreamUnaryInterceptorWithRequestIterator):
with self.subTest(name=interceptor_class):
channel = aio.insecure_channel(
UNREACHABLE_TARGET, interceptors=[interceptor_class()])
stub = test_pb2_grpc.TestServiceStub(channel)
payload = messages_pb2.Payload(body=b'\0' *
_REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(
payload=payload)
call = stub.StreamingInputCall()
# When there is an error during the write, exception is raised.
with self.assertRaises(asyncio.InvalidStateError):
for _ in range(_NUM_STREAM_REQUESTS):
await call.write(request)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
self.assertEqual(grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code())
self.assertTrue(call.done())
self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
await channel.close()
async def test_cancel_before_rpc(self):
interceptor_reached = asyncio.Event()
wait_for_ever = self.loop.create_future()
class Interceptor(aio.StreamUnaryClientInterceptor):
async def intercept_stream_unary(self, continuation,
client_call_details,
request_iterator):
interceptor_reached.set()
await wait_for_ever
channel = aio.insecure_channel(self._server_target,
interceptors=[Interceptor()])
stub = test_pb2_grpc.TestServiceStub(channel)
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
call = stub.StreamingInputCall()
self.assertFalse(call.cancelled())
self.assertFalse(call.done())
await interceptor_reached.wait()
self.assertTrue(call.cancel())
# When there is an error during the write, exception is raised.
with self.assertRaises(asyncio.InvalidStateError):
for _ in range(_NUM_STREAM_REQUESTS):
await call.write(request)
with self.assertRaises(asyncio.CancelledError):
await call
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None)
await channel.close()
async def test_cancel_after_rpc(self):
interceptor_reached = asyncio.Event()
wait_for_ever = self.loop.create_future()
class Interceptor(aio.StreamUnaryClientInterceptor):
async def intercept_stream_unary(self, continuation,
client_call_details,
request_iterator):
call = await continuation(client_call_details, request_iterator)
interceptor_reached.set()
await wait_for_ever
channel = aio.insecure_channel(self._server_target,
interceptors=[Interceptor()])
stub = test_pb2_grpc.TestServiceStub(channel)
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
call = stub.StreamingInputCall()
self.assertFalse(call.cancelled())
self.assertFalse(call.done())
await interceptor_reached.wait()
self.assertTrue(call.cancel())
# When there is an error during the write, exception is raised.
with self.assertRaises(asyncio.InvalidStateError):
for _ in range(_NUM_STREAM_REQUESTS):
await call.write(request)
with self.assertRaises(asyncio.CancelledError):
await call
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None)
await channel.close()
async def test_cancel_while_writing(self):
# Test cancelation before making any write or after doing at least 1
for num_writes_before_cancel in (0, 1):
with self.subTest(name="Num writes before cancel: {}".format(
num_writes_before_cancel)):
channel = aio.insecure_channel(
UNREACHABLE_TARGET,
interceptors=[_StreamUnaryInterceptorWithRequestIterator()])
stub = test_pb2_grpc.TestServiceStub(channel)
payload = messages_pb2.Payload(body=b'\0' *
_REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(
payload=payload)
call = stub.StreamingInputCall()
with self.assertRaises(asyncio.InvalidStateError):
for i in range(_NUM_STREAM_REQUESTS):
if i == num_writes_before_cancel:
self.assertTrue(call.cancel())
await call.write(request)
with self.assertRaises(asyncio.CancelledError):
await call
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
await channel.close()
async def test_cancel_by_the_interceptor(self):
class Interceptor(aio.StreamUnaryClientInterceptor):
async def intercept_stream_unary(self, continuation,
client_call_details,
request_iterator):
call = await continuation(client_call_details, request_iterator)
call.cancel()
return call
channel = aio.insecure_channel(UNREACHABLE_TARGET,
interceptors=[Interceptor()])
stub = test_pb2_grpc.TestServiceStub(channel)
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
call = stub.StreamingInputCall()
with self.assertRaises(asyncio.InvalidStateError):
for i in range(_NUM_STREAM_REQUESTS):
await call.write(request)
with self.assertRaises(asyncio.CancelledError):
await call
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
await channel.close()
async def test_exception_raised_by_interceptor(self):
class InterceptorException(Exception):
pass
class Interceptor(aio.StreamUnaryClientInterceptor):
async def intercept_stream_unary(self, continuation,
client_call_details,
request_iterator):
raise InterceptorException
channel = aio.insecure_channel(UNREACHABLE_TARGET,
interceptors=[Interceptor()])
stub = test_pb2_grpc.TestServiceStub(channel)
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
call = stub.StreamingInputCall()
with self.assertRaises(InterceptorException):
for i in range(_NUM_STREAM_REQUESTS):
await call.write(request)
with self.assertRaises(InterceptorException):
await call
await channel.close()
async def test_intercepts_prohibit_mixing_style(self):
channel = aio.insecure_channel(
self._server_target, interceptors=[_StreamUnaryInterceptorEmpty()])
stub = test_pb2_grpc.TestServiceStub(channel)
payload = messages_pb2.Payload(body=b'\0' * _REQUEST_PAYLOAD_SIZE)
request = messages_pb2.StreamingInputCallRequest(payload=payload)
async def request_iterator():
for _ in range(_NUM_STREAM_REQUESTS):
yield request
call = stub.StreamingInputCall(request_iterator())
with self.assertRaises(grpc._cython.cygrpc.UsageError):
await call.write(request)
with self.assertRaises(grpc._cython.cygrpc.UsageError):
await call.done_writing()
await channel.close()
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
Loading…
Cancel
Save