Apply feedback

pull/22713/head
Pau Freixes 5 years ago
parent 2aea5c002b
commit dae80a4977
  1. 5
      src/python/grpcio/grpc/experimental/aio/__init__.py
  2. 20
      src/python/grpcio/grpc/experimental/aio/_channel.py
  3. 46
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  4. 31
      src/python/grpcio_tests/tests_aio/unit/_common.py
  5. 182
      src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py
  6. 24
      src/python/grpcio_tests/tests_aio/unit/done_callback_test.py

@ -30,7 +30,8 @@ from ._base_channel import (Channel, StreamStreamMultiCallable,
StreamUnaryMultiCallable, UnaryStreamMultiCallable, StreamUnaryMultiCallable, UnaryStreamMultiCallable,
UnaryUnaryMultiCallable) UnaryUnaryMultiCallable)
from ._call import AioRpcError from ._call import AioRpcError
from ._interceptor import (ClientCallDetails, InterceptedUnaryUnaryCall, from ._interceptor import (ClientCallDetails, ClientInterceptor,
InterceptedUnaryUnaryCall,
UnaryUnaryClientInterceptor, UnaryUnaryClientInterceptor,
UnaryStreamClientInterceptor, ServerInterceptor) UnaryStreamClientInterceptor, ServerInterceptor)
from ._server import server from ._server import server
@ -57,6 +58,8 @@ __all__ = (
'StreamUnaryMultiCallable', 'StreamUnaryMultiCallable',
'StreamStreamMultiCallable', 'StreamStreamMultiCallable',
'ClientCallDetails', 'ClientCallDetails',
'ClientInterceptor',
'UnaryStreamClientInterceptor',
'UnaryUnaryClientInterceptor', 'UnaryUnaryClientInterceptor',
'InterceptedUnaryUnaryCall', 'InterceptedUnaryUnaryCall',
'ServerInterceptor', 'ServerInterceptor',

@ -15,7 +15,7 @@
import asyncio import asyncio
import sys import sys
from typing import Any, Iterable, Optional, Sequence from typing import Any, Iterable, Optional, Sequence, List
import grpc import grpc
from grpc import _common, _compression, _grpcio_metadata from grpc import _common, _compression, _grpcio_metadata
@ -202,8 +202,8 @@ class StreamStreamMultiCallable(_BaseMultiCallable,
class Channel(_base_channel.Channel): class Channel(_base_channel.Channel):
_loop: asyncio.AbstractEventLoop _loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel _channel: cygrpc.AioChannel
_unary_unary_interceptors: Sequence[UnaryUnaryClientInterceptor] _unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
_unary_stream_interceptors: Sequence[UnaryStreamClientInterceptor] _unary_stream_interceptors: List[UnaryStreamClientInterceptor]
def __init__(self, target: str, options: ChannelArgumentType, def __init__(self, target: str, options: ChannelArgumentType,
credentials: Optional[grpc.ChannelCredentials], credentials: Optional[grpc.ChannelCredentials],
@ -224,18 +224,16 @@ class Channel(_base_channel.Channel):
self._unary_stream_interceptors = [] self._unary_stream_interceptors = []
if interceptors: if interceptors:
attrs_and_interceptor_classes = [ attrs_and_interceptor_classes = (
(self._unary_unary_interceptors, UnaryUnaryClientInterceptor), (self._unary_unary_interceptors, UnaryUnaryClientInterceptor),
(self._unary_stream_interceptors, UnaryStreamClientInterceptor) (self._unary_stream_interceptors, UnaryStreamClientInterceptor)
] )
# pylint: disable=cell-var-from-loop # pylint: disable=cell-var-from-loop
for attr, interceptor_class in attrs_and_interceptor_classes: for attr, interceptor_class in attrs_and_interceptor_classes:
attr.extend( attr.extend(
list( [interceptor for interceptor in interceptors if isinstance(interceptor, interceptor_class)]
filter( )
lambda interceptor: isinstance(
interceptor, interceptor_class), interceptors)))
invalid_interceptors = set(interceptors) - set( invalid_interceptors = set(interceptors) - set(
self._unary_unary_interceptors) - set( self._unary_unary_interceptors) - set(
@ -245,7 +243,7 @@ class Channel(_base_channel.Channel):
raise ValueError( raise ValueError(
"Interceptor must be "+\ "Interceptor must be "+\
"UnaryUnaryClientInterceptors or "+\ "UnaryUnaryClientInterceptors or "+\
"UnaryStreamClientInterceptors the following are invalid: {}"\ "UnaryStreamClientInterceptors. The following are invalid: {}"\
.format(invalid_interceptors)) .format(invalid_interceptors))
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
@ -402,7 +400,7 @@ def insecure_channel(
target: str, target: str,
options: Optional[ChannelArgumentType] = None, options: Optional[ChannelArgumentType] = None,
compression: Optional[grpc.Compression] = None, compression: Optional[grpc.Compression] = None,
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] = None): interceptors: Optional[Sequence[ClientInterceptor]] = None):
"""Creates an insecure asynchronous Channel to a server. """Creates an insecure asynchronous Channel to a server.
Args: Args:

@ -126,9 +126,9 @@ class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
@abstractmethod @abstractmethod
async def intercept_unary_stream(self, continuation: Callable[[ async def intercept_unary_stream(self, continuation: Callable[[
ClientCallDetails, RequestType, AsyncIterable[ResponseType] ClientCallDetails, RequestType], UnaryStreamCall],
], UnaryStreamCall], client_call_details: ClientCallDetails, client_call_details: ClientCallDetails,
request: RequestType) -> UnaryStreamCall: request: RequestType) -> Union[AsyncIterable[ResponseType], UnaryStreamCall]:
"""Intercepts a unary-stream invocation asynchronously. """Intercepts a unary-stream invocation asynchronously.
Args: Args:
@ -180,31 +180,32 @@ class InterceptedCall:
self._interceptors_task = interceptors_task self._interceptors_task = interceptors_task
self._pending_add_done_callbacks = [] self._pending_add_done_callbacks = []
self._interceptors_task.add_done_callback( self._interceptors_task.add_done_callback(
self._fire_or_add_pending_add_done_callbacks) self._fire_or_add_pending_done_callbacks)
def __del__(self): def __del__(self):
self.cancel() self.cancel()
def _fire_or_add_pending_add_done_callbacks(self, def _fire_or_add_pending_done_callbacks(self,
interceptors_task: asyncio.Task interceptors_task: asyncio.Task
) -> None: ) -> None:
if not self._pending_add_done_callbacks: if not self._pending_add_done_callbacks:
return return
fire = False call_completed = False
try: try:
call = interceptors_task.result() call = interceptors_task.result()
if call.done(): if call.done():
fire = True call_completed = True
except (AioRpcError, asyncio.CancelledError): except (AioRpcError, asyncio.CancelledError):
fire = True call_completed = True
for callback in self._pending_add_done_callbacks: if call_completed:
if fire: for callback in self._pending_add_done_callbacks:
callback(self) callback(self)
else: else:
for callback in self._pending_add_done_callbacks:
callback = functools.partial(self._wrap_add_done_callback, callback = functools.partial(self._wrap_add_done_callback,
callback) callback)
call.add_done_callback(callback) call.add_done_callback(callback)
@ -415,6 +416,7 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall):
_loop: asyncio.AbstractEventLoop _loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel _channel: cygrpc.AioChannel
_response_aiter: AsyncIterable[ResponseType] _response_aiter: AsyncIterable[ResponseType]
_last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall]
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
def __init__(self, interceptors: Sequence[UnaryStreamClientInterceptor], def __init__(self, interceptors: Sequence[UnaryStreamClientInterceptor],
@ -429,6 +431,7 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall):
self._channel = channel self._channel = channel
self._response_aiter = self._wait_for_interceptor_task_response_iterator( self._response_aiter = self._wait_for_interceptor_task_response_iterator(
) )
self._last_returned_call_from_interceptors = None
interceptors_task = loop.create_task( interceptors_task = loop.create_task(
self._invoke(interceptors, method, timeout, metadata, credentials, self._invoke(interceptors, method, timeout, metadata, credentials,
wait_for_ready, request, request_serializer, wait_for_ready, request, request_serializer,
@ -446,7 +449,6 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall):
) -> UnaryStreamCall: ) -> UnaryStreamCall:
"""Run the RPC call wrapped in interceptors""" """Run the RPC call wrapped in interceptors"""
last_returned_call_from_interceptors = [None]
async def _run_interceptor( async def _run_interceptor(
interceptors: Iterator[UnaryStreamClientInterceptor], interceptors: Iterator[UnaryStreamClientInterceptor],
@ -462,17 +464,15 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall):
call_or_response_iterator = await interceptor.intercept_unary_stream( call_or_response_iterator = await interceptor.intercept_unary_stream(
continuation, client_call_details, request) continuation, client_call_details, request)
if call_or_response_iterator is last_returned_call_from_interceptors[ if isinstance(call_or_response_iterator, _base_call.UnaryUnaryCall):
0]: self._last_returned_call_from_interceptors = call_or_response_iterator
return call_or_response_iterator
else: else:
last_returned_call_from_interceptors[ self._last_returned_call_from_interceptors = UnaryStreamCallResponseIterator(
0] = UnaryStreamCallResponseIterator( self._last_returned_call_from_interceptors,
last_returned_call_from_interceptors[0],
call_or_response_iterator) call_or_response_iterator)
return last_returned_call_from_interceptors[0] return self._last_returned_call_from_interceptors
else: else:
last_returned_call_from_interceptors[0] = UnaryStreamCall( self._last_returned_call_from_interceptors = UnaryStreamCall(
request, _timeout_to_deadline(client_call_details.timeout), request, _timeout_to_deadline(client_call_details.timeout),
client_call_details.metadata, client_call_details.metadata,
client_call_details.credentials, client_call_details.credentials,
@ -480,7 +480,7 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall):
client_call_details.method, request_serializer, client_call_details.method, request_serializer,
response_deserializer, self._loop) response_deserializer, self._loop)
return last_returned_call_from_interceptors[0] return self._last_returned_call_from_interceptors
client_call_details = ClientCallDetails(method, timeout, metadata, client_call_details = ClientCallDetails(method, timeout, metadata,
credentials, wait_for_ready) credentials, wait_for_ready)
@ -598,4 +598,6 @@ class UnaryStreamCallResponseIterator(_base_call.UnaryStreamCall):
return await self._call.wait_for_connection() return await self._call.wait_for_connection()
async def read(self) -> ResponseType: async def read(self) -> ResponseType:
return await self._call.read() # Behind the scenes everyting goes through the
# async iterator. So this path should not be reached.
raise Exception()

@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import grpc import grpc
from grpc.experimental import aio from grpc.experimental import aio
from grpc.experimental.aio._typing import MetadataType, MetadatumType from grpc.experimental.aio._typing import MetadataType, MetadatumType
from tests.unit.framework.common import test_constants
def seen_metadata(expected: MetadataType, actual: MetadataType): def seen_metadata(expected: MetadataType, actual: MetadataType):
return not bool(set(expected) - set(actual)) return not bool(set(expected) - set(actual))
@ -32,3 +35,31 @@ async def block_until_certain_state(channel: aio.Channel,
while state != expected_state: while state != expected_state:
await channel.wait_for_state_change(state) await channel.wait_for_state_change(state)
state = channel.get_state() state = channel.get_state()
def inject_callbacks(call):
first_callback_ran = asyncio.Event()
def first_callback(call):
# Validate that all resopnses have been received
# and the call is an end state.
assert call.done()
first_callback_ran.set()
second_callback_ran = asyncio.Event()
def second_callback(call):
# Validate that all resopnses have been received
# and the call is an end state.
assert call.done()
second_callback_ran.set()
call.add_done_callback(first_callback)
call.add_done_callback(second_callback)
async def validation():
await asyncio.wait_for(
asyncio.gather(first_callback_ran.wait(),
second_callback_ran.wait()),
test_constants.SHORT_TIMEOUT)
return validation()

@ -1,4 +1,4 @@
# Copyright 2019 The gRPC Authors. # Copyright 2020 The gRPC Authors.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -20,67 +20,34 @@ import grpc
from grpc.experimental import aio from grpc.experimental import aio
from tests_aio.unit._constants import UNREACHABLE_TARGET from tests_aio.unit._constants import UNREACHABLE_TARGET
from tests_aio.unit._common import inject_callbacks
from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_base import AioTestBase
from tests.unit.framework.common import test_constants from tests.unit.framework.common import test_constants
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
_SHORT_TIMEOUT_S = datetime.timedelta(seconds=1).total_seconds() _SHORT_TIMEOUT_S = 1.0
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
_NUM_STREAM_RESPONSES = 5 _NUM_STREAM_RESPONSES = 5
_REQUEST_PAYLOAD_SIZE = 7 _REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 7 _RESPONSE_PAYLOAD_SIZE = 7
_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000) _RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000)
class _ResponseIterator: class _CountingResponseIterator:
def __init__(self, response_iterator): def __init__(self, response_iterator):
self._response_cnt = 0 self.response_cnt = 0
self._response_iterator = response_iterator self._response_iterator = response_iterator
async def _forward_responses(self): async def _forward_responses(self):
async for response in self._response_iterator: async for response in self._response_iterator:
self._response_cnt += 1 self.response_cnt += 1
yield response yield response
def __aiter__(self): def __aiter__(self):
return self._forward_responses() return self._forward_responses()
@property
def response_cnt(self):
return self._response_cnt
def _inject_callbacks(call):
first_callback_ran = asyncio.Event()
def first_callback(call):
# Validate that all resopnses have been received
# and the call is an end state.
assert call.done()
first_callback_ran.set()
second_callback_ran = asyncio.Event()
def second_callback(call):
# Validate that all resopnses have been received
# and the call is an end state.
assert call.done()
second_callback_ran.set()
call.add_done_callback(first_callback)
call.add_done_callback(second_callback)
async def validation():
await asyncio.wait_for(
asyncio.gather(first_callback_ran.wait(),
second_callback_ran.wait()),
test_constants.SHORT_TIMEOUT)
return validation()
class _UnaryStreamInterceptorEmpty(aio.UnaryStreamClientInterceptor): class _UnaryStreamInterceptorEmpty(aio.UnaryStreamClientInterceptor):
@ -89,7 +56,7 @@ class _UnaryStreamInterceptorEmpty(aio.UnaryStreamClientInterceptor):
return await continuation(client_call_details, request) return await continuation(client_call_details, request)
class _UnaryStreamInterceptorWith_ResponseIterator( class _UnaryStreamInterceptorWithResponseIterator(
aio.UnaryStreamClientInterceptor): aio.UnaryStreamClientInterceptor):
def __init__(self): def __init__(self):
@ -98,7 +65,7 @@ class _UnaryStreamInterceptorWith_ResponseIterator(
async def intercept_unary_stream(self, continuation, client_call_details, async def intercept_unary_stream(self, continuation, client_call_details,
request): request):
call = await continuation(client_call_details, request) call = await continuation(client_call_details, request)
self.response_iterator = _ResponseIterator(call) self.response_iterator = _CountingResponseIterator(call)
return self.response_iterator return self.response_iterator
@ -112,16 +79,15 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
async def test_intercepts(self): async def test_intercepts(self):
for interceptor_class in (_UnaryStreamInterceptorEmpty, for interceptor_class in (_UnaryStreamInterceptorEmpty,
_UnaryStreamInterceptorWith_ResponseIterator): _UnaryStreamInterceptorWithResponseIterator):
with self.subTest(name=interceptor_class): with self.subTest(name=interceptor_class):
interceptor = interceptor_class() interceptor = interceptor_class()
request = messages_pb2.StreamingOutputCallRequest() request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES): request.response_parameters.extend([
request.response_parameters.append( messages_pb2.ResponseParameters(
messages_pb2.ResponseParameters( size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES)
size=_RESPONSE_PAYLOAD_SIZE))
channel = aio.insecure_channel(self._server_target, channel = aio.insecure_channel(self._server_target,
interceptors=[interceptor]) interceptors=[interceptor])
@ -138,7 +104,7 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body)) len(response.payload.body))
self.assertTrue(response_cnt, _NUM_STREAM_RESPONSES) self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES)
self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ()) self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ()) self.assertEqual(await call.trailing_metadata(), ())
@ -148,31 +114,30 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
self.assertEqual(call.cancelled(), False) self.assertEqual(call.cancelled(), False)
self.assertEqual(call.done(), True) self.assertEqual(call.done(), True)
if interceptor_class == _UnaryStreamInterceptorWith_ResponseIterator: if interceptor_class == _UnaryStreamInterceptorWithResponseIterator:
self.assertTrue(interceptor.response_iterator.response_cnt, self.assertEqual(interceptor.response_iterator.response_cnt,
_NUM_STREAM_RESPONSES) _NUM_STREAM_RESPONSES)
await channel.close() await channel.close()
async def test_add_done_callback(self): async def test_add_done_callback_interceptor_task_not_finished(self):
for interceptor_class in (_UnaryStreamInterceptorEmpty, for interceptor_class in (_UnaryStreamInterceptorEmpty,
_UnaryStreamInterceptorWith_ResponseIterator): _UnaryStreamInterceptorWithResponseIterator):
with self.subTest(name=interceptor_class): with self.subTest(name=interceptor_class):
interceptor = interceptor_class() interceptor = interceptor_class()
request = messages_pb2.StreamingOutputCallRequest() request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES): request.response_parameters.extend([
request.response_parameters.append( messages_pb2.ResponseParameters(
messages_pb2.ResponseParameters( size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES)
size=_RESPONSE_PAYLOAD_SIZE))
channel = aio.insecure_channel(self._server_target, channel = aio.insecure_channel(self._server_target,
interceptors=[interceptor]) interceptors=[interceptor])
stub = test_pb2_grpc.TestServiceStub(channel) stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.StreamingOutputCall(request) call = stub.StreamingOutputCall(request)
validation = _inject_callbacks(call) validation = inject_callbacks(call)
async for response in call: async for response in call:
pass pass
@ -181,18 +146,17 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
await channel.close() await channel.close()
async def test_add_done_callback_after_connection(self): async def test_add_done_callback_interceptor_task_finished(self):
for interceptor_class in (_UnaryStreamInterceptorEmpty, for interceptor_class in (_UnaryStreamInterceptorEmpty,
_UnaryStreamInterceptorWith_ResponseIterator): _UnaryStreamInterceptorWithResponseIterator):
with self.subTest(name=interceptor_class): with self.subTest(name=interceptor_class):
interceptor = interceptor_class() interceptor = interceptor_class()
request = messages_pb2.StreamingOutputCallRequest() request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES): request.response_parameters.extend([
request.response_parameters.append( messages_pb2.ResponseParameters(
messages_pb2.ResponseParameters( size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES)
size=_RESPONSE_PAYLOAD_SIZE))
channel = aio.insecure_channel(self._server_target, channel = aio.insecure_channel(self._server_target,
interceptors=[interceptor]) interceptors=[interceptor])
@ -204,7 +168,7 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
# pending state list. # pending state list.
await call.wait_for_connection() await call.wait_for_connection()
validation = _inject_callbacks(call) validation = inject_callbacks(call)
async for response in call: async for response in call:
pass pass
@ -214,16 +178,16 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
await channel.close() await channel.close()
async def test_response_iterator_using_read(self): async def test_response_iterator_using_read(self):
interceptor = _UnaryStreamInterceptorWith_ResponseIterator() interceptor = _UnaryStreamInterceptorWithResponseIterator()
channel = aio.insecure_channel(self._server_target, channel = aio.insecure_channel(self._server_target,
interceptors=[interceptor]) interceptors=[interceptor])
stub = test_pb2_grpc.TestServiceStub(channel) stub = test_pb2_grpc.TestServiceStub(channel)
request = messages_pb2.StreamingOutputCallRequest() request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES): request.response_parameters.extend([
request.response_parameters.append( messages_pb2.ResponseParameters(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES)
call = stub.StreamingOutputCall(request) call = stub.StreamingOutputCall(request)
@ -235,16 +199,16 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
messages_pb2.StreamingOutputCallResponse) messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertTrue(response_cnt, _NUM_STREAM_RESPONSES) self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES)
self.assertTrue(interceptor.response_iterator.response_cnt, self.assertEqual(interceptor.response_iterator.response_cnt,
_NUM_STREAM_RESPONSES) _NUM_STREAM_RESPONSES)
self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.code(), grpc.StatusCode.OK)
await channel.close() await channel.close()
async def test_mulitple_interceptors_response_iterator(self): async def test_multiple_interceptors_response_iterator(self):
for interceptor_class in (_UnaryStreamInterceptorEmpty, for interceptor_class in (_UnaryStreamInterceptorEmpty,
_UnaryStreamInterceptorWith_ResponseIterator): _UnaryStreamInterceptorWithResponseIterator):
with self.subTest(name=interceptor_class): with self.subTest(name=interceptor_class):
@ -255,10 +219,9 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
stub = test_pb2_grpc.TestServiceStub(channel) stub = test_pb2_grpc.TestServiceStub(channel)
request = messages_pb2.StreamingOutputCallRequest() request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES): request.response_parameters.extend([
request.response_parameters.append( messages_pb2.ResponseParameters(
messages_pb2.ResponseParameters( size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES)
size=_RESPONSE_PAYLOAD_SIZE))
call = stub.StreamingOutputCall(request) call = stub.StreamingOutputCall(request)
@ -270,14 +233,14 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body)) len(response.payload.body))
self.assertTrue(response_cnt, _NUM_STREAM_RESPONSES) self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES)
self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.code(), grpc.StatusCode.OK)
await channel.close() await channel.close()
async def test_intercepts_response_iterator_rpc_error(self): async def test_intercepts_response_iterator_rpc_error(self):
for interceptor_class in (_UnaryStreamInterceptorEmpty, for interceptor_class in (_UnaryStreamInterceptorEmpty,
_UnaryStreamInterceptorWith_ResponseIterator): _UnaryStreamInterceptorWithResponseIterator):
with self.subTest(name=interceptor_class): with self.subTest(name=interceptor_class):
@ -329,8 +292,6 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
self.assertTrue(call.cancelled()) self.assertTrue(call.cancelled())
self.assertTrue(call.done()) self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
self.assertEqual(await call.initial_metadata(), None) self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None) self.assertEqual(await call.trailing_metadata(), None)
await channel.close() await channel.close()
@ -367,23 +328,19 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
self.assertTrue(call.cancelled()) self.assertTrue(call.cancelled())
self.assertTrue(call.done()) self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
self.assertEqual(await call.initial_metadata(), None) self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None) self.assertEqual(await call.trailing_metadata(), None)
await channel.close() await channel.close()
async def test_cancel_consuming_response_iterator(self): async def test_cancel_consuming_response_iterator(self):
request = messages_pb2.StreamingOutputCallRequest() request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES): request.response_parameters.extend([
request.response_parameters.append( messages_pb2.ResponseParameters(
messages_pb2.ResponseParameters( size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES)
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US))
channel = aio.insecure_channel( channel = aio.insecure_channel(
self._server_target, self._server_target,
interceptors=[_UnaryStreamInterceptorWith_ResponseIterator()]) interceptors=[_UnaryStreamInterceptorWithResponseIterator()])
stub = test_pb2_grpc.TestServiceStub(channel) stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.StreamingOutputCall(request) call = stub.StreamingOutputCall(request)
@ -394,10 +351,57 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
self.assertTrue(call.cancelled()) self.assertTrue(call.cancelled())
self.assertTrue(call.done()) self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
await channel.close() await channel.close()
async def test_cancel_by_the_interceptor(self):
class Interceptor(aio.UnaryStreamClientInterceptor):
async def intercept_unary_stream(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
call.cancel()
return call
channel = aio.insecure_channel(UNREACHABLE_TARGET,
interceptors=[Interceptor()])
request = messages_pb2.StreamingOutputCallRequest()
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.StreamingOutputCall(request)
with self.assertRaises(asyncio.CancelledError):
async for response in call:
pass
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
await channel.close()
async def test_exception_raised_by_interceptor(self):
class InterceptorException(Exception):
pass
class Interceptor(aio.UnaryStreamClientInterceptor):
async def intercept_unary_stream(self, continuation,
client_call_details, request):
raise InterceptorException
channel = aio.insecure_channel(UNREACHABLE_TARGET,
interceptors=[Interceptor()])
request = messages_pb2.StreamingOutputCallRequest()
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.StreamingOutputCall(request)
with self.assertRaises(InterceptorException):
async for response in call:
pass
await channel.close()
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)

@ -21,6 +21,7 @@ import gc
import grpc import grpc
from grpc.experimental import aio from grpc.experimental import aio
from tests_aio.unit._common import inject_callbacks
from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_base import AioTestBase
from tests.unit.framework.common import test_constants from tests.unit.framework.common import test_constants
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
@ -31,29 +32,6 @@ _REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42 _RESPONSE_PAYLOAD_SIZE = 42
def _inject_callbacks(call):
first_callback_ran = asyncio.Event()
def first_callback(unused_call):
first_callback_ran.set()
second_callback_ran = asyncio.Event()
def second_callback(unused_call):
second_callback_ran.set()
call.add_done_callback(first_callback)
call.add_done_callback(second_callback)
async def validation():
await asyncio.wait_for(
asyncio.gather(first_callback_ran.wait(),
second_callback_ran.wait()),
test_constants.SHORT_TIMEOUT)
return validation()
class TestDoneCallback(AioTestBase): class TestDoneCallback(AioTestBase):
async def setUp(self): async def setUp(self):

Loading…
Cancel
Save