From dae80a4977c9afcabd7e104c317cfd8e259c7814 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Mon, 20 Apr 2020 23:24:54 +0200 Subject: [PATCH] Apply feedback --- .../grpcio/grpc/experimental/aio/__init__.py | 5 +- .../grpcio/grpc/experimental/aio/_channel.py | 20 +- .../grpc/experimental/aio/_interceptor.py | 46 ++--- .../grpcio_tests/tests_aio/unit/_common.py | 31 +++ .../client_unary_stream_interceptor_test.py | 182 +++++++++--------- .../tests_aio/unit/done_callback_test.py | 24 +-- 6 files changed, 162 insertions(+), 146 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/__init__.py b/src/python/grpcio/grpc/experimental/aio/__init__.py index 649d29588d4..d0b6a58a149 100644 --- a/src/python/grpcio/grpc/experimental/aio/__init__.py +++ b/src/python/grpcio/grpc/experimental/aio/__init__.py @@ -30,7 +30,8 @@ from ._base_channel import (Channel, StreamStreamMultiCallable, StreamUnaryMultiCallable, UnaryStreamMultiCallable, UnaryUnaryMultiCallable) from ._call import AioRpcError -from ._interceptor import (ClientCallDetails, InterceptedUnaryUnaryCall, +from ._interceptor import (ClientCallDetails, ClientInterceptor, + InterceptedUnaryUnaryCall, UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, ServerInterceptor) from ._server import server @@ -57,6 +58,8 @@ __all__ = ( 'StreamUnaryMultiCallable', 'StreamStreamMultiCallable', 'ClientCallDetails', + 'ClientInterceptor', + 'UnaryStreamClientInterceptor', 'UnaryUnaryClientInterceptor', 'InterceptedUnaryUnaryCall', 'ServerInterceptor', diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index 03d8aa075e0..f783999dd2a 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -15,7 +15,7 @@ import asyncio import sys -from typing import Any, Iterable, Optional, Sequence +from typing import Any, Iterable, Optional, Sequence, List import grpc from grpc import _common, _compression, _grpcio_metadata @@ -202,8 +202,8 @@ class StreamStreamMultiCallable(_BaseMultiCallable, class Channel(_base_channel.Channel): _loop: asyncio.AbstractEventLoop _channel: cygrpc.AioChannel - _unary_unary_interceptors: Sequence[UnaryUnaryClientInterceptor] - _unary_stream_interceptors: Sequence[UnaryStreamClientInterceptor] + _unary_unary_interceptors: List[UnaryUnaryClientInterceptor] + _unary_stream_interceptors: List[UnaryStreamClientInterceptor] def __init__(self, target: str, options: ChannelArgumentType, credentials: Optional[grpc.ChannelCredentials], @@ -224,18 +224,16 @@ class Channel(_base_channel.Channel): self._unary_stream_interceptors = [] if interceptors: - attrs_and_interceptor_classes = [ + attrs_and_interceptor_classes = ( (self._unary_unary_interceptors, UnaryUnaryClientInterceptor), (self._unary_stream_interceptors, UnaryStreamClientInterceptor) - ] + ) # pylint: disable=cell-var-from-loop for attr, interceptor_class in attrs_and_interceptor_classes: attr.extend( - list( - filter( - lambda interceptor: isinstance( - interceptor, interceptor_class), interceptors))) + [interceptor for interceptor in interceptors if isinstance(interceptor, interceptor_class)] + ) invalid_interceptors = set(interceptors) - set( self._unary_unary_interceptors) - set( @@ -245,7 +243,7 @@ class Channel(_base_channel.Channel): raise ValueError( "Interceptor must be "+\ "UnaryUnaryClientInterceptors or "+\ - "UnaryStreamClientInterceptors the following are invalid: {}"\ + "UnaryStreamClientInterceptors. The following are invalid: {}"\ .format(invalid_interceptors)) self._loop = asyncio.get_event_loop() @@ -402,7 +400,7 @@ def insecure_channel( target: str, options: Optional[ChannelArgumentType] = None, compression: Optional[grpc.Compression] = None, - interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] = None): + interceptors: Optional[Sequence[ClientInterceptor]] = None): """Creates an insecure asynchronous Channel to a server. Args: diff --git a/src/python/grpcio/grpc/experimental/aio/_interceptor.py b/src/python/grpcio/grpc/experimental/aio/_interceptor.py index 3ce0ddd904d..469585de51a 100644 --- a/src/python/grpcio/grpc/experimental/aio/_interceptor.py +++ b/src/python/grpcio/grpc/experimental/aio/_interceptor.py @@ -126,9 +126,9 @@ class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta): @abstractmethod async def intercept_unary_stream(self, continuation: Callable[[ - ClientCallDetails, RequestType, AsyncIterable[ResponseType] - ], UnaryStreamCall], client_call_details: ClientCallDetails, - request: RequestType) -> UnaryStreamCall: + ClientCallDetails, RequestType], UnaryStreamCall], + client_call_details: ClientCallDetails, + request: RequestType) -> Union[AsyncIterable[ResponseType], UnaryStreamCall]: """Intercepts a unary-stream invocation asynchronously. Args: @@ -180,31 +180,32 @@ class InterceptedCall: self._interceptors_task = interceptors_task self._pending_add_done_callbacks = [] self._interceptors_task.add_done_callback( - self._fire_or_add_pending_add_done_callbacks) + self._fire_or_add_pending_done_callbacks) def __del__(self): self.cancel() - def _fire_or_add_pending_add_done_callbacks(self, + def _fire_or_add_pending_done_callbacks(self, interceptors_task: asyncio.Task ) -> None: if not self._pending_add_done_callbacks: return - fire = False + call_completed = False try: call = interceptors_task.result() if call.done(): - fire = True + call_completed = True except (AioRpcError, asyncio.CancelledError): - fire = True + call_completed = True - for callback in self._pending_add_done_callbacks: - if fire: + if call_completed: + for callback in self._pending_add_done_callbacks: callback(self) - else: + else: + for callback in self._pending_add_done_callbacks: callback = functools.partial(self._wrap_add_done_callback, callback) call.add_done_callback(callback) @@ -415,6 +416,7 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall): _loop: asyncio.AbstractEventLoop _channel: cygrpc.AioChannel _response_aiter: AsyncIterable[ResponseType] + _last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall] # pylint: disable=too-many-arguments def __init__(self, interceptors: Sequence[UnaryStreamClientInterceptor], @@ -429,6 +431,7 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall): self._channel = channel self._response_aiter = self._wait_for_interceptor_task_response_iterator( ) + self._last_returned_call_from_interceptors = None interceptors_task = loop.create_task( self._invoke(interceptors, method, timeout, metadata, credentials, wait_for_ready, request, request_serializer, @@ -446,7 +449,6 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall): ) -> UnaryStreamCall: """Run the RPC call wrapped in interceptors""" - last_returned_call_from_interceptors = [None] async def _run_interceptor( interceptors: Iterator[UnaryStreamClientInterceptor], @@ -462,17 +464,15 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall): call_or_response_iterator = await interceptor.intercept_unary_stream( continuation, client_call_details, request) - if call_or_response_iterator is last_returned_call_from_interceptors[ - 0]: - return call_or_response_iterator + if isinstance(call_or_response_iterator, _base_call.UnaryUnaryCall): + self._last_returned_call_from_interceptors = call_or_response_iterator else: - last_returned_call_from_interceptors[ - 0] = UnaryStreamCallResponseIterator( - last_returned_call_from_interceptors[0], + self._last_returned_call_from_interceptors = UnaryStreamCallResponseIterator( + self._last_returned_call_from_interceptors, call_or_response_iterator) - return last_returned_call_from_interceptors[0] + return self._last_returned_call_from_interceptors else: - last_returned_call_from_interceptors[0] = UnaryStreamCall( + self._last_returned_call_from_interceptors = UnaryStreamCall( request, _timeout_to_deadline(client_call_details.timeout), client_call_details.metadata, client_call_details.credentials, @@ -480,7 +480,7 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall): client_call_details.method, request_serializer, response_deserializer, self._loop) - return last_returned_call_from_interceptors[0] + return self._last_returned_call_from_interceptors client_call_details = ClientCallDetails(method, timeout, metadata, credentials, wait_for_ready) @@ -598,4 +598,6 @@ class UnaryStreamCallResponseIterator(_base_call.UnaryStreamCall): return await self._call.wait_for_connection() async def read(self) -> ResponseType: - return await self._call.read() + # Behind the scenes everyting goes through the + # async iterator. So this path should not be reached. + raise Exception() diff --git a/src/python/grpcio_tests/tests_aio/unit/_common.py b/src/python/grpcio_tests/tests_aio/unit/_common.py index 1b5a4d909fa..e820a18dd77 100644 --- a/src/python/grpcio_tests/tests_aio/unit/_common.py +++ b/src/python/grpcio_tests/tests_aio/unit/_common.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import grpc from grpc.experimental import aio from grpc.experimental.aio._typing import MetadataType, MetadatumType +from tests.unit.framework.common import test_constants + def seen_metadata(expected: MetadataType, actual: MetadataType): return not bool(set(expected) - set(actual)) @@ -32,3 +35,31 @@ async def block_until_certain_state(channel: aio.Channel, while state != expected_state: await channel.wait_for_state_change(state) state = channel.get_state() + +def inject_callbacks(call): + first_callback_ran = asyncio.Event() + + def first_callback(call): + # Validate that all resopnses have been received + # and the call is an end state. + assert call.done() + first_callback_ran.set() + + second_callback_ran = asyncio.Event() + + def second_callback(call): + # Validate that all resopnses have been received + # and the call is an end state. + assert call.done() + second_callback_ran.set() + + call.add_done_callback(first_callback) + call.add_done_callback(second_callback) + + async def validation(): + await asyncio.wait_for( + asyncio.gather(first_callback_ran.wait(), + second_callback_ran.wait()), + test_constants.SHORT_TIMEOUT) + + return validation() diff --git a/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py index 3de4d054469..fc9c8d81ad0 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py @@ -1,4 +1,4 @@ -# Copyright 2019 The gRPC Authors. +# Copyright 2020 The gRPC Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,67 +20,34 @@ import grpc from grpc.experimental import aio from tests_aio.unit._constants import UNREACHABLE_TARGET +from tests_aio.unit._common import inject_callbacks from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_base import AioTestBase from tests.unit.framework.common import test_constants from src.proto.grpc.testing import messages_pb2, test_pb2_grpc -_SHORT_TIMEOUT_S = datetime.timedelta(seconds=1).total_seconds() +_SHORT_TIMEOUT_S = 1.0 -_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' _NUM_STREAM_RESPONSES = 5 _REQUEST_PAYLOAD_SIZE = 7 _RESPONSE_PAYLOAD_SIZE = 7 _RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000) -class _ResponseIterator: +class _CountingResponseIterator: def __init__(self, response_iterator): - self._response_cnt = 0 + self.response_cnt = 0 self._response_iterator = response_iterator async def _forward_responses(self): async for response in self._response_iterator: - self._response_cnt += 1 + self.response_cnt += 1 yield response def __aiter__(self): return self._forward_responses() - @property - def response_cnt(self): - return self._response_cnt - - -def _inject_callbacks(call): - first_callback_ran = asyncio.Event() - - def first_callback(call): - # Validate that all resopnses have been received - # and the call is an end state. - assert call.done() - first_callback_ran.set() - - second_callback_ran = asyncio.Event() - - def second_callback(call): - # Validate that all resopnses have been received - # and the call is an end state. - assert call.done() - second_callback_ran.set() - - call.add_done_callback(first_callback) - call.add_done_callback(second_callback) - - async def validation(): - await asyncio.wait_for( - asyncio.gather(first_callback_ran.wait(), - second_callback_ran.wait()), - test_constants.SHORT_TIMEOUT) - - return validation() - class _UnaryStreamInterceptorEmpty(aio.UnaryStreamClientInterceptor): @@ -89,7 +56,7 @@ class _UnaryStreamInterceptorEmpty(aio.UnaryStreamClientInterceptor): return await continuation(client_call_details, request) -class _UnaryStreamInterceptorWith_ResponseIterator( +class _UnaryStreamInterceptorWithResponseIterator( aio.UnaryStreamClientInterceptor): def __init__(self): @@ -98,7 +65,7 @@ class _UnaryStreamInterceptorWith_ResponseIterator( async def intercept_unary_stream(self, continuation, client_call_details, request): call = await continuation(client_call_details, request) - self.response_iterator = _ResponseIterator(call) + self.response_iterator = _CountingResponseIterator(call) return self.response_iterator @@ -112,16 +79,15 @@ class TestUnaryStreamClientInterceptor(AioTestBase): async def test_intercepts(self): for interceptor_class in (_UnaryStreamInterceptorEmpty, - _UnaryStreamInterceptorWith_ResponseIterator): + _UnaryStreamInterceptorWithResponseIterator): with self.subTest(name=interceptor_class): interceptor = interceptor_class() request = messages_pb2.StreamingOutputCallRequest() - for _ in range(_NUM_STREAM_RESPONSES): - request.response_parameters.append( - messages_pb2.ResponseParameters( - size=_RESPONSE_PAYLOAD_SIZE)) + request.response_parameters.extend([ + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES) channel = aio.insecure_channel(self._server_target, interceptors=[interceptor]) @@ -138,7 +104,7 @@ class TestUnaryStreamClientInterceptor(AioTestBase): self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) - self.assertTrue(response_cnt, _NUM_STREAM_RESPONSES) + self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) self.assertEqual(await call.code(), grpc.StatusCode.OK) self.assertEqual(await call.initial_metadata(), ()) self.assertEqual(await call.trailing_metadata(), ()) @@ -148,31 +114,30 @@ class TestUnaryStreamClientInterceptor(AioTestBase): self.assertEqual(call.cancelled(), False) self.assertEqual(call.done(), True) - if interceptor_class == _UnaryStreamInterceptorWith_ResponseIterator: - self.assertTrue(interceptor.response_iterator.response_cnt, + if interceptor_class == _UnaryStreamInterceptorWithResponseIterator: + self.assertEqual(interceptor.response_iterator.response_cnt, _NUM_STREAM_RESPONSES) await channel.close() - async def test_add_done_callback(self): + async def test_add_done_callback_interceptor_task_not_finished(self): for interceptor_class in (_UnaryStreamInterceptorEmpty, - _UnaryStreamInterceptorWith_ResponseIterator): + _UnaryStreamInterceptorWithResponseIterator): with self.subTest(name=interceptor_class): interceptor = interceptor_class() request = messages_pb2.StreamingOutputCallRequest() - for _ in range(_NUM_STREAM_RESPONSES): - request.response_parameters.append( - messages_pb2.ResponseParameters( - size=_RESPONSE_PAYLOAD_SIZE)) + request.response_parameters.extend([ + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES) channel = aio.insecure_channel(self._server_target, interceptors=[interceptor]) stub = test_pb2_grpc.TestServiceStub(channel) call = stub.StreamingOutputCall(request) - validation = _inject_callbacks(call) + validation = inject_callbacks(call) async for response in call: pass @@ -181,18 +146,17 @@ class TestUnaryStreamClientInterceptor(AioTestBase): await channel.close() - async def test_add_done_callback_after_connection(self): + async def test_add_done_callback_interceptor_task_finished(self): for interceptor_class in (_UnaryStreamInterceptorEmpty, - _UnaryStreamInterceptorWith_ResponseIterator): + _UnaryStreamInterceptorWithResponseIterator): with self.subTest(name=interceptor_class): interceptor = interceptor_class() request = messages_pb2.StreamingOutputCallRequest() - for _ in range(_NUM_STREAM_RESPONSES): - request.response_parameters.append( - messages_pb2.ResponseParameters( - size=_RESPONSE_PAYLOAD_SIZE)) + request.response_parameters.extend([ + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES) channel = aio.insecure_channel(self._server_target, interceptors=[interceptor]) @@ -204,7 +168,7 @@ class TestUnaryStreamClientInterceptor(AioTestBase): # pending state list. await call.wait_for_connection() - validation = _inject_callbacks(call) + validation = inject_callbacks(call) async for response in call: pass @@ -214,16 +178,16 @@ class TestUnaryStreamClientInterceptor(AioTestBase): await channel.close() async def test_response_iterator_using_read(self): - interceptor = _UnaryStreamInterceptorWith_ResponseIterator() + interceptor = _UnaryStreamInterceptorWithResponseIterator() channel = aio.insecure_channel(self._server_target, interceptors=[interceptor]) stub = test_pb2_grpc.TestServiceStub(channel) request = messages_pb2.StreamingOutputCallRequest() - for _ in range(_NUM_STREAM_RESPONSES): - request.response_parameters.append( - messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)) + request.response_parameters.extend([ + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES) call = stub.StreamingOutputCall(request) @@ -235,16 +199,16 @@ class TestUnaryStreamClientInterceptor(AioTestBase): messages_pb2.StreamingOutputCallResponse) self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) - self.assertTrue(response_cnt, _NUM_STREAM_RESPONSES) - self.assertTrue(interceptor.response_iterator.response_cnt, + self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) + self.assertEqual(interceptor.response_iterator.response_cnt, _NUM_STREAM_RESPONSES) self.assertEqual(await call.code(), grpc.StatusCode.OK) await channel.close() - async def test_mulitple_interceptors_response_iterator(self): + async def test_multiple_interceptors_response_iterator(self): for interceptor_class in (_UnaryStreamInterceptorEmpty, - _UnaryStreamInterceptorWith_ResponseIterator): + _UnaryStreamInterceptorWithResponseIterator): with self.subTest(name=interceptor_class): @@ -255,10 +219,9 @@ class TestUnaryStreamClientInterceptor(AioTestBase): stub = test_pb2_grpc.TestServiceStub(channel) request = messages_pb2.StreamingOutputCallRequest() - for _ in range(_NUM_STREAM_RESPONSES): - request.response_parameters.append( - messages_pb2.ResponseParameters( - size=_RESPONSE_PAYLOAD_SIZE)) + request.response_parameters.extend([ + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES) call = stub.StreamingOutputCall(request) @@ -270,14 +233,14 @@ class TestUnaryStreamClientInterceptor(AioTestBase): self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body)) - self.assertTrue(response_cnt, _NUM_STREAM_RESPONSES) + self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES) self.assertEqual(await call.code(), grpc.StatusCode.OK) await channel.close() async def test_intercepts_response_iterator_rpc_error(self): for interceptor_class in (_UnaryStreamInterceptorEmpty, - _UnaryStreamInterceptorWith_ResponseIterator): + _UnaryStreamInterceptorWithResponseIterator): with self.subTest(name=interceptor_class): @@ -329,8 +292,6 @@ class TestUnaryStreamClientInterceptor(AioTestBase): self.assertTrue(call.cancelled()) self.assertTrue(call.done()) self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) - self.assertEqual(await call.details(), - _LOCAL_CANCEL_DETAILS_EXPECTATION) self.assertEqual(await call.initial_metadata(), None) self.assertEqual(await call.trailing_metadata(), None) await channel.close() @@ -367,23 +328,19 @@ class TestUnaryStreamClientInterceptor(AioTestBase): self.assertTrue(call.cancelled()) self.assertTrue(call.done()) self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) - self.assertEqual(await call.details(), - _LOCAL_CANCEL_DETAILS_EXPECTATION) self.assertEqual(await call.initial_metadata(), None) self.assertEqual(await call.trailing_metadata(), None) await channel.close() async def test_cancel_consuming_response_iterator(self): request = messages_pb2.StreamingOutputCallRequest() - for _ in range(_NUM_STREAM_RESPONSES): - request.response_parameters.append( - messages_pb2.ResponseParameters( - size=_RESPONSE_PAYLOAD_SIZE, - interval_us=_RESPONSE_INTERVAL_US)) + request.response_parameters.extend([ + messages_pb2.ResponseParameters( + size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES) channel = aio.insecure_channel( self._server_target, - interceptors=[_UnaryStreamInterceptorWith_ResponseIterator()]) + interceptors=[_UnaryStreamInterceptorWithResponseIterator()]) stub = test_pb2_grpc.TestServiceStub(channel) call = stub.StreamingOutputCall(request) @@ -394,10 +351,57 @@ class TestUnaryStreamClientInterceptor(AioTestBase): self.assertTrue(call.cancelled()) self.assertTrue(call.done()) self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) - self.assertEqual(await call.details(), - _LOCAL_CANCEL_DETAILS_EXPECTATION) await channel.close() + async def test_cancel_by_the_interceptor(self): + + class Interceptor(aio.UnaryStreamClientInterceptor): + + async def intercept_unary_stream(self, continuation, + client_call_details, request): + call = await continuation(client_call_details, request) + call.cancel() + return call + + channel = aio.insecure_channel(UNREACHABLE_TARGET, + interceptors=[Interceptor()]) + request = messages_pb2.StreamingOutputCallRequest() + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + with self.assertRaises(asyncio.CancelledError): + async for response in call: + pass + + self.assertTrue(call.cancelled()) + self.assertTrue(call.done()) + self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) + await channel.close() + + async def test_exception_raised_by_interceptor(self): + + class InterceptorException(Exception): + pass + + class Interceptor(aio.UnaryStreamClientInterceptor): + + async def intercept_unary_stream(self, continuation, + client_call_details, request): + raise InterceptorException + + channel = aio.insecure_channel(UNREACHABLE_TARGET, + interceptors=[Interceptor()]) + request = messages_pb2.StreamingOutputCallRequest() + stub = test_pb2_grpc.TestServiceStub(channel) + call = stub.StreamingOutputCall(request) + + with self.assertRaises(InterceptorException): + async for response in call: + pass + + await channel.close() + + if __name__ == '__main__': logging.basicConfig(level=logging.DEBUG) diff --git a/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py b/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py index a312e45711f..d4adf965512 100644 --- a/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/done_callback_test.py @@ -21,6 +21,7 @@ import gc import grpc from grpc.experimental import aio +from tests_aio.unit._common import inject_callbacks from tests_aio.unit._test_base import AioTestBase from tests.unit.framework.common import test_constants from src.proto.grpc.testing import messages_pb2, test_pb2_grpc @@ -31,29 +32,6 @@ _REQUEST_PAYLOAD_SIZE = 7 _RESPONSE_PAYLOAD_SIZE = 42 -def _inject_callbacks(call): - first_callback_ran = asyncio.Event() - - def first_callback(unused_call): - first_callback_ran.set() - - second_callback_ran = asyncio.Event() - - def second_callback(unused_call): - second_callback_ran.set() - - call.add_done_callback(first_callback) - call.add_done_callback(second_callback) - - async def validation(): - await asyncio.wait_for( - asyncio.gather(first_callback_ran.wait(), - second_callback_ran.wait()), - test_constants.SHORT_TIMEOUT) - - return validation() - - class TestDoneCallback(AioTestBase): async def setUp(self):