Client unary unary interceptor

Implements the unary unary interceptor for the client-side. Interceptors
can be now installed by passing them as a new parameter of the `Channel`
constructor or by giving them as part of the `insecure_channel`
function.

Interceptors are executed within an Asyncio task for making some work before
the RPC invocation, and after for accessing to the intercepted call that has
been invoked.
pull/21455/head
Pau Freixes 5 years ago
parent a29b3b7304
commit 77df7f5f17
  1. 1
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  2. 2
      src/python/grpcio/grpc/experimental/BUILD.bazel
  3. 24
      src/python/grpcio/grpc/experimental/aio/__init__.py
  4. 75
      src/python/grpcio/grpc/experimental/aio/_channel.py
  5. 336
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  6. 23
      src/python/grpcio/grpc/experimental/aio/_utils.py
  7. 2
      src/python/grpcio_tests/tests_aio/tests.json
  8. 2
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  9. 3
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  10. 1
      src/python/grpcio_tests/tests_aio/unit/channel_test.py
  11. 504
      src/python/grpcio_tests/tests_aio/unit/interceptor_test.py

@ -15,6 +15,7 @@
cimport cpython
import grpc
_EMPTY_FLAGS = 0
_EMPTY_MASK = 0
_EMPTY_METADATA = None

@ -7,8 +7,10 @@ py_library(
"aio/_base_call.py",
"aio/_call.py",
"aio/_channel.py",
"aio/_interceptor.py",
"aio/_server.py",
"aio/_typing.py",
"aio/_utils.py",
],
deps = [
"//src/python/grpcio/grpc/_cython:cygrpc",

@ -18,18 +18,26 @@ created. AsyncIO doesn't provide thread safety for most of its APIs.
"""
import abc
from typing import Any, Optional, Sequence, Text, Tuple
import six
import grpc
from grpc._cython.cygrpc import init_grpc_aio
from ._base_call import RpcContext, Call, UnaryUnaryCall, UnaryStreamCall
from ._call import AioRpcError
from ._channel import Channel
from ._channel import UnaryUnaryMultiCallable
from ._interceptor import ClientCallDetails, UnaryUnaryClientInterceptor
from ._interceptor import InterceptedUnaryUnaryCall
from ._server import server
def insecure_channel(target, options=None, compression=None):
def insecure_channel(
target: Text,
options: Optional[Sequence[Tuple[Text, Any]]] = None,
compression: Optional[grpc.Compression] = None,
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] = None):
"""Creates an insecure asynchronous Channel to a server.
Args:
@ -38,16 +46,22 @@ def insecure_channel(target, options=None, compression=None):
in gRPC Core runtime) to configure the channel.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel. This is an EXPERIMENTAL option.
interceptors: An optional sequence of interceptors that will be executed for
any call executed with this channel.
Returns:
A Channel.
"""
return Channel(target, () if options is None else options, None,
compression)
return Channel(
target, () if options is None else options,
None,
compression,
interceptors=interceptors)
################################### __all__ #################################
__all__ = ('RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall',
__all__ = ('AioRpcError', 'RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall',
'init_grpc_aio', 'Channel', 'UnaryUnaryMultiCallable',
'insecure_channel', 'server')
'ClientCallDetails', 'UnaryUnaryClientInterceptor',
'InterceptedUnaryUnaryCall', 'insecure_channel', 'server')

@ -18,29 +18,35 @@ from typing import Any, Optional, Sequence, Text, Tuple
import grpc
from grpc import _common
from grpc._cython import cygrpc
from . import _base_call
from ._call import UnaryUnaryCall, UnaryStreamCall
from ._interceptor import UnaryUnaryClientInterceptor, InterceptedUnaryUnaryCall
from ._typing import (DeserializingFunction, MetadataType, SerializingFunction)
def _timeout_to_deadline(loop: asyncio.AbstractEventLoop,
timeout: Optional[float]) -> Optional[float]:
if timeout is None:
return None
return loop.time() + timeout
from ._utils import _timeout_to_deadline
class UnaryUnaryMultiCallable:
"""Factory an asynchronous unary-unary RPC stub call from client-side."""
_channel: cygrpc.AioChannel
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
_loop: asyncio.AbstractEventLoop
def __init__(self, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
response_deserializer: DeserializingFunction,
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
) -> None:
self._loop = asyncio.get_event_loop()
self._channel = channel
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._interceptors = interceptors
def __call__(self,
request: Any,
@ -74,7 +80,6 @@ class UnaryUnaryMultiCallable:
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if metadata:
raise NotImplementedError("TODO: metadata not implemented yet")
@ -88,11 +93,20 @@ class UnaryUnaryMultiCallable:
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
deadline = _timeout_to_deadline(self._loop, timeout)
if not self._interceptors:
return UnaryUnaryCall(
request,
deadline,
_timeout_to_deadline(self._loop, timeout),
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
)
else:
return InterceptedUnaryUnaryCall(
self._interceptors,
request,
timeout,
self._channel,
self._method,
self._request_serializer,
@ -138,13 +152,7 @@ class UnaryStreamMultiCallable:
Returns:
A Call object instance which is an awaitable object.
Raises:
RpcError: Indicating that the RPC terminated with non-OK status. The
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
"""
if metadata:
raise NotImplementedError("TODO: metadata not implemented yet")
@ -175,11 +183,14 @@ class Channel:
A cygrpc.AioChannel-backed implementation.
"""
_channel: cygrpc.AioChannel
_unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
def __init__(self, target: Text,
options: Optional[Sequence[Tuple[Text, Any]]],
credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression]):
compression: Optional[grpc.Compression],
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]):
"""Constructor.
Args:
@ -188,8 +199,9 @@ class Channel:
credentials: A cygrpc.ChannelCredentials or None.
compression: An optional value indicating the compression method to be
used over the lifetime of the channel.
interceptors: An optional list of interceptors that would be used for
intercepting any RPC executed with that channel.
"""
if options:
raise NotImplementedError("TODO: options not implemented yet")
@ -199,6 +211,23 @@ class Channel:
if compression:
raise NotImplementedError("TODO: compression not implemented yet")
if interceptors is None:
self._unary_unary_interceptors = None
else:
self._unary_unary_interceptors = list(
filter(
lambda interceptor: isinstance(interceptor, UnaryUnaryClientInterceptor),
interceptors))
invalid_interceptors = set(interceptors) - set(
self._unary_unary_interceptors)
if invalid_interceptors:
raise ValueError(
"Interceptor must be "+\
"UnaryUnaryClientInterceptors, the following are invalid: {}"\
.format(invalid_interceptors))
self._channel = cygrpc.AioChannel(_common.encode(target))
def unary_unary(
@ -220,9 +249,9 @@ class Channel:
Returns:
A UnaryUnaryMultiCallable value for the named unary-unary method.
"""
return UnaryUnaryMultiCallable(self._channel, _common.encode(method),
request_serializer,
response_deserializer)
return UnaryUnaryMultiCallable(
self._channel, _common.encode(method), request_serializer,
response_deserializer, self._unary_unary_interceptors)
def unary_stream(
self,

@ -0,0 +1,336 @@
# Copyright 2019 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interceptors implementation of gRPC Asyncio Python."""
import asyncio
import collections
import functools
from abc import ABCMeta, abstractmethod
from typing import Callable, Optional, Iterator, Sequence, Text, Union
import grpc
from grpc._cython import cygrpc
from . import _base_call
from ._call import UnaryUnaryCall
from ._utils import _timeout_to_deadline
from ._typing import (RequestType, SerializingFunction, DeserializingFunction,
MetadataType, ResponseType)
_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
class ClientCallDetails(
collections.namedtuple(
'ClientCallDetails',
('method', 'timeout', 'metadata', 'credentials')),
grpc.ClientCallDetails):
method: Text
timeout: Optional[float]
metadata: Optional[MetadataType]
credentials: Optional[grpc.CallCredentials]
class UnaryUnaryClientInterceptor(metaclass=ABCMeta):
"""Affords intercepting unary-unary invocations."""
@abstractmethod
async def intercept_unary_unary(
self, continuation: Callable[[ClientCallDetails, RequestType],
UnaryUnaryCall],
client_call_details: ClientCallDetails,
request: RequestType) -> Union[UnaryUnaryCall, ResponseType]:
"""Intercepts a unary-unary invocation asynchronously.
Args:
continuation: A coroutine that proceeds with the invocation by
executing the next interceptor in chain or invoking the
actual RPC on the underlying Channel. It is the interceptor's
responsibility to call it if it decides to move the RPC forward.
The interceptor can use
`response_future = await continuation(client_call_details, request)`
to continue with the RPC. `continuation` returns the response of the
RPC.
client_call_details: A ClientCallDetails object describing the
outgoing RPC.
request: The request value for the RPC.
Returns:
An object with the RPC response.
Raises:
AioRpcError: Indicating that the RPC terminated with non-OK status.
asyncio.CancelledError: Indicating that the RPC was canceled.
"""
class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
"""Used for running a `UnaryUnaryCall` wrapped by interceptors.
Interceptors might have some work to do before the RPC invocation with
the capacity of changing the invocation parameters, and some work to do
after the RPC invocation with the capacity for accessing to the wrapped
`UnaryUnaryCall`.
It handles also early and later cancellations, when the RPC has not even
started and the execution is still held by the interceptors or when the
RPC has finished but again the execution is still held by the interceptors.
Once the RPC is finally executed, all methods are finally done against the
intercepted call, being at the same time the same call returned to the
interceptors.
For most of the methods, like `initial_metadata()` the caller does not need
to wait until the interceptors task is finished, once the RPC is done the
caller will have the freedom for accessing to the results.
For the `__await__` method is it is proxied to the intercepted call only when
the interceptor task is finished.
"""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_cancelled_before_rpc: bool
_intercepted_call: Optional[_base_call.UnaryUnaryCall]
_intercepted_call_created: asyncio.Event
_interceptors_task: asyncio.Task
def __init__( # pylint: disable=R0913
self, interceptors: Sequence[UnaryUnaryClientInterceptor],
request: RequestType, timeout: Optional[float],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
self._channel = channel
self._loop = asyncio.get_event_loop()
self._interceptors_task = asyncio.ensure_future(
self._invoke(interceptors, method, timeout, request,
request_serializer, response_deserializer))
def __del__(self):
self.cancel()
async def _invoke(
self, interceptors: Sequence[UnaryUnaryClientInterceptor],
method: bytes, timeout: Optional[float], request: RequestType,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> UnaryUnaryCall:
"""Run the RPC call wrapped in interceptors"""
async def _run_interceptor(
interceptors: Iterator[UnaryUnaryClientInterceptor],
client_call_details: ClientCallDetails,
request: RequestType) -> _base_call.UnaryUnaryCall:
try:
interceptor = next(interceptors)
except StopIteration:
interceptor = None
if interceptor:
continuation = functools.partial(_run_interceptor, interceptors)
try:
call_or_response = await interceptor.intercept_unary_unary(
continuation, client_call_details, request)
except grpc.RpcError as err:
# gRPC error is masked inside an artificial call,
# caller will see this error if and only
# if it runs an `await call` operation
return UnaryUnaryCallRpcError(err)
except asyncio.CancelledError:
# Cancellation is masked inside an artificial call,
# caller will see this error if and only
# if it runs an `await call` operation
return UnaryUnaryCancelledError()
if isinstance(call_or_response, _base_call.UnaryUnaryCall):
return call_or_response
else:
return UnaryUnaryCallResponse(call_or_response)
else:
return UnaryUnaryCall(request,
_timeout_to_deadline(
self._loop,
client_call_details.timeout),
self._channel, client_call_details.method,
request_serializer, response_deserializer)
client_call_details = ClientCallDetails(method, timeout, None, None)
return await _run_interceptor(
iter(interceptors), client_call_details, request)
def cancel(self) -> bool:
if self._interceptors_task.done():
return False
return self._interceptors_task.cancel()
def cancelled(self) -> bool:
if not self._interceptors_task.done():
return False
call = self._interceptors_task.result()
return call.cancelled()
def done(self) -> bool:
if not self._interceptors_task.done():
return False
return True
def add_done_callback(self, unused_callback) -> None:
raise NotImplementedError()
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
async def initial_metadata(self) -> Optional[MetadataType]:
return await (await self._interceptors_task).initial_metadata()
async def trailing_metadata(self) -> Optional[MetadataType]:
return await (await self._interceptors_task).trailing_metadata()
async def code(self) -> grpc.StatusCode:
return await (await self._interceptors_task).code()
async def details(self) -> str:
return await (await self._interceptors_task).details()
async def debug_error_string(self) -> Optional[str]:
return await (await self._interceptors_task).debug_error_string()
def __await__(self):
call = yield from self._interceptors_task.__await__()
response = yield from call.__await__()
return response
class UnaryUnaryCallRpcError(_base_call.UnaryUnaryCall):
"""Final UnaryUnaryCall class finished with an RpcError."""
_error: grpc.RpcError
def __init__(self, error: grpc.RpcError) -> None:
self._error = error
def cancel(self) -> bool:
return False
def cancelled(self) -> bool:
return False
def done(self) -> bool:
return True
def add_done_callback(self, unused_callback) -> None:
raise NotImplementedError()
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
async def initial_metadata(self) -> Optional[MetadataType]:
return None
async def trailing_metadata(self) -> Optional[MetadataType]:
return self._error.initial_metadata()
async def code(self) -> grpc.StatusCode:
return self._error.code()
async def details(self) -> str:
return self._error.details()
async def debug_error_string(self) -> Optional[str]:
return self._error.debug_error_string()
def __await__(self):
raise self._error
class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
"""Final UnaryUnaryCall class finished with a response."""
_response: ResponseType
def __init__(self, response: ResponseType) -> None:
self._response = response
def cancel(self) -> bool:
return False
def cancelled(self) -> bool:
return False
def done(self) -> bool:
return True
def add_done_callback(self, unused_callback) -> None:
raise NotImplementedError()
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
async def initial_metadata(self) -> Optional[MetadataType]:
return None
async def trailing_metadata(self) -> Optional[MetadataType]:
return None
async def code(self) -> grpc.StatusCode:
return grpc.StatusCode.OK
async def details(self) -> str:
return ''
async def debug_error_string(self) -> Optional[str]:
return None
def __await__(self):
if False: # pylint: disable=W0125
# This code path is never used, but a yield statement is needed
# for telling the interpreter that __await__ is a generator.
yield None
return self._response
class UnaryUnaryCancelledError(_base_call.UnaryUnaryCall):
"""Final UnaryUnaryCall class finished with an asyncio.CancelledError."""
def cancel(self) -> bool:
return False
def cancelled(self) -> bool:
return True
def done(self) -> bool:
return True
def add_done_callback(self, unused_callback) -> None:
raise NotImplementedError()
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
async def initial_metadata(self) -> Optional[MetadataType]:
return None
async def trailing_metadata(self) -> Optional[MetadataType]:
return None
async def code(self) -> grpc.StatusCode:
return grpc.StatusCode.CANCELLED
async def details(self) -> str:
return _LOCAL_CANCELLATION_DETAILS
async def debug_error_string(self) -> Optional[str]:
return None
def __await__(self):
raise asyncio.CancelledError()

@ -0,0 +1,23 @@
# Copyright 2019 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Internal utilities used by the gRPC Aio module."""
import asyncio
from typing import Optional
def _timeout_to_deadline(loop: asyncio.AbstractEventLoop,
timeout: Optional[float]) -> Optional[float]:
if timeout is None:
return None
return loop.time() + timeout

@ -5,5 +5,7 @@
"unit.call_test.TestUnaryUnaryCall",
"unit.channel_test.TestChannel",
"unit.init_test.TestInsecureChannel",
"unit.interceptor_test.TestInterceptedUnaryUnaryCall",
"unit.interceptor_test.TestUnaryUnaryClientInterceptor",
"unit.server_test.TestServer"
]

@ -17,9 +17,9 @@ import logging
import datetime
from grpc.experimental import aio
from tests.unit.framework.common import test_constants
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc
from tests.unit.framework.common import test_constants
class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):

@ -26,6 +26,7 @@ from src.proto.grpc.testing import test_pb2_grpc
from tests.unit.framework.common import test_constants
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2
_NUM_STREAM_RESPONSES = 5
_RESPONSE_PAYLOAD_SIZE = 42
@ -399,5 +400,5 @@ class TestUnaryStreamCall(AioTestBase):
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
logging.basicConfig()
unittest.main(verbosity=2)

@ -25,6 +25,7 @@ from src.proto.grpc.testing import test_pb2_grpc
from tests.unit.framework.common import test_constants
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2
_UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall'
_STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall'

@ -0,0 +1,504 @@
# Copyright 2019 The gRPC Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import unittest
import grpc
from grpc.experimental import aio
from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
class TestUnaryUnaryClientInterceptor(AioTestBase):
def test_invalid_interceptor(self):
class InvalidInterceptor:
"""Just an invalid Interceptor"""
with self.assertRaises(ValueError):
aio.insecure_channel("", interceptors=[InvalidInterceptor()])
async def test_executed_right_order(self):
interceptors_executed = []
class Interceptor(aio.UnaryUnaryClientInterceptor):
"""Interceptor used for testing if the interceptor is being called"""
async def intercept_unary_unary(self, continuation,
client_call_details, request):
interceptors_executed.append(self)
call = await continuation(client_call_details, request)
return call
interceptors = [Interceptor() for i in range(2)]
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(
server_target, interceptors=interceptors) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
response = await call
# Check that all interceptors were executed, and were executed
# in the right order.
self.assertSequenceEqual(interceptors_executed, interceptors)
self.assertIsInstance(response, messages_pb2.SimpleResponse)
@unittest.expectedFailure
# TODO(https://github.com/grpc/grpc/issues/20144) Once metadata support is
# implemented in the client-side, this test must be implemented.
def test_modify_metadata(self):
raise NotImplementedError()
@unittest.expectedFailure
# TODO(https://github.com/grpc/grpc/issues/20532) Once credentials support is
# implemented in the client-side, this test must be implemented.
def test_modify_credentials(self):
raise NotImplementedError()
async def test_status_code_Ok(self):
class StatusCodeOkInterceptor(aio.UnaryUnaryClientInterceptor):
"""Interceptor used for observing status code Ok returned by the RPC"""
def __init__(self):
self.status_code_Ok_observed = False
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
code = await call.code()
if code == grpc.StatusCode.OK:
self.status_code_Ok_observed = True
return call
interceptor = StatusCodeOkInterceptor()
server_target, server = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(
server_target, interceptors=[interceptor]) as channel:
# when no error StatusCode.OK must be observed
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
await multicallable(messages_pb2.SimpleRequest())
self.assertTrue(interceptor.status_code_Ok_observed)
async def test_add_timeout(self):
class TimeoutInterceptor(aio.UnaryUnaryClientInterceptor):
"""Interceptor used for adding a timeout to the RPC"""
async def intercept_unary_unary(self, continuation,
client_call_details, request):
new_client_call_details = aio.ClientCallDetails(
method=client_call_details.method,
timeout=0.1,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials)
return await continuation(new_client_call_details, request)
interceptor = TimeoutInterceptor()
server_target, server = await start_test_server()
async with aio.insecure_channel(
server_target, interceptors=[interceptor]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
await server.stop(None)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
self.assertEqual(exception_context.exception.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertTrue(call.done())
self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await
call.code())
async def test_retry(self):
class RetryInterceptor(aio.UnaryUnaryClientInterceptor):
"""Simulates a Retry Interceptor which ends up by making
two RPC calls."""
def __init__(self):
self.calls = []
async def intercept_unary_unary(self, continuation,
client_call_details, request):
new_client_call_details = aio.ClientCallDetails(
method=client_call_details.method,
timeout=0.1,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials)
try:
call = await continuation(new_client_call_details, request)
await call
except grpc.RpcError:
pass
self.calls.append(call)
new_client_call_details = aio.ClientCallDetails(
method=client_call_details.method,
timeout=None,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials)
call = await continuation(new_client_call_details, request)
self.calls.append(call)
return call
interceptor = RetryInterceptor()
server_target, server = await start_test_server()
async with aio.insecure_channel(
server_target, interceptors=[interceptor]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
await call
self.assertEqual(grpc.StatusCode.OK, await call.code())
# Check that two calls were made, first one finishing with
# a deadline and second one finishing ok..
self.assertEqual(len(interceptor.calls), 2)
self.assertEqual(await interceptor.calls[0].code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertEqual(await interceptor.calls[1].code(),
grpc.StatusCode.OK)
async def test_rpcerror_raised_when_call_is_awaited(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
"""RpcErrors are only seen when the call is awaited"""
def __init__(self):
self.deadline_seen = False
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
try:
await call
except aio.AioRpcError as err:
if err.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
self.deadline_seen = True
raise
# This point should never be reached
raise Exception()
interceptor_a, interceptor_b = (Interceptor(), Interceptor())
server_target, server = await start_test_server()
async with aio.insecure_channel(
server_target, interceptors=[interceptor_a,
interceptor_b]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest(), timeout=0.1)
with self.assertRaises(grpc.RpcError) as exception_context:
await call
# Check that the two interceptors catch the deadline exception
# only when the call was awaited
self.assertTrue(interceptor_a.deadline_seen)
self.assertTrue(interceptor_b.deadline_seen)
# Check all of the UnaryUnaryCallRpcError attributes
self.assertTrue(call.done())
self.assertFalse(call.cancel())
self.assertFalse(call.cancelled())
self.assertEqual(await call.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertEqual(await call.details(), 'Deadline Exceeded')
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.debug_error_string(), None)
async def test_rpcresponse(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
"""Raw responses are seen as reegular calls"""
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
response = await call
return call
class ResponseInterceptor(aio.UnaryUnaryClientInterceptor):
"""Return a raw response"""
response = messages_pb2.SimpleResponse()
async def intercept_unary_unary(self, continuation,
client_call_details, request):
return ResponseInterceptor.response
interceptor, interceptor_response = Interceptor(), ResponseInterceptor()
server_target, server = await start_test_server()
async with aio.insecure_channel(
server_target, interceptors=[interceptor,
interceptor_response]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
response = await call
# Check that the response returned is the one returned by the
# interceptor
self.assertEqual(id(response), id(ResponseInterceptor.response))
# Check all of the UnaryUnaryCallResponse attributes
self.assertTrue(call.done())
self.assertFalse(call.cancel())
self.assertFalse(call.cancelled())
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.details(), '')
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None)
self.assertEqual(await call.debug_error_string(), None)
class TestInterceptedUnaryUnaryCall(AioTestBase):
async def test_call_ok(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
return call
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(
server_target, interceptors=[Interceptor()]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
response = await call
self.assertTrue(call.done())
self.assertFalse(call.cancelled())
self.assertEqual(type(response), messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.details(), '')
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
async def test_cancel_before_rpc(self):
interceptor_reached = asyncio.Event()
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
interceptor_reached.set()
await asyncio.sleep(0)
# This line should never be reached
raise Exception()
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(
server_target, interceptors=[Interceptor()]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
self.assertFalse(call.cancelled())
self.assertFalse(call.done())
await interceptor_reached.wait()
self.assertTrue(call.cancel())
with self.assertRaises(asyncio.CancelledError):
await call
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None)
async def test_cancel_after_rpc(self):
interceptor_reached = asyncio.Event()
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
await call
interceptor_reached.set()
await asyncio.sleep(0)
# This line should never be reached
raise Exception()
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(
server_target, interceptors=[Interceptor()]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
self.assertFalse(call.cancelled())
self.assertFalse(call.done())
await interceptor_reached.wait()
self.assertTrue(call.cancel())
with self.assertRaises(asyncio.CancelledError):
await call
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None)
async def test_cancel_inside_interceptor_after_rpc_awaiting(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
call.cancel()
await call
return call
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(
server_target, interceptors=[Interceptor()]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
with self.assertRaises(asyncio.CancelledError):
await call
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None)
async def test_cancel_inside_interceptor_after_rpc_not_awaiting(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
call.cancel()
return call
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(
server_target, interceptors=[Interceptor()]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
with self.assertRaises(asyncio.CancelledError):
await call
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
self.assertEqual(await call.initial_metadata(), tuple())
self.assertEqual(await call.trailing_metadata(), None)
if __name__ == '__main__':
logging.basicConfig()
unittest.main(verbosity=2)
Loading…
Cancel
Save