Merge pull request #22713 from Skyscanner/client_unary_stream_interceptor

[Aio] Implement the Unary Stream client interceptor
pull/22730/head
Pau Freixes 5 years ago committed by GitHub
commit 11e41537a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 8
      src/python/grpcio/grpc/experimental/aio/__init__.py
  2. 3
      src/python/grpcio/grpc/experimental/aio/_call.py
  3. 76
      src/python/grpcio/grpc/experimental/aio/_channel.py
  4. 387
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  5. 38
      src/python/grpcio_tests/tests_aio/health_check/health_servicer_test.py
  6. 5
      src/python/grpcio_tests/tests_aio/tests.json
  7. 32
      src/python/grpcio_tests/tests_aio/unit/_common.py
  8. 25
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  9. 409
      src/python/grpcio_tests/tests_aio/unit/client_unary_stream_interceptor_test.py
  10. 0
      src/python/grpcio_tests/tests_aio/unit/client_unary_unary_interceptor_test.py
  11. 34
      src/python/grpcio_tests/tests_aio/unit/done_callback_test.py

@ -30,8 +30,10 @@ from ._base_channel import (Channel, StreamStreamMultiCallable,
StreamUnaryMultiCallable, UnaryStreamMultiCallable,
UnaryUnaryMultiCallable)
from ._call import AioRpcError
from ._interceptor import (ClientCallDetails, InterceptedUnaryUnaryCall,
UnaryUnaryClientInterceptor, ServerInterceptor)
from ._interceptor import (ClientCallDetails, ClientInterceptor,
InterceptedUnaryUnaryCall,
UnaryUnaryClientInterceptor,
UnaryStreamClientInterceptor, ServerInterceptor)
from ._server import server
from ._base_server import Server, ServicerContext
from ._typing import ChannelArgumentType
@ -56,6 +58,8 @@ __all__ = (
'StreamUnaryMultiCallable',
'StreamStreamMultiCallable',
'ClientCallDetails',
'ClientInterceptor',
'UnaryStreamClientInterceptor',
'UnaryUnaryClientInterceptor',
'InterceptedUnaryUnaryCall',
'ServerInterceptor',

@ -318,6 +318,9 @@ class _StreamResponseMixin(Call):
yield message
message = await self._read()
# If the read operation failed, Core should explain why.
await self._raise_for_status()
def __aiter__(self) -> AsyncIterable[ResponseType]:
self._update_response_style(_APIStyle.ASYNC_GENERATOR)
if self._message_aiter is None:

@ -15,7 +15,7 @@
import asyncio
import sys
from typing import Any, Iterable, Optional, Sequence
from typing import Any, Iterable, Optional, Sequence, List
import grpc
from grpc import _common, _compression, _grpcio_metadata
@ -25,7 +25,9 @@ from . import _base_call, _base_channel
from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall,
UnaryUnaryCall)
from ._interceptor import (InterceptedUnaryUnaryCall,
UnaryUnaryClientInterceptor)
InterceptedUnaryStreamCall, ClientInterceptor,
UnaryUnaryClientInterceptor,
UnaryStreamClientInterceptor)
from ._typing import (ChannelArgumentType, DeserializingFunction, MetadataType,
SerializingFunction, RequestIterableType)
from ._utils import _timeout_to_deadline
@ -65,7 +67,7 @@ class _BaseMultiCallable:
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
_interceptors: Optional[Sequence[ClientInterceptor]]
_loop: asyncio.AbstractEventLoop
# pylint: disable=too-many-arguments
@ -75,7 +77,7 @@ class _BaseMultiCallable:
method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]],
interceptors: Optional[Sequence[ClientInterceptor]],
loop: asyncio.AbstractEventLoop,
) -> None:
self._loop = loop
@ -134,10 +136,17 @@ class UnaryStreamMultiCallable(_BaseMultiCallable,
deadline = _timeout_to_deadline(timeout)
call = UnaryStreamCall(request, deadline, metadata, credentials,
wait_for_ready, self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
if not self._interceptors:
call = UnaryStreamCall(request, deadline, metadata, credentials,
wait_for_ready, self._channel, self._method,
self._request_serializer,
self._response_deserializer, self._loop)
else:
call = InterceptedUnaryStreamCall(
self._interceptors, request, deadline, metadata, credentials,
wait_for_ready, self._channel, self._method,
self._request_serializer, self._response_deserializer,
self._loop)
return call
@ -193,12 +202,13 @@ class StreamStreamMultiCallable(_BaseMultiCallable,
class Channel(_base_channel.Channel):
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
_unary_unary_interceptors: List[UnaryUnaryClientInterceptor]
_unary_stream_interceptors: List[UnaryStreamClientInterceptor]
def __init__(self, target: str, options: ChannelArgumentType,
credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression],
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]):
interceptors: Optional[Sequence[ClientInterceptor]]):
"""Constructor.
Args:
@ -210,22 +220,31 @@ class Channel(_base_channel.Channel):
interceptors: An optional list of interceptors that would be used for
intercepting any RPC executed with that channel.
"""
if interceptors is None:
self._unary_unary_interceptors = None
else:
self._unary_unary_interceptors = list(
filter(
lambda interceptor: isinstance(interceptor,
UnaryUnaryClientInterceptor),
interceptors))
self._unary_unary_interceptors = []
self._unary_stream_interceptors = []
if interceptors:
attrs_and_interceptor_classes = ((self._unary_unary_interceptors,
UnaryUnaryClientInterceptor),
(self._unary_stream_interceptors,
UnaryStreamClientInterceptor))
# pylint: disable=cell-var-from-loop
for attr, interceptor_class in attrs_and_interceptor_classes:
attr.extend([
interceptor for interceptor in interceptors
if isinstance(interceptor, interceptor_class)
])
invalid_interceptors = set(interceptors) - set(
self._unary_unary_interceptors)
self._unary_unary_interceptors) - set(
self._unary_stream_interceptors)
if invalid_interceptors:
raise ValueError(
"Interceptor must be "+\
"UnaryUnaryClientInterceptors, the following are invalid: {}"\
"UnaryUnaryClientInterceptors or "+\
"UnaryStreamClientInterceptors. The following are invalid: {}"\
.format(invalid_interceptors))
self._loop = asyncio.get_event_loop()
@ -352,7 +371,9 @@ class Channel(_base_channel.Channel):
) -> UnaryStreamMultiCallable:
return UnaryStreamMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer, None, self._loop)
response_deserializer,
self._unary_stream_interceptors,
self._loop)
def stream_unary(
self,
@ -380,7 +401,7 @@ def insecure_channel(
target: str,
options: Optional[ChannelArgumentType] = None,
compression: Optional[grpc.Compression] = None,
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] = None):
interceptors: Optional[Sequence[ClientInterceptor]] = None):
"""Creates an insecure asynchronous Channel to a server.
Args:
@ -399,12 +420,11 @@ def insecure_channel(
compression, interceptors)
def secure_channel(
target: str,
credentials: grpc.ChannelCredentials,
options: Optional[ChannelArgumentType] = None,
compression: Optional[grpc.Compression] = None,
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] = None):
def secure_channel(target: str,
credentials: grpc.ChannelCredentials,
options: Optional[ChannelArgumentType] = None,
compression: Optional[grpc.Compression] = None,
interceptors: Optional[Sequence[ClientInterceptor]] = None):
"""Creates a secure asynchronous Channel to a server.
Args:

