Apply feedback

pull/22713/head
Pau Freixes 5 years ago
parent 2aea5c002b
commit dae80a4977
  1. 5
      src/python/grpcio/grpc/experimental/aio/__init__.py
  2. 20
      src/python/grpcio/grpc/experimental/aio/_channel.py
  3. 46
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  4. 31
      src/python/grpcio_tests/tests_aio/unit/_common.py
  5. 182
      src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py
  6. 24
      src/python/grpcio_tests/tests_aio/unit/done_callback_test.py

@ -30,7 +30,8 @@ from ._base_channel import (Channel, StreamStreamMultiCallable,
StreamUnaryMultiCallable, UnaryStreamMultiCallable,
UnaryUnaryMultiCallable)
from ._call import AioRpcError
from ._interceptor import (ClientCallDetails, InterceptedUnaryUnaryCall,
from ._interceptor import (ClientCallDetails, ClientInterceptor,
InterceptedUnaryUnaryCall,
UnaryUnaryClientInterceptor,
UnaryStreamClientInterceptor, ServerInterceptor)
from ._server import server
@ -57,6 +58,8 @@ __all__ = (
'StreamUnaryMultiCallable',
'StreamStreamMultiCallable',
'ClientCallDetails',
'ClientInterceptor',
'UnaryStreamClientInterceptor',
'UnaryUnaryClientInterceptor',
'InterceptedUnaryUnaryCall',
'ServerInterceptor',

@ -15,7 +15,7 @@
import asyncio
import sys
from typing import Any, Iterable, Optional, Sequence
from typing import Any, Iterable, Optional, Sequence, List
import grpc
from grpc import _common, _compression, _grpcio_metadata
@ -202,8 +202,8 @@ class StreamStreamMultiCallable(_BaseMultiCallable,
class Channel(_base_channel.Channel):
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_unary_unary_interceptors: Sequence[UnaryUnaryClientInterceptor]
_unary_stream_interceptors: Sequence[UnaryStreamClientInterceptor]
_unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
_unary_stream_interceptors: List[UnaryStreamClientInterceptor]
def __init__(self, target: str, options: ChannelArgumentType,
credentials: Optional[grpc.ChannelCredentials],
@ -224,18 +224,16 @@ class Channel(_base_channel.Channel):
self._unary_stream_interceptors = []
if interceptors:
attrs_and_interceptor_classes = [
attrs_and_interceptor_classes = (
(self._unary_unary_interceptors, UnaryUnaryClientInterceptor),
(self._unary_stream_interceptors, UnaryStreamClientInterceptor)
]
)
# pylint: disable=cell-var-from-loop
for attr, interceptor_class in attrs_and_interceptor_classes:
attr.extend(
list(
filter(
lambda interceptor: isinstance(
interceptor, interceptor_class), interceptors)))
[interceptor for interceptor in interceptors if isinstance(interceptor, interceptor_class)]
)
invalid_interceptors = set(interceptors) - set(
self._unary_unary_interceptors) - set(
@ -245,7 +243,7 @@ class Channel(_base_channel.Channel):
raise ValueError(
"Interceptor must be "+\
"UnaryUnaryClientInterceptors or "+\
"UnaryStreamClientInterceptors the following are invalid: {}"\
"UnaryStreamClientInterceptors. The following are invalid: {}"\
.format(invalid_interceptors))
self._loop = asyncio.get_event_loop()
@ -402,7 +400,7 @@ def insecure_channel(
target: str,
options: Optional[ChannelArgumentType] = None,
compression: Optional[grpc.Compression] = None,
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] = None):
interceptors: Optional[Sequence[ClientInterceptor]] = None):
"""Creates an insecure asynchronous Channel to a server.
Args:

@ -126,9 +126,9 @@ class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
@abstractmethod
async def intercept_unary_stream(self, continuation: Callable[[
ClientCallDetails, RequestType, AsyncIterable[ResponseType]
], UnaryStreamCall], client_call_details: ClientCallDetails,
request: RequestType) -> UnaryStreamCall:
ClientCallDetails, RequestType], UnaryStreamCall],
client_call_details: ClientCallDetails,
request: RequestType) -> Union[AsyncIterable[ResponseType], UnaryStreamCall]:
"""Intercepts a unary-stream invocation asynchronously.
Args:
@ -180,31 +180,32 @@ class InterceptedCall:
self._interceptors_task = interceptors_task
self._pending_add_done_callbacks = []
self._interceptors_task.add_done_callback(
self._fire_or_add_pending_add_done_callbacks)
self._fire_or_add_pending_done_callbacks)
def __del__(self):
self.cancel()
def _fire_or_add_pending_add_done_callbacks(self,
def _fire_or_add_pending_done_callbacks(self,
interceptors_task: asyncio.Task
) -> None:
if not self._pending_add_done_callbacks:
return
fire = False
call_completed = False
try:
call = interceptors_task.result()
if call.done():
fire = True
call_completed = True
except (AioRpcError, asyncio.CancelledError):
fire = True
call_completed = True
for callback in self._pending_add_done_callbacks:
if fire:
if call_completed:
for callback in self._pending_add_done_callbacks:
callback(self)
else:
else:
for callback in self._pending_add_done_callbacks:
callback = functools.partial(self._wrap_add_done_callback,
callback)
call.add_done_callback(callback)
@ -415,6 +416,7 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall):
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_response_aiter: AsyncIterable[ResponseType]
_last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall]
# pylint: disable=too-many-arguments
def __init__(self, interceptors: Sequence[UnaryStreamClientInterceptor],
@ -429,6 +431,7 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall):
self._channel = channel
self._response_aiter = self._wait_for_interceptor_task_response_iterator(
)
self._last_returned_call_from_interceptors = None
interceptors_task = loop.create_task(
self._invoke(interceptors, method, timeout, metadata, credentials,
wait_for_ready, request, request_serializer,
@ -446,7 +449,6 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall):
) -> UnaryStreamCall:
"""Run the RPC call wrapped in interceptors"""
last_returned_call_from_interceptors = [None]
async def _run_interceptor(
interceptors: Iterator[UnaryStreamClientInterceptor],
@ -462,17 +464,15 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall):
call_or_response_iterator = await interceptor.intercept_unary_stream(
continuation, client_call_details, request)
if call_or_response_iterator is last_returned_call_from_interceptors[
0]:
return call_or_response_iterator
if isinstance(call_or_response_iterator, _base_call.UnaryUnaryCall):
self._last_returned_call_from_interceptors = call_or_response_iterator
else:
last_returned_call_from_interceptors[
0] = UnaryStreamCallResponseIterator(
last_returned_call_from_interceptors[0],
self._last_returned_call_from_interceptors = UnaryStreamCallResponseIterator(
self._last_returned_call_from_interceptors,
call_or_response_iterator)
return last_returned_call_from_interceptors[0]
return self._last_returned_call_from_interceptors
else:
last_returned_call_from_interceptors[0] = UnaryStreamCall(
self._last_returned_call_from_interceptors = UnaryStreamCall(
request, _timeout_to_deadline(client_call_details.timeout),
client_call_details.metadata,
client_call_details.credentials,
@ -480,7 +480,7 @@ class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall):
client_call_details.method, request_serializer,
response_deserializer, self._loop)
return last_returned_call_from_interceptors[0]
return self._last_returned_call_from_interceptors
client_call_details = ClientCallDetails(method, timeout, metadata,
credentials, wait_for_ready)
@ -598,4 +598,6 @@ class UnaryStreamCallResponseIterator(_base_call.UnaryStreamCall):
return await self._call.wait_for_connection()
async def read(self) -> ResponseType:
return await self._call.read()
# Behind the scenes everyting goes through the
# async iterator. So this path should not be reached.
raise Exception()

@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import grpc
from grpc.experimental import aio
from grpc.experimental.aio._typing import MetadataType, MetadatumType
from tests.unit.framework.common import test_constants
def seen_metadata(expected: MetadataType, actual: MetadataType):
return not bool(set(expected) - set(actual))
@ -32,3 +35,31 @@ async def block_until_certain_state(channel: aio.Channel,
while state != expected_state:
await channel.wait_for_state_change(state)
state = channel.get_state()
def inject_callbacks(call):
first_callback_ran = asyncio.Event()
def first_callback(call):
# Validate that all resopnses have been received
# and the call is an end state.
assert call.done()
first_callback_ran.set()
second_callback_ran = asyncio.Event()
def second_callback(call):
# Validate that all resopnses have been received
# and the call is an end state.
assert call.done()
second_callback_ran.set()
call.add_done_callback(first_callback)
call.add_done_callback(second_callback)
async def validation():
await asyncio.wait_for(
asyncio.gather(first_callback_ran.wait(),
second_callback_ran.wait()),
test_constants.SHORT_TIMEOUT)
return validation()

@ -1,4 +1,4 @@
# Copyright 2019 The gRPC Authors.
# Copyright 2020 The gRPC Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -20,67 +20,34 @@ import grpc
from grpc.experimental import aio
from tests_aio.unit._constants import UNREACHABLE_TARGET
from tests_aio.unit._common import inject_callbacks
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase
from tests.unit.framework.common import test_constants
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
_SHORT_TIMEOUT_S = datetime.timedelta(seconds=1).total_seconds()
_SHORT_TIMEOUT_S = 1.0
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
_NUM_STREAM_RESPONSES = 5
_REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 7
_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000)
class _ResponseIterator:
class _CountingResponseIterator:
def __init__(self, response_iterator):
self._response_cnt = 0
self.response_cnt = 0
self._response_iterator = response_iterator
async def _forward_responses(self):
async for response in self._response_iterator:
self._response_cnt += 1
self.response_cnt += 1
yield response
def __aiter__(self):
return self._forward_responses()
@property
def response_cnt(self):
return self._response_cnt
def _inject_callbacks(call):
first_callback_ran = asyncio.Event()
def first_callback(call):
# Validate that all resopnses have been received
# and the call is an end state.
assert call.done()
first_callback_ran.set()
second_callback_ran = asyncio.Event()
def second_callback(call):
# Validate that all resopnses have been received
# and the call is an end state.
assert call.done()
second_callback_ran.set()
call.add_done_callback(first_callback)
call.add_done_callback(second_callback)
async def validation():
await asyncio.wait_for(
asyncio.gather(first_callback_ran.wait(),
second_callback_ran.wait()),
test_constants.SHORT_TIMEOUT)
return validation()
class _UnaryStreamInterceptorEmpty(aio.UnaryStreamClientInterceptor):
@ -89,7 +56,7 @@ class _UnaryStreamInterceptorEmpty(aio.UnaryStreamClientInterceptor):
return await continuation(client_call_details, request)
class _UnaryStreamInterceptorWith_ResponseIterator(
class _UnaryStreamInterceptorWithResponseIterator(
aio.UnaryStreamClientInterceptor):
def __init__(self):
@ -98,7 +65,7 @@ class _UnaryStreamInterceptorWith_ResponseIterator(
async def intercept_unary_stream(self, continuation, client_call_details,
request):
call = await continuation(client_call_details, request)
self.response_iterator = _ResponseIterator(call)
self.response_iterator = _CountingResponseIterator(call)
return self.response_iterator
@ -112,16 +79,15 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
async def test_intercepts(self):
for interceptor_class in (_UnaryStreamInterceptorEmpty,
_UnaryStreamInterceptorWith_ResponseIterator):
_UnaryStreamInterceptorWithResponseIterator):
with self.subTest(name=interceptor_class):
interceptor = interceptor_class()
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE))
request.response_parameters.extend([
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES)
channel = aio.insecure_channel(self._server_target,
interceptors=[interceptor])
@ -138,7 +104,7 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
self.assertTrue(response_cnt, _NUM_STREAM_RESPONSES)
self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
@ -148,31 +114,30 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
self.assertEqual(call.cancelled(), False)
self.assertEqual(call.done(), True)
if interceptor_class == _UnaryStreamInterceptorWith_ResponseIterator:
self.assertTrue(interceptor.response_iterator.response_cnt,
if interceptor_class == _UnaryStreamInterceptorWithResponseIterator:
self.assertEqual(interceptor.response_iterator.response_cnt,
_NUM_STREAM_RESPONSES)
await channel.close()
async def test_add_done_callback(self):
async def test_add_done_callback_interceptor_task_not_finished(self):
for interceptor_class in (_UnaryStreamInterceptorEmpty,
_UnaryStreamInterceptorWith_ResponseIterator):
_UnaryStreamInterceptorWithResponseIterator):
with self.subTest(name=interceptor_class):
interceptor = interceptor_class()
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE))
request.response_parameters.extend([
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES)
channel = aio.insecure_channel(self._server_target,
interceptors=[interceptor])
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.StreamingOutputCall(request)
validation = _inject_callbacks(call)
validation = inject_callbacks(call)
async for response in call:
pass
@ -181,18 +146,17 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
await channel.close()
async def test_add_done_callback_after_connection(self):
async def test_add_done_callback_interceptor_task_finished(self):
for interceptor_class in (_UnaryStreamInterceptorEmpty,
_UnaryStreamInterceptorWith_ResponseIterator):
_UnaryStreamInterceptorWithResponseIterator):
with self.subTest(name=interceptor_class):
interceptor = interceptor_class()
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE))
request.response_parameters.extend([
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES)
channel = aio.insecure_channel(self._server_target,
interceptors=[interceptor])
@ -204,7 +168,7 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
# pending state list.
await call.wait_for_connection()
validation = _inject_callbacks(call)
validation = inject_callbacks(call)
async for response in call:
pass
@ -214,16 +178,16 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
await channel.close()
async def test_response_iterator_using_read(self):
interceptor = _UnaryStreamInterceptorWith_ResponseIterator()
interceptor = _UnaryStreamInterceptorWithResponseIterator()
channel = aio.insecure_channel(self._server_target,
interceptors=[interceptor])
stub = test_pb2_grpc.TestServiceStub(channel)
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
request.response_parameters.extend([
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES)
call = stub.StreamingOutputCall(request)
@ -235,16 +199,16 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertTrue(response_cnt, _NUM_STREAM_RESPONSES)
self.assertTrue(interceptor.response_iterator.response_cnt,
self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES)
self.assertEqual(interceptor.response_iterator.response_cnt,
_NUM_STREAM_RESPONSES)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
await channel.close()
async def test_mulitple_interceptors_response_iterator(self):
async def test_multiple_interceptors_response_iterator(self):
for interceptor_class in (_UnaryStreamInterceptorEmpty,
_UnaryStreamInterceptorWith_ResponseIterator):
_UnaryStreamInterceptorWithResponseIterator):
with self.subTest(name=interceptor_class):
@ -255,10 +219,9 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
stub = test_pb2_grpc.TestServiceStub(channel)
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE))
request.response_parameters.extend([
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES)
call = stub.StreamingOutputCall(request)
@ -270,14 +233,14 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
self.assertTrue(response_cnt, _NUM_STREAM_RESPONSES)
self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
await channel.close()
async def test_intercepts_response_iterator_rpc_error(self):
for interceptor_class in (_UnaryStreamInterceptorEmpty,
_UnaryStreamInterceptorWith_ResponseIterator):
_UnaryStreamInterceptorWithResponseIterator):
with self.subTest(name=interceptor_class):
@ -329,8 +292,6 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None)
await channel.close()
@ -367,23 +328,19 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None)
await channel.close()
async def test_cancel_consuming_response_iterator(self):
request = messages_pb2.StreamingOutputCallRequest()
for _ in range(_NUM_STREAM_RESPONSES):
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_RESPONSE_INTERVAL_US))
request.response_parameters.extend([
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE)] * _NUM_STREAM_RESPONSES)
channel = aio.insecure_channel(
self._server_target,
interceptors=[_UnaryStreamInterceptorWith_ResponseIterator()])
interceptors=[_UnaryStreamInterceptorWithResponseIterator()])
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.StreamingOutputCall(request)
@ -394,10 +351,57 @@ class TestUnaryStreamClientInterceptor(AioTestBase):
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
await channel.close()
async def test_cancel_by_the_interceptor(self):
class Interceptor(aio.UnaryStreamClientInterceptor):
async def intercept_unary_stream(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
call.cancel()
return call
channel = aio.insecure_channel(UNREACHABLE_TARGET,
interceptors=[Interceptor()])
request = messages_pb2.StreamingOutputCallRequest()
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.StreamingOutputCall(request)
with self.assertRaises(asyncio.CancelledError):
async for response in call:
pass
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
await channel.close()
async def test_exception_raised_by_interceptor(self):
class InterceptorException(Exception):
pass
class Interceptor(aio.UnaryStreamClientInterceptor):
async def intercept_unary_stream(self, continuation,
client_call_details, request):
raise InterceptorException
channel = aio.insecure_channel(UNREACHABLE_TARGET,
interceptors=[Interceptor()])
request = messages_pb2.StreamingOutputCallRequest()
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.StreamingOutputCall(request)
with self.assertRaises(InterceptorException):
async for response in call:
pass
await channel.close()
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)

@ -21,6 +21,7 @@ import gc
import grpc
from grpc.experimental import aio
from tests_aio.unit._common import inject_callbacks
from tests_aio.unit._test_base import AioTestBase
from tests.unit.framework.common import test_constants
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
@ -31,29 +32,6 @@ _REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42
def _inject_callbacks(call):
first_callback_ran = asyncio.Event()
def first_callback(unused_call):
first_callback_ran.set()
second_callback_ran = asyncio.Event()
def second_callback(unused_call):
second_callback_ran.set()
call.add_done_callback(first_callback)
call.add_done_callback(second_callback)
async def validation():
await asyncio.wait_for(
asyncio.gather(first_callback_ran.wait(),
second_callback_ran.wait()),
test_constants.SHORT_TIMEOUT)
return validation()
class TestDoneCallback(AioTestBase):
async def setUp(self):

Loading…
Cancel
Save