diff --git a/src/python/grpcio/grpc/aio/_interceptor.py b/src/python/grpcio/grpc/aio/_interceptor.py index 85db01fa08c..05f166e3b0b 100644 --- a/src/python/grpcio/grpc/aio/_interceptor.py +++ b/src/python/grpcio/grpc/aio/_interceptor.py @@ -17,8 +17,8 @@ from abc import abstractmethod import asyncio import collections import functools -from typing import (AsyncIterable, Awaitable, Callable, Iterator, Optional, - Sequence, Union) +from typing import (AsyncIterable, Awaitable, Callable, Iterator, List, + Optional, Sequence, Union) import grpc from grpc._cython import cygrpc @@ -599,16 +599,14 @@ class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall, """Run the RPC call wrapped in interceptors""" async def _run_interceptor( - interceptors: Iterator[UnaryUnaryClientInterceptor], + interceptors: List[UnaryUnaryClientInterceptor], client_call_details: ClientCallDetails, request: RequestType) -> _base_call.UnaryUnaryCall: - interceptor = next(interceptors, None) - - if interceptor: - continuation = functools.partial(_run_interceptor, interceptors) - - call_or_response = await interceptor.intercept_unary_unary( + if interceptors: + continuation = functools.partial(_run_interceptor, + interceptors[1:]) + call_or_response = await interceptors[0].intercept_unary_unary( continuation, client_call_details, request) if isinstance(call_or_response, _base_call.UnaryUnaryCall): @@ -627,7 +625,7 @@ class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall, client_call_details = ClientCallDetails(method, timeout, metadata, credentials, wait_for_ready) - return await _run_interceptor(iter(interceptors), client_call_details, + return await _run_interceptor(list(interceptors), client_call_details, request) def time_remaining(self) -> Optional[float]: @@ -673,18 +671,18 @@ class InterceptedUnaryStreamCall(_InterceptedStreamResponseMixin, """Run the RPC call wrapped in interceptors""" async def _run_interceptor( - interceptors: Iterator[UnaryStreamClientInterceptor], + interceptors: List[UnaryStreamClientInterceptor], client_call_details: ClientCallDetails, request: RequestType, ) -> _base_call.UnaryUnaryCall: - interceptor = next(interceptors, None) + if interceptors: + continuation = functools.partial(_run_interceptor, + interceptors[1:]) - if interceptor: - continuation = functools.partial(_run_interceptor, interceptors) - - call_or_response_iterator = await interceptor.intercept_unary_stream( - continuation, client_call_details, request) + call_or_response_iterator = await interceptors[ + 0].intercept_unary_stream(continuation, client_call_details, + request) if isinstance(call_or_response_iterator, _base_call.UnaryStreamCall): @@ -707,7 +705,7 @@ class InterceptedUnaryStreamCall(_InterceptedStreamResponseMixin, client_call_details = ClientCallDetails(method, timeout, metadata, credentials, wait_for_ready) - return await _run_interceptor(iter(interceptors), client_call_details, + return await _run_interceptor(list(interceptors), client_call_details, request) def time_remaining(self) -> Optional[float]: @@ -762,12 +760,11 @@ class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin, request_iterator: RequestIterableType ) -> _base_call.StreamUnaryCall: - interceptor = next(interceptors, None) + if interceptors: + continuation = functools.partial(_run_interceptor, + interceptors[1:]) - if interceptor: - continuation = functools.partial(_run_interceptor, interceptors) - - return await interceptor.intercept_stream_unary( + return await interceptors[0].intercept_stream_unary( continuation, client_call_details, request_iterator) else: return StreamUnaryCall( @@ -781,7 +778,7 @@ class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin, client_call_details = ClientCallDetails(method, timeout, metadata, credentials, wait_for_ready) - return await _run_interceptor(iter(interceptors), client_call_details, + return await _run_interceptor(list(interceptors), client_call_details, request_iterator) def time_remaining(self) -> Optional[float]: @@ -830,18 +827,19 @@ class InterceptedStreamStreamCall(_InterceptedStreamResponseMixin, """Run the RPC call wrapped in interceptors""" async def _run_interceptor( - interceptors: Iterator[StreamStreamClientInterceptor], + interceptors: List[StreamStreamClientInterceptor], client_call_details: ClientCallDetails, request_iterator: RequestIterableType ) -> _base_call.StreamStreamCall: - interceptor = next(interceptors, None) - - if interceptor: - continuation = functools.partial(_run_interceptor, interceptors) + if interceptors: + continuation = functools.partial(_run_interceptor, + interceptors[1:]) - call_or_response_iterator = await interceptor.intercept_stream_stream( - continuation, client_call_details, request_iterator) + call_or_response_iterator = await interceptors[ + 0].intercept_stream_stream(continuation, + client_call_details, + request_iterator) if isinstance(call_or_response_iterator, _base_call.StreamStreamCall): @@ -864,7 +862,7 @@ class InterceptedStreamStreamCall(_InterceptedStreamResponseMixin, client_call_details = ClientCallDetails(method, timeout, metadata, credentials, wait_for_ready) - return await _run_interceptor(iter(interceptors), client_call_details, + return await _run_interceptor(list(interceptors), client_call_details, request_iterator) def time_remaining(self) -> Optional[float]: diff --git a/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py b/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py index 7367c454fad..bcb5df54819 100644 --- a/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py @@ -224,6 +224,50 @@ class TestUnaryUnaryClientInterceptor(AioTestBase): self.assertEqual(await interceptor.calls[1].code(), grpc.StatusCode.OK) + async def test_retry_with_multiple_interceptors(self): + + class RetryInterceptor(aio.UnaryUnaryClientInterceptor): + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + # Simulate retry twice + for _ in range(2): + call = await continuation(client_call_details, request) + result = await call + return result + + class AnotherInterceptor(aio.UnaryUnaryClientInterceptor): + + def __init__(self): + self.called_times = 0 + + async def intercept_unary_unary(self, continuation, + client_call_details, request): + self.called_times += 1 + call = await continuation(client_call_details, request) + result = await call + return result + + # Create two interceptors, the retry interceptor will call another interceptor. + retry_interceptor = RetryInterceptor() + another_interceptor = AnotherInterceptor() + async with aio.insecure_channel( + self._server_target, + interceptors=[retry_interceptor, + another_interceptor]) as channel: + + multicallable = channel.unary_unary( + '/grpc.testing.TestService/UnaryCallWithSleep', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + + call = multicallable(messages_pb2.SimpleRequest()) + + await call + + self.assertEqual(grpc.StatusCode.OK, await call.code()) + self.assertEqual(another_interceptor.called_times, 2) + async def test_rpcresponse(self): class Interceptor(aio.UnaryUnaryClientInterceptor):