@ -16,13 +16,13 @@ import asyncio
import collections
import functools
from abc import ABCMeta, abstractmethod
from typing import Callable, Optional, Iterator, Sequence, Union, Awaitable
from typing import Callable, Optional, Iterator, Sequence, Union, Awaitable, AsyncIterable
import grpc
from grpc._cython import cygrpc
from . import _base_call
from ._call import UnaryUnaryCall, AioRpcError
from ._call import UnaryUnaryCall, UnaryStreamCall, AioRpcError
from ._utils import _timeout_to_deadline
from ._typing import (RequestType, SerializingFunction, DeserializingFunction,
MetadataType, ResponseType, DoneCallbackType)
@ -84,7 +84,11 @@ class ClientCallDetails(
wait_for_ready: Optional[bool]
class UnaryUnaryClientInterceptor(metaclass=ABCMeta):
class ClientInterceptor(metaclass=ABCMeta):
"""Base class used for all Aio Client Interceptor classes"""
class UnaryUnaryClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
"""Affords intercepting unary-unary invocations."""
@abstractmethod
@ -101,8 +105,8 @@ class UnaryUnaryClientInterceptor(metaclass=ABCMeta):
actual RPC on the underlying Channel. It is the interceptor's
responsibility to call it if it decides to move the RPC forward.
The interceptor can use
`response_future = await continuation(client_call_details, request)`
to continue with the RPC. `continuation` returns the response of the
`call = await continuation(client_call_details, request)`
to continue with the RPC. `continuation` returns the call to the
RPC.
client_call_details: A ClientCallDetails object describing the
outgoing RPC.
@ -117,8 +121,41 @@ class UnaryUnaryClientInterceptor(metaclass=ABCMeta):
"""
class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
"""Used for running a `UnaryUnaryCall` wrapped by interceptors.
class UnaryStreamClientInterceptor(ClientInterceptor, metaclass=ABCMeta):
"""Affords intercepting unary-stream invocations."""
@abstractmethod
async def intercept_unary_stream(
self, continuation: Callable[[ClientCallDetails, RequestType],
UnaryStreamCall],
client_call_details: ClientCallDetails, request: RequestType
) -> Union[AsyncIterable[ResponseType], UnaryStreamCall]:
"""Intercepts a unary-stream invocation asynchronously.
Args:
continuation: A coroutine that proceeds with the invocation by
executing the next interceptor in chain or invoking the
actual RPC on the underlying Channel. It is the interceptor's
responsibility to call it if it decides to move the RPC forward.
The interceptor can use
`call = await continuation(client_call_details, request, response_iterator))`
to continue with the RPC. `continuation` returns the call to the
RPC.
client_call_details: A ClientCallDetails object describing the
outgoing RPC.
request: The request value for the RPC.
Returns:
The RPC Call.
Raises:
AioRpcError: Indicating that the RPC terminated with non-OK status.
asyncio.CancelledError: Indicating that the RPC was canceled.
"""
class InterceptedCall:
"""Base implementation for all intecepted call arities.
Interceptors might have some work to do before the RPC invocation with
the capacity of changing the invocation parameters, and some work to do
@ -133,103 +170,68 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
intercepted call, being at the same time the same call returned to the
interceptors.
For most of the methods, like `initial_metadata()` the caller does not need
to wait until the interceptors task is finished, once the RPC is done the
caller will have the freedom for accessing to the results.
For the `__await__` method is it is proxied to the intercepted call only when
the interceptor task is finished.
As a base class for all of the interceptors implements the logic around
final status, metadata and cancellation.
"""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_cancelled_before_rpc: bool
_intercepted_call: Optional[_base_call.UnaryUnaryCall]
_intercepted_call_created: asyncio.Event
_interceptors_task: asyncio.Task
_pending_add_done_callbacks: Sequence[DoneCallbackType]
# pylint: disable=too-many-arguments
def __init__(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
request: RequestType, timeout: Optional[float],
metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
self._channel = channel
self._loop = loop
self._interceptors_task = loop.create_task(
self._invoke(interceptors, method, timeout, metadata, credentials,
wait_for_ready, request, request_serializer,
response_deserializer))
def __init__(self, interceptors_task: asyncio.Task) -> None:
self._interceptors_task = interceptors_task
self._pending_add_done_callbacks = []
self._interceptors_task.add_done_callback(
self._fire_pending_add_done_callbacks)
self._fire_or_add_pending_done_callbacks)
def __del__(self):
self.cancel()
# pylint: disable=too-many-arguments
async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
method: bytes, timeout: Optional[float],
metadata: Optional[MetadataType],
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], request: RequestType,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction
) -> UnaryUnaryCall:
"""Run the RPC call wrapped in interceptors"""
async def _run_interceptor(
interceptors: Iterator[UnaryUnaryClientInterceptor],
client_call_details: ClientCallDetails,
request: RequestType) -> _base_call.UnaryUnaryCall:
interceptor = next(interceptors, None)
if interceptor:
continuation = functools.partial(_run_interceptor, interceptors)
def _fire_or_add_pending_done_callbacks(self,
interceptors_task: asyncio.Task
) -> None:
call_or_response = await interceptor.intercept_unary_unary(
continuation, client_call_details, request)
if isinstance(call_or_response, _base_call.UnaryUnaryCall):
return call_or_response
else:
return UnaryUnaryCallResponse(call_or_response)
if not self._pending_add_done_callbacks:
return
else:
return UnaryUnaryCall(
request, _timeout_to_deadline(client_call_details.timeout),
client_call_details.metadata,
client_call_details.credentials,
client_call_details.wait_for_ready, self._channel,
client_call_details.method, request_serializer,
response_deserializer, self._loop)
call_completed = False
client_call_details = ClientCallDetails(method, timeout, metadata,
credentials, wait_for_ready)
return await _run_interceptor(iter(interceptors), client_call_details,
request)
try:
call = interceptors_task.result()
if call.done():
call_completed = True
except (AioRpcError, asyncio.CancelledError):
call_completed = True
def _fire_pending_add_done_callbacks(self,
unused_task: asyncio.Task) -> None:
for callback in self._pending_add_done_callbacks:
callback(self)
if call_completed:
for callback in self._pending_add_done_callbacks:
callback(self)
else:
for callback in self._pending_add_done_callbacks:
callback = functools.partial(self._wrap_add_done_callback,
callback)
call.add_done_callback(callback)
self._pending_add_done_callbacks = []
def _wrap_add_done_callback(self, callback: DoneCallbackType,
unused_task: asyncio.Task) -> None:
unused_call: _base_call.Call) -> None:
callback(self)
def cancel(self) -> bool:
if self._interceptors_task.done():
if not self._interceptors_task.done():
# There is no yet the intercepted call available,
# Trying to cancel it by using the generic Asyncio
# cancellation method.
return self._interceptors_task.cancel()
try:
call = self._interceptors_task.result()
except AioRpcError:
return False
except asyncio.CancelledError:
return False
return self._interceptors_task.cancel()
return call.cancel()
def cancelled(self) -> bool:
if not self._interceptors_task.done():
@ -270,7 +272,7 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
callback(self)
else:
callback = functools.partial(self._wrap_add_done_callback, callback)
call.add_done_callback(self._wrap_add_done_callback)
call.add_done_callback(callback)
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
@ -325,14 +327,181 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
return await call.debug_error_string()
async def wait_for_connection(self) -> None:
call = await self._interceptors_task
return await call.wait_for_connection()
class InterceptedUnaryUnaryCall(InterceptedCall, _base_call.UnaryUnaryCall):
"""Used for running a `UnaryUnaryCall` wrapped by interceptors.
For the `__await__` method is it is proxied to the intercepted call only when
the interceptor task is finished.
"""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
# pylint: disable=too-many-arguments
def __init__(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
request: RequestType, timeout: Optional[float],
metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._channel = channel
interceptors_task = loop.create_task(
self._invoke(interceptors, method, timeout, metadata, credentials,
wait_for_ready, request, request_serializer,
response_deserializer))
super().__init__(interceptors_task)
# pylint: disable=too-many-arguments
async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
method: bytes, timeout: Optional[float],
metadata: Optional[MetadataType],
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], request: RequestType,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction
) -> UnaryUnaryCall:
"""Run the RPC call wrapped in interceptors"""
async def _run_interceptor(
interceptors: Iterator[UnaryUnaryClientInterceptor],
client_call_details: ClientCallDetails,
request: RequestType) -> _base_call.UnaryUnaryCall:
interceptor = next(interceptors, None)
if interceptor:
continuation = functools.partial(_run_interceptor, interceptors)
call_or_response = await interceptor.intercept_unary_unary(
continuation, client_call_details, request)
if isinstance(call_or_response, _base_call.UnaryUnaryCall):
return call_or_response
else:
return UnaryUnaryCallResponse(call_or_response)
else:
return UnaryUnaryCall(
request, _timeout_to_deadline(client_call_details.timeout),
client_call_details.metadata,
client_call_details.credentials,
client_call_details.wait_for_ready, self._channel,
client_call_details.method, request_serializer,
response_deserializer, self._loop)
client_call_details = ClientCallDetails(method, timeout, metadata,
credentials, wait_for_ready)
return await _run_interceptor(iter(interceptors), client_call_details,
request)
def __await__(self):
call = yield from self._interceptors_task.__await__()
response = yield from call.__await__()
return response
async def wait_for_connection(self) -> None:
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
class InterceptedUnaryStreamCall(InterceptedCall, _base_call.UnaryStreamCall):
"""Used for running a `UnaryStreamCall` wrapped by interceptors."""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_response_aiter: AsyncIterable[ResponseType]
_last_returned_call_from_interceptors = Optional[_base_call.UnaryStreamCall]
# pylint: disable=too-many-arguments
def __init__(self, interceptors: Sequence[UnaryStreamClientInterceptor],
request: RequestType, timeout: Optional[float],
metadata: MetadataType,
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
method: bytes, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction,
loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._channel = channel
self._response_aiter = self._wait_for_interceptor_task_response_iterator(
)
self._last_returned_call_from_interceptors = None
interceptors_task = loop.create_task(
self._invoke(interceptors, method, timeout, metadata, credentials,
wait_for_ready, request, request_serializer,
response_deserializer))
super().__init__(interceptors_task)
# pylint: disable=too-many-arguments
async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
method: bytes, timeout: Optional[float],
metadata: Optional[MetadataType],
credentials: Optional[grpc.CallCredentials],
wait_for_ready: Optional[bool], request: RequestType,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction
) -> UnaryStreamCall:
"""Run the RPC call wrapped in interceptors"""
async def _run_interceptor(
interceptors: Iterator[UnaryStreamClientInterceptor],
client_call_details: ClientCallDetails,
request: RequestType,
) -> _base_call.UnaryUnaryCall:
interceptor = next(interceptors, None)
if interceptor:
continuation = functools.partial(_run_interceptor, interceptors)
call_or_response_iterator = await interceptor.intercept_unary_stream(
continuation, client_call_details, request)
if isinstance(call_or_response_iterator,
_base_call.UnaryUnaryCall):
self._last_returned_call_from_interceptors = call_or_response_iterator
else:
self._last_returned_call_from_interceptors = UnaryStreamCallResponseIterator(
self._last_returned_call_from_interceptors,
call_or_response_iterator)
return self._last_returned_call_from_interceptors
else:
self._last_returned_call_from_interceptors = UnaryStreamCall(
request, _timeout_to_deadline(client_call_details.timeout),
client_call_details.metadata,
client_call_details.credentials,
client_call_details.wait_for_ready, self._channel,
client_call_details.method, request_serializer,
response_deserializer, self._loop)
return self._last_returned_call_from_interceptors
client_call_details = ClientCallDetails(method, timeout, metadata,
credentials, wait_for_ready)
return await _run_interceptor(iter(interceptors), client_call_details,
request)
async def _wait_for_interceptor_task_response_iterator(self
) -> ResponseType:
call = await self._interceptors_task
return await call.wait_for_connection()
async for response in call:
yield response
def __aiter__(self) -> AsyncIterable[ResponseType]:
return self._response_aiter
async def read(self) -> ResponseType:
return await self._response_aiter.asend(None)
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
@ -381,3 +550,55 @@ class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
async def wait_for_connection(self) -> None:
pass
class UnaryStreamCallResponseIterator(_base_call.UnaryStreamCall):
"""UnaryStreamCall class wich uses an alternative response iterator."""
_call: _base_call.UnaryStreamCall
_response_iterator: AsyncIterable[ResponseType]
def __init__(self, call: _base_call.UnaryStreamCall,
response_iterator: AsyncIterable[ResponseType]) -> None:
self._response_iterator = response_iterator
self._call = call
def cancel(self) -> bool:
return self._call.cancel()
def cancelled(self) -> bool:
return self._call.cancelled()
def done(self) -> bool:
return self._call.done()
def add_done_callback(self, callback) -> None:
self._call.add_done_callback(callback)
def time_remaining(self) -> Optional[float]:
return self._call.time_remaining()
async def initial_metadata(self) -> Optional[MetadataType]:
return await self._call.initial_metadata()
async def trailing_metadata(self) -> Optional[MetadataType]:
return await self._call.trailing_metadata()
async def code(self) -> grpc.StatusCode:
return await self._call.code()
async def details(self) -> str:
return await self._call.details()
async def debug_error_string(self) -> Optional[str]:
return await self._call.debug_error_string()
def __aiter__(self):
return self._response_iterator.__aiter__()
async def wait_for_connection(self) -> None:
return await self._call.wait_for_connection()
async def read(self) -> ResponseType:
# Behind the scenes everyting goes through the
# async iterator. So this path should not be reached.
raise Exception()

