Fix an issue that AIO interceptors can only be called once. (#32641)

Prior to this change, we invoke aycnio interceptors by converting them
to `iterator` first, which means we can only call them once before
they're exhausted.

This PR changes the implementation to use `list`, thus the interceptors
can be called multiple times.
pull/32651/head
Xuan Wang 2 years ago committed by GitHub
parent 7293016afc
commit 0011f7090f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 62
      src/python/grpcio/grpc/aio/_interceptor.py
  2. 44
      src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py

@ -17,8 +17,8 @@ from abc import abstractmethod
import asyncio
import collections
import functools
from typing import (AsyncIterable, Awaitable, Callable, Iterator, Optional,
Sequence, Union)
from typing import (AsyncIterable, Awaitable, Callable, Iterator, List,
Optional, Sequence, Union)
import grpc
from grpc._cython import cygrpc
@ -599,16 +599,14 @@ class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall,
"""Run the RPC call wrapped in interceptors"""
async def _run_interceptor(
interceptors: Iterator[UnaryUnaryClientInterceptor],
interceptors: List[UnaryUnaryClientInterceptor],
client_call_details: ClientCallDetails,
request: RequestType) -> _base_call.UnaryUnaryCall:
interceptor = next(interceptors, None)
if interceptor:
continuation = functools.partial(_run_interceptor, interceptors)
call_or_response = await interceptor.intercept_unary_unary(
if interceptors:
continuation = functools.partial(_run_interceptor,
interceptors[1:])
call_or_response = await interceptors[0].intercept_unary_unary(
continuation, client_call_details, request)
if isinstance(call_or_response, _base_call.UnaryUnaryCall):
@ -627,7 +625,7 @@ class InterceptedUnaryUnaryCall(_InterceptedUnaryResponseMixin, InterceptedCall,
client_call_details = ClientCallDetails(method, timeout, metadata,
credentials, wait_for_ready)
return await _run_interceptor(iter(interceptors), client_call_details,
return await _run_interceptor(list(interceptors), client_call_details,
request)
def time_remaining(self) -> Optional[float]:
@ -673,18 +671,18 @@ class InterceptedUnaryStreamCall(_InterceptedStreamResponseMixin,
"""Run the RPC call wrapped in interceptors"""
async def _run_interceptor(
interceptors: Iterator[UnaryStreamClientInterceptor],
interceptors: List[UnaryStreamClientInterceptor],
client_call_details: ClientCallDetails,
request: RequestType,
) -> _base_call.UnaryUnaryCall:
interceptor = next(interceptors, None)
if interceptors:
continuation = functools.partial(_run_interceptor,
interceptors[1:])
if interceptor:
continuation = functools.partial(_run_interceptor, interceptors)
call_or_response_iterator = await interceptor.intercept_unary_stream(
continuation, client_call_details, request)
call_or_response_iterator = await interceptors[
0].intercept_unary_stream(continuation, client_call_details,
request)
if isinstance(call_or_response_iterator,
_base_call.UnaryStreamCall):
@ -707,7 +705,7 @@ class InterceptedUnaryStreamCall(_InterceptedStreamResponseMixin,
client_call_details = ClientCallDetails(method, timeout, metadata,
credentials, wait_for_ready)
return await _run_interceptor(iter(interceptors), client_call_details,
return await _run_interceptor(list(interceptors), client_call_details,
request)
def time_remaining(self) -> Optional[float]:
@ -762,12 +760,11 @@ class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin,
request_iterator: RequestIterableType
) -> _base_call.StreamUnaryCall:
interceptor = next(interceptors, None)
if interceptors:
continuation = functools.partial(_run_interceptor,
interceptors[1:])
if interceptor:
continuation = functools.partial(_run_interceptor, interceptors)
return await interceptor.intercept_stream_unary(
return await interceptors[0].intercept_stream_unary(
continuation, client_call_details, request_iterator)
else:
return StreamUnaryCall(
@ -781,7 +778,7 @@ class InterceptedStreamUnaryCall(_InterceptedUnaryResponseMixin,
client_call_details = ClientCallDetails(method, timeout, metadata,
credentials, wait_for_ready)
return await _run_interceptor(iter(interceptors), client_call_details,
return await _run_interceptor(list(interceptors), client_call_details,
request_iterator)
def time_remaining(self) -> Optional[float]:
@ -830,18 +827,19 @@ class InterceptedStreamStreamCall(_InterceptedStreamResponseMixin,
"""Run the RPC call wrapped in interceptors"""
async def _run_interceptor(
interceptors: Iterator[StreamStreamClientInterceptor],
interceptors: List[StreamStreamClientInterceptor],
client_call_details: ClientCallDetails,
request_iterator: RequestIterableType
) -> _base_call.StreamStreamCall:
interceptor = next(interceptors, None)
if interceptor:
continuation = functools.partial(_run_interceptor, interceptors)
if interceptors:
continuation = functools.partial(_run_interceptor,
interceptors[1:])
call_or_response_iterator = await interceptor.intercept_stream_stream(
continuation, client_call_details, request_iterator)
call_or_response_iterator = await interceptors[
0].intercept_stream_stream(continuation,
client_call_details,
request_iterator)
if isinstance(call_or_response_iterator,
_base_call.StreamStreamCall):
@ -864,7 +862,7 @@ class InterceptedStreamStreamCall(_InterceptedStreamResponseMixin,
client_call_details = ClientCallDetails(method, timeout, metadata,
credentials, wait_for_ready)
return await _run_interceptor(iter(interceptors), client_call_details,
return await _run_interceptor(list(interceptors), client_call_details,
request_iterator)
def time_remaining(self) -> Optional[float]:

@ -224,6 +224,50 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
self.assertEqual(await interceptor.calls[1].code(),
grpc.StatusCode.OK)
async def test_retry_with_multiple_interceptors(self):
class RetryInterceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
# Simulate retry twice
for _ in range(2):
call = await continuation(client_call_details, request)
result = await call
return result
class AnotherInterceptor(aio.UnaryUnaryClientInterceptor):
def __init__(self):
self.called_times = 0
async def intercept_unary_unary(self, continuation,
client_call_details, request):
self.called_times += 1
call = await continuation(client_call_details, request)
result = await call
return result
# Create two interceptors, the retry interceptor will call another interceptor.
retry_interceptor = RetryInterceptor()
another_interceptor = AnotherInterceptor()
async with aio.insecure_channel(
self._server_target,
interceptors=[retry_interceptor,
another_interceptor]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCallWithSleep',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
await call
self.assertEqual(grpc.StatusCode.OK, await call.code())
self.assertEqual(another_interceptor.called_times, 2)
async def test_rpcresponse(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):

Loading…
Cancel
Save