@ -108,7 +108,10 @@ class HealthServicerTest(AioTestBase):
(await queue.get()).status)
call.cancel()
await task
with self.assertRaises(asyncio.CancelledError):
await task
self.assertTrue(queue.empty())
async def test_watch_new_service(self):
@ -131,7 +134,10 @@ class HealthServicerTest(AioTestBase):
(await queue.get()).status)
call.cancel()
await task
with self.assertRaises(asyncio.CancelledError):
await task
self.assertTrue(queue.empty())
async def test_watch_service_isolation(self):
@ -151,7 +157,10 @@ class HealthServicerTest(AioTestBase):
await asyncio.wait_for(queue.get(), test_constants.SHORT_TIMEOUT)
call.cancel()
await task
with self.assertRaises(asyncio.CancelledError):
await task
self.assertTrue(queue.empty())
async def test_two_watchers(self):
@ -177,8 +186,13 @@ class HealthServicerTest(AioTestBase):
call1.cancel()
call2.cancel()
await task1
await task2
with self.assertRaises(asyncio.CancelledError):
await task1
with self.assertRaises(asyncio.CancelledError):
await task2
self.assertTrue(queue1.empty())
self.assertTrue(queue2.empty())
@ -194,7 +208,9 @@ class HealthServicerTest(AioTestBase):
call.cancel()
await self._servicer.set(_WATCH_SERVICE,
health_pb2.HealthCheckResponse.SERVING)
await task
with self.assertRaises(asyncio.CancelledError):
await task
# Wait for the serving coroutine to process client cancellation.
timeout = time.monotonic() + test_constants.TIME_ALLOWANCE
@ -226,7 +242,10 @@ class HealthServicerTest(AioTestBase):
resp.status)
call.cancel()
await task
with self.assertRaises(asyncio.CancelledError):
await task
self.assertTrue(queue.empty())
async def test_no_duplicate_status(self):
@ -251,7 +270,10 @@ class HealthServicerTest(AioTestBase):
last_status = status
call.cancel()
await task
with self.assertRaises(asyncio.CancelledError):
await task
self.assertTrue(queue.empty())

@ -13,8 +13,9 @@
"unit.channel_argument_test.TestChannelArgument",
"unit.channel_ready_test.TestChannelReady",
"unit.channel_test.TestChannel",
"unit.client_interceptor_test.TestInterceptedUnaryUnaryCall",
"unit.client_interceptor_test.TestUnaryUnaryClientInterceptor",
"unit.client_unary_stream_interceptor_test.TestUnaryStreamClientInterceptor",
"unit.client_unary_unary_interceptor_test.TestInterceptedUnaryUnaryCall",
"unit.client_unary_unary_interceptor_test.TestUnaryUnaryClientInterceptor",
"unit.close_channel_test.TestCloseChannel",
"unit.compatibility_test.TestCompatibility",
"unit.compression_test.TestCompression",

@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import grpc
from grpc.experimental import aio
from grpc.experimental.aio._typing import MetadataType, MetadatumType
from tests.unit.framework.common import test_constants
def seen_metadata(expected: MetadataType, actual: MetadataType):
return not bool(set(expected) - set(actual))
@ -32,3 +35,32 @@ async def block_until_certain_state(channel: aio.Channel,
while state != expected_state:
await channel.wait_for_state_change(state)
state = channel.get_state()
def inject_callbacks(call):
first_callback_ran = asyncio.Event()
def first_callback(call):
# Validate that all resopnses have been received
# and the call is an end state.
assert call.done()
first_callback_ran.set()
second_callback_ran = asyncio.Event()
def second_callback(call):
# Validate that all resopnses have been received
# and the call is an end state.
assert call.done()
second_callback_ran.set()
call.add_done_callback(first_callback)
call.add_done_callback(second_callback)
async def validation():
await asyncio.wait_for(
asyncio.gather(first_callback_ran.wait(),
second_callback_ran.wait()),
test_constants.SHORT_TIMEOUT)
return validation()

@ -217,6 +217,23 @@ class TestUnaryUnaryCall(_MulticallableTestMixin, AioTestBase):
class TestUnaryStreamCall(_MulticallableTestMixin, AioTestBase):
async def test_call_rpc_error(self):
channel = aio.insecure_channel(UNREACHABLE_TARGET)
request = messages_pb2.StreamingOutputCallRequest()
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.StreamingOutputCall(request)
with self.assertRaises(aio.AioRpcError) as exception_context:
async for response in call:
pass
self.assertEqual(grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code())
self.assertTrue(call.done())
self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
await channel.close()
async def test_cancel_unary_stream(self):
# Prepares the request
request = messages_pb2.StreamingOutputCallRequest()
@ -550,7 +567,6 @@ class TestStreamUnaryCall(_MulticallableTestMixin, AioTestBase):
cancel_later_task = self.loop.create_task(cancel_later())
# No exceptions here
with self.assertRaises(asyncio.CancelledError):
await call
@ -772,9 +788,10 @@ class TestStreamStreamCall(_MulticallableTestMixin, AioTestBase):
cancel_later_task = self.loop.create_task(cancel_later())
# No exceptions here
async for response in call:
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
with self.assertRaises(asyncio.CancelledError):
async for response in call:
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
await request_iterator_received_the_exception.wait()

@ -0,0 +1,409 @@
# Copyright 2020 The gRPC Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import unittest
import datetime
import grpc
from grpc.experimental import aio
from tests_aio.unit._constants import UNREACHABLE_TARGET
from tests_aio.unit._common import inject_callbacks
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase
from tests.unit.framework.common import test_constants
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
_SHORT_TIMEOUT_S = 1.0
_NUM_STREAM_RESPONSES = 5
_REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 7
_RESPONSE_INTERVAL_US = int(_SHORT_TIMEOUT_S * 1000 * 1000)
class _CountingResponseIterator:
def __init__(self, response_iterator):
self.response_cnt = 0
self._response_iterator = response_iterator
async def _forward_responses(self):
async for response in self._response_iterator:
self.response_cnt += 1
yield response
def __aiter__(self):
return self._forward_responses()
class _UnaryStreamInterceptorEmpty(aio.UnaryStreamClientInterceptor):
async def intercept_unary_stream(self, continuation, client_call_details,
request):
return await continuation(client_call_details, request)
def assert_in_final_state(self, test: unittest.TestCase):
pass
class _UnaryStreamInterceptorWithResponseIterator(
aio.UnaryStreamClientInterceptor):
async def intercept_unary_stream(self, continuation, client_call_details,
request):
call = await continuation(client_call_details, request)
self.response_iterator = _CountingResponseIterator(call)
return self.response_iterator
def assert_in_final_state(self, test: unittest.TestCase):
test.assertEqual(_NUM_STREAM_RESPONSES,
self.response_iterator.response_cnt)
class TestUnaryStreamClientInterceptor(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
async def tearDown(self):
await self._server.stop(None)
async def test_intercepts(self):
for interceptor_class in (_UnaryStreamInterceptorEmpty,
_UnaryStreamInterceptorWithResponseIterator):
with self.subTest(name=interceptor_class):
interceptor = interceptor_class()
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.extend([
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)
] * _NUM_STREAM_RESPONSES)
channel = aio.insecure_channel(self._server_target,
interceptors=[interceptor])
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.StreamingOutputCall(request)
await call.wait_for_connection()
response_cnt = 0
async for response in call:
response_cnt += 1
self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.details(), '')
self.assertEqual(await call.debug_error_string(), '')
self.assertEqual(call.cancel(), False)
self.assertEqual(call.cancelled(), False)
self.assertEqual(call.done(), True)
interceptor.assert_in_final_state(self)
await channel.close()
async def test_add_done_callback_interceptor_task_not_finished(self):
for interceptor_class in (_UnaryStreamInterceptorEmpty,
_UnaryStreamInterceptorWithResponseIterator):
with self.subTest(name=interceptor_class):
interceptor = interceptor_class()
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.extend([
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)
] * _NUM_STREAM_RESPONSES)
channel = aio.insecure_channel(self._server_target,
interceptors=[interceptor])
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.StreamingOutputCall(request)
validation = inject_callbacks(call)
async for response in call:
pass
await validation
await channel.close()
async def test_add_done_callback_interceptor_task_finished(self):
for interceptor_class in (_UnaryStreamInterceptorEmpty,
_UnaryStreamInterceptorWithResponseIterator):
with self.subTest(name=interceptor_class):
interceptor = interceptor_class()
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.extend([
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)
] * _NUM_STREAM_RESPONSES)
channel = aio.insecure_channel(self._server_target,
interceptors=[interceptor])
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.StreamingOutputCall(request)
# This ensures that the callbacks will be registered
# with the intercepted call rather than saving in the
# pending state list.
await call.wait_for_connection()
validation = inject_callbacks(call)
async for response in call:
pass
await validation
await channel.close()
async def test_response_iterator_using_read(self):
interceptor = _UnaryStreamInterceptorWithResponseIterator()
channel = aio.insecure_channel(self._server_target,
interceptors=[interceptor])
stub = test_pb2_grpc.TestServiceStub(channel)
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.extend(
[messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)] *
_NUM_STREAM_RESPONSES)
call = stub.StreamingOutputCall(request)
response_cnt = 0
for response in range(_NUM_STREAM_RESPONSES):
response = await call.read()
response_cnt += 1
self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE, len(response.payload.body))
self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES)
self.assertEqual(interceptor.response_iterator.response_cnt,
_NUM_STREAM_RESPONSES)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
await channel.close()
async def test_multiple_interceptors_response_iterator(self):
for interceptor_class in (_UnaryStreamInterceptorEmpty,
_UnaryStreamInterceptorWithResponseIterator):
with self.subTest(name=interceptor_class):
interceptors = [interceptor_class(), interceptor_class()]
channel = aio.insecure_channel(self._server_target,
interceptors=interceptors)
stub = test_pb2_grpc.TestServiceStub(channel)
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.extend([
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)
] * _NUM_STREAM_RESPONSES)
call = stub.StreamingOutputCall(request)
response_cnt = 0
async for response in call:
response_cnt += 1
self.assertIs(type(response),
messages_pb2.StreamingOutputCallResponse)
self.assertEqual(_RESPONSE_PAYLOAD_SIZE,
len(response.payload.body))
self.assertEqual(response_cnt, _NUM_STREAM_RESPONSES)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
await channel.close()
async def test_intercepts_response_iterator_rpc_error(self):
for interceptor_class in (_UnaryStreamInterceptorEmpty,
_UnaryStreamInterceptorWithResponseIterator):
with self.subTest(name=interceptor_class):
channel = aio.insecure_channel(
UNREACHABLE_TARGET, interceptors=[interceptor_class()])
request = messages_pb2.StreamingOutputCallRequest()
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.StreamingOutputCall(request)
with self.assertRaises(aio.AioRpcError) as exception_context:
async for response in call:
pass
self.assertEqual(grpc.StatusCode.UNAVAILABLE,
exception_context.exception.code())
self.assertTrue(call.done())
self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
await channel.close()
async def test_cancel_before_rpc(self):
interceptor_reached = asyncio.Event()
wait_for_ever = self.loop.create_future()
class Interceptor(aio.UnaryStreamClientInterceptor):
async def intercept_unary_stream(self, continuation,
client_call_details, request):
interceptor_reached.set()
await wait_for_ever
channel = aio.insecure_channel(UNREACHABLE_TARGET,
interceptors=[Interceptor()])
request = messages_pb2.StreamingOutputCallRequest()
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
self.assertFalse(call.done())
await interceptor_reached.wait()
self.assertTrue(call.cancel())
with self.assertRaises(asyncio.CancelledError):
async for response in call:
pass
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None)
await channel.close()
async def test_cancel_after_rpc(self):
interceptor_reached = asyncio.Event()
wait_for_ever = self.loop.create_future()
class Interceptor(aio.UnaryStreamClientInterceptor):
async def intercept_unary_stream(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
interceptor_reached.set()
await wait_for_ever
channel = aio.insecure_channel(UNREACHABLE_TARGET,
interceptors=[Interceptor()])
request = messages_pb2.StreamingOutputCallRequest()
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.StreamingOutputCall(request)
self.assertFalse(call.cancelled())
self.assertFalse(call.done())
await interceptor_reached.wait()
self.assertTrue(call.cancel())
with self.assertRaises(asyncio.CancelledError):
async for response in call:
pass
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None)
await channel.close()
async def test_cancel_consuming_response_iterator(self):
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.extend(
[messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE)] *
_NUM_STREAM_RESPONSES)
channel = aio.insecure_channel(
self._server_target,
interceptors=[_UnaryStreamInterceptorWithResponseIterator()])
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.StreamingOutputCall(request)
with self.assertRaises(asyncio.CancelledError):
async for response in call:
call.cancel()
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
await channel.close()
async def test_cancel_by_the_interceptor(self):
class Interceptor(aio.UnaryStreamClientInterceptor):
async def intercept_unary_stream(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
call.cancel()
return call
channel = aio.insecure_channel(UNREACHABLE_TARGET,
interceptors=[Interceptor()])
request = messages_pb2.StreamingOutputCallRequest()
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.StreamingOutputCall(request)
with self.assertRaises(asyncio.CancelledError):
async for response in call:
pass
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
await channel.close()
async def test_exception_raised_by_interceptor(self):
class InterceptorException(Exception):
pass
class Interceptor(aio.UnaryStreamClientInterceptor):
async def intercept_unary_stream(self, continuation,
client_call_details, request):
raise InterceptorException
channel = aio.insecure_channel(UNREACHABLE_TARGET,
interceptors=[Interceptor()])
request = messages_pb2.StreamingOutputCallRequest()
stub = test_pb2_grpc.TestServiceStub(channel)
call = stub.StreamingOutputCall(request)
with self.assertRaises(InterceptorException):
async for response in call:
pass
await channel.close()
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)

@ -21,6 +21,7 @@ import gc
import grpc
from grpc.experimental import aio
from tests_aio.unit._common import inject_callbacks
from tests_aio.unit._test_base import AioTestBase
from tests.unit.framework.common import test_constants
from src.proto.grpc.testing import messages_pb2, test_pb2_grpc
@ -31,29 +32,6 @@ _REQUEST_PAYLOAD_SIZE = 7
_RESPONSE_PAYLOAD_SIZE = 42
def _inject_callbacks(call):
first_callback_ran = asyncio.Event()
def first_callback(unused_call):
first_callback_ran.set()
second_callback_ran = asyncio.Event()
def second_callback(unused_call):
second_callback_ran.set()
call.add_done_callback(first_callback)
call.add_done_callback(second_callback)
async def validation():
await asyncio.wait_for(
asyncio.gather(first_callback_ran.wait(),
second_callback_ran.wait()),
test_constants.SHORT_TIMEOUT)
return validation()
class TestDoneCallback(AioTestBase):
async def setUp(self):
@ -69,12 +47,12 @@ class TestDoneCallback(AioTestBase):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
self.assertEqual(grpc.StatusCode.OK, await call.code())
validation = _inject_callbacks(call)
validation = inject_callbacks(call)
await validation
async def test_unary_unary(self):
call = self._stub.UnaryCall(messages_pb2.SimpleRequest())
validation = _inject_callbacks(call)
validation = inject_callbacks(call)
self.assertEqual(grpc.StatusCode.OK, await call.code())
@ -87,7 +65,7 @@ class TestDoneCallback(AioTestBase):
messages_pb2.ResponseParameters(size=_RESPONSE_PAYLOAD_SIZE))
call = self._stub.StreamingOutputCall(request)
validation = _inject_callbacks(call)
validation = inject_callbacks(call)
response_cnt = 0
async for response in call:
@ -110,7 +88,7 @@ class TestDoneCallback(AioTestBase):
yield request
call = self._stub.StreamingInputCall(gen())
validation = _inject_callbacks(call)
validation = inject_callbacks(call)
response = await call
self.assertIsInstance(response, messages_pb2.StreamingInputCallResponse)
@ -122,7 +100,7 @@ class TestDoneCallback(AioTestBase):
async def test_stream_stream(self):
call = self._stub.FullDuplexCall()
validation = _inject_callbacks(call)
validation = inject_callbacks(call)
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(

Loading…
Cancel
Save