Merge pull request #21455 from Skyscanner/client_unaryunary_interceptors_3

[Aio] Client Side Interceptor For Unary Calls
pull/21642/head
Lidi Zheng 5 years ago committed by GitHub
commit da6a29dd6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi
  2. 26
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  3. 1
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  4. 2
      src/python/grpcio/grpc/experimental/BUILD.bazel
  5. 24
      src/python/grpcio/grpc/experimental/aio/__init__.py
  6. 2
      src/python/grpcio/grpc/experimental/aio/_call.py
  7. 89
      src/python/grpcio/grpc/experimental/aio/_channel.py
  8. 291
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  9. 22
      src/python/grpcio/grpc/experimental/aio/_utils.py
  10. 2
      src/python/grpcio_tests/tests_aio/tests.json
  11. 32
      src/python/grpcio_tests/tests_aio/unit/_test_server.py
  12. 3
      src/python/grpcio_tests/tests_aio/unit/call_test.py
  13. 24
      src/python/grpcio_tests/tests_aio/unit/channel_test.py
  14. 538
      src/python/grpcio_tests/tests_aio/unit/interceptor_test.py

@ -13,11 +13,10 @@
# limitations under the License. # limitations under the License.
cdef class _AioCall: cdef class _AioCall(GrpcCallWrapper):
cdef: cdef:
AioChannel _channel AioChannel _channel
list _references list _references
GrpcCallWrapper _grpc_call_wrapper
# Caches the picked event loop, so we can avoid the 30ns overhead each # Caches the picked event loop, so we can avoid the 30ns overhead each
# time we need access to the event loop. # time we need access to the event loop.
object _loop object _loop
@ -30,4 +29,3 @@ cdef class _AioCall:
bint _is_locally_cancelled bint _is_locally_cancelled
cdef grpc_call* _create_grpc_call(self, object timeout, bytes method) except * cdef grpc_call* _create_grpc_call(self, object timeout, bytes method) except *
cdef void _destroy_grpc_call(self)

@ -15,6 +15,7 @@
cimport cpython cimport cpython
import grpc import grpc
_EMPTY_FLAGS = 0 _EMPTY_FLAGS = 0
_EMPTY_MASK = 0 _EMPTY_MASK = 0
_EMPTY_METADATA = None _EMPTY_METADATA = None
@ -28,15 +29,16 @@ cdef class _AioCall:
AioChannel channel, AioChannel channel,
object deadline, object deadline,
bytes method): bytes method):
self.call = NULL
self._channel = channel self._channel = channel
self._references = [] self._references = []
self._grpc_call_wrapper = GrpcCallWrapper()
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self._create_grpc_call(deadline, method) self._create_grpc_call(deadline, method)
self._is_locally_cancelled = False self._is_locally_cancelled = False
def __dealloc__(self): def __dealloc__(self):
self._destroy_grpc_call() if self.call:
grpc_call_unref(self.call)
def __repr__(self): def __repr__(self):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
@ -61,7 +63,7 @@ cdef class _AioCall:
<const char *> method, <const char *> method,
<size_t> len(method) <size_t> len(method)
) )
self._grpc_call_wrapper.call = grpc_channel_create_call( self.call = grpc_channel_create_call(
self._channel.channel, self._channel.channel,
NULL, NULL,
_EMPTY_MASK, _EMPTY_MASK,
@ -73,10 +75,6 @@ cdef class _AioCall:
) )
grpc_slice_unref(method_slice) grpc_slice_unref(method_slice)
cdef void _destroy_grpc_call(self):
"""Destroys the corresponding Core object for this RPC."""
grpc_call_unref(self._grpc_call_wrapper.call)
def cancel(self, AioRpcStatus status): def cancel(self, AioRpcStatus status):
"""Cancels the RPC in Core with given RPC status. """Cancels the RPC in Core with given RPC status.
@ -97,7 +95,7 @@ cdef class _AioCall:
c_details = <char *>details c_details = <char *>details
# By implementation, grpc_call_cancel_with_status always return OK # By implementation, grpc_call_cancel_with_status always return OK
error = grpc_call_cancel_with_status( error = grpc_call_cancel_with_status(
self._grpc_call_wrapper.call, self.call,
status.c_code(), status.c_code(),
c_details, c_details,
NULL, NULL,
@ -105,7 +103,7 @@ cdef class _AioCall:
assert error == GRPC_CALL_OK assert error == GRPC_CALL_OK
else: else:
# By implementation, grpc_call_cancel always return OK # By implementation, grpc_call_cancel always return OK
error = grpc_call_cancel(self._grpc_call_wrapper.call, NULL) error = grpc_call_cancel(self.call, NULL)
assert error == GRPC_CALL_OK assert error == GRPC_CALL_OK
async def unary_unary(self, async def unary_unary(self,
@ -140,7 +138,7 @@ cdef class _AioCall:
# Executes all operations in one batch. # Executes all operations in one batch.
# Might raise CancelledError, handling it in Python UnaryUnaryCall. # Might raise CancelledError, handling it in Python UnaryUnaryCall.
await execute_batch(self._grpc_call_wrapper, await execute_batch(self,
ops, ops,
self._loop) self._loop)
@ -163,7 +161,7 @@ cdef class _AioCall:
"""Handles the status sent by peer once received.""" """Handles the status sent by peer once received."""
cdef ReceiveStatusOnClientOperation op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS) cdef ReceiveStatusOnClientOperation op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS)
cdef tuple ops = (op,) cdef tuple ops = (op,)
await execute_batch(self._grpc_call_wrapper, ops, self._loop) await execute_batch(self, ops, self._loop)
# Halts if the RPC is locally cancelled # Halts if the RPC is locally cancelled
if self._is_locally_cancelled: if self._is_locally_cancelled:
@ -186,7 +184,7 @@ cdef class _AioCall:
# * The client application cancels; # * The client application cancels;
# * The server sends final status. # * The server sends final status.
received_message = await _receive_message( received_message = await _receive_message(
self._grpc_call_wrapper, self,
self._loop self._loop
) )
return received_message return received_message
@ -217,12 +215,12 @@ cdef class _AioCall:
) )
# Sends out the request message. # Sends out the request message.
await execute_batch(self._grpc_call_wrapper, await execute_batch(self,
outbound_ops, outbound_ops,
self._loop) self._loop)
# Receives initial metadata. # Receives initial metadata.
initial_metadata_observer( initial_metadata_observer(
await _receive_initial_metadata(self._grpc_call_wrapper, await _receive_initial_metadata(self,
self._loop), self._loop),
) )

@ -30,6 +30,7 @@ cdef class _HandlerCallDetails:
cdef class RPCState: cdef class RPCState:
def __cinit__(self, AioServer server): def __cinit__(self, AioServer server):
self.call = NULL
self.server = server self.server = server
grpc_metadata_array_init(&self.request_metadata) grpc_metadata_array_init(&self.request_metadata)
grpc_call_details_init(&self.details) grpc_call_details_init(&self.details)

@ -7,8 +7,10 @@ py_library(
"aio/_base_call.py", "aio/_base_call.py",
"aio/_call.py", "aio/_call.py",
"aio/_channel.py", "aio/_channel.py",
"aio/_interceptor.py",
"aio/_server.py", "aio/_server.py",
"aio/_typing.py", "aio/_typing.py",
"aio/_utils.py",
], ],
deps = [ deps = [
"//src/python/grpcio/grpc/_cython:cygrpc", "//src/python/grpcio/grpc/_cython:cygrpc",

@ -18,18 +18,26 @@ created. AsyncIO doesn't provide thread safety for most of its APIs.
""" """
import abc import abc
from typing import Any, Optional, Sequence, Text, Tuple
import six import six
import grpc import grpc
from grpc._cython.cygrpc import init_grpc_aio from grpc._cython.cygrpc import init_grpc_aio
from ._base_call import RpcContext, Call, UnaryUnaryCall, UnaryStreamCall from ._base_call import RpcContext, Call, UnaryUnaryCall, UnaryStreamCall
from ._call import AioRpcError
from ._channel import Channel from ._channel import Channel
from ._channel import UnaryUnaryMultiCallable from ._channel import UnaryUnaryMultiCallable
from ._interceptor import ClientCallDetails, UnaryUnaryClientInterceptor
from ._interceptor import InterceptedUnaryUnaryCall
from ._server import server from ._server import server
def insecure_channel(target, options=None, compression=None): def insecure_channel(
target: Text,
options: Optional[Sequence[Tuple[Text, Any]]] = None,
compression: Optional[grpc.Compression] = None,
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]] = None):
"""Creates an insecure asynchronous Channel to a server. """Creates an insecure asynchronous Channel to a server.
Args: Args:
@ -38,16 +46,22 @@ def insecure_channel(target, options=None, compression=None):
in gRPC Core runtime) to configure the channel. in gRPC Core runtime) to configure the channel.
compression: An optional value indicating the compression method to be compression: An optional value indicating the compression method to be
used over the lifetime of the channel. This is an EXPERIMENTAL option. used over the lifetime of the channel. This is an EXPERIMENTAL option.
interceptors: An optional sequence of interceptors that will be executed for
any call executed with this channel.
Returns: Returns:
A Channel. A Channel.
""" """
return Channel(target, () if options is None else options, None, return Channel(target, () if options is None else options,
compression) None,
compression,
interceptors=interceptors)
################################### __all__ ################################# ################################### __all__ #################################
__all__ = ('RpcContext', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall', __all__ = ('AioRpcError', 'RpcContext', 'Call', 'UnaryUnaryCall',
'init_grpc_aio', 'Channel', 'UnaryUnaryMultiCallable', 'UnaryStreamCall', 'init_grpc_aio', 'Channel',
'UnaryUnaryMultiCallable', 'ClientCallDetails',
'UnaryUnaryClientInterceptor', 'InterceptedUnaryUnaryCall',
'insecure_channel', 'server') 'insecure_channel', 'server')

@ -233,7 +233,7 @@ class Call(_base_call.Call):
if self._code is grpc.StatusCode.OK: if self._code is grpc.StatusCode.OK:
return _OK_CALL_REPRESENTATION.format( return _OK_CALL_REPRESENTATION.format(
self.__class__.__name__, self._code, self.__class__.__name__, self._code,
self._status.result().self._status.result().details()) self._status.result().details())
else: else:
return _NON_OK_CALL_REPRESENTATION.format( return _NON_OK_CALL_REPRESENTATION.format(
self.__class__.__name__, self._code, self.__class__.__name__, self._code,

@ -18,29 +18,35 @@ from typing import Any, Optional, Sequence, Text, Tuple
import grpc import grpc
from grpc import _common from grpc import _common
from grpc._cython import cygrpc from grpc._cython import cygrpc
from . import _base_call from . import _base_call
from ._call import UnaryUnaryCall, UnaryStreamCall from ._call import UnaryUnaryCall, UnaryStreamCall
from ._interceptor import UnaryUnaryClientInterceptor, InterceptedUnaryUnaryCall
from ._typing import (DeserializingFunction, MetadataType, SerializingFunction) from ._typing import (DeserializingFunction, MetadataType, SerializingFunction)
from ._utils import _timeout_to_deadline
def _timeout_to_deadline(loop: asyncio.AbstractEventLoop,
timeout: Optional[float]) -> Optional[float]:
if timeout is None:
return None
return loop.time() + timeout
class UnaryUnaryMultiCallable: class UnaryUnaryMultiCallable:
"""Factory an asynchronous unary-unary RPC stub call from client-side.""" """Factory an asynchronous unary-unary RPC stub call from client-side."""
_channel: cygrpc.AioChannel
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
_loop: asyncio.AbstractEventLoop
def __init__(self, channel: cygrpc.AioChannel, method: bytes, def __init__(self, channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction, request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None: response_deserializer: DeserializingFunction,
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
) -> None:
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
self._channel = channel self._channel = channel
self._method = method self._method = method
self._request_serializer = request_serializer self._request_serializer = request_serializer
self._response_deserializer = response_deserializer self._response_deserializer = response_deserializer
self._interceptors = interceptors
def __call__(self, def __call__(self,
request: Any, request: Any,
@ -74,7 +80,6 @@ class UnaryUnaryMultiCallable:
raised RpcError will also be a Call for the RPC affording the RPC's raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details. metadata, status code, and details.
""" """
if metadata: if metadata:
raise NotImplementedError("TODO: metadata not implemented yet") raise NotImplementedError("TODO: metadata not implemented yet")
@ -88,16 +93,25 @@ class UnaryUnaryMultiCallable:
if compression: if compression:
raise NotImplementedError("TODO: compression not implemented yet") raise NotImplementedError("TODO: compression not implemented yet")
deadline = _timeout_to_deadline(self._loop, timeout) if not self._interceptors:
return UnaryUnaryCall(
return UnaryUnaryCall( request,
request, _timeout_to_deadline(timeout),
deadline, self._channel,
self._channel, self._method,
self._method, self._request_serializer,
self._request_serializer, self._response_deserializer,
self._response_deserializer, )
) else:
return InterceptedUnaryUnaryCall(
self._interceptors,
request,
timeout,
self._channel,
self._method,
self._request_serializer,
self._response_deserializer,
)
class UnaryStreamMultiCallable: class UnaryStreamMultiCallable:
@ -138,13 +152,7 @@ class UnaryStreamMultiCallable:
Returns: Returns:
A Call object instance which is an awaitable object. A Call object instance which is an awaitable object.
Raises:
RpcError: Indicating that the RPC terminated with non-OK status. The
raised RpcError will also be a Call for the RPC affording the RPC's
metadata, status code, and details.
""" """
if metadata: if metadata:
raise NotImplementedError("TODO: metadata not implemented yet") raise NotImplementedError("TODO: metadata not implemented yet")
@ -158,7 +166,7 @@ class UnaryStreamMultiCallable:
if compression: if compression:
raise NotImplementedError("TODO: compression not implemented yet") raise NotImplementedError("TODO: compression not implemented yet")
deadline = _timeout_to_deadline(self._loop, timeout) deadline = _timeout_to_deadline(timeout)
return UnaryStreamCall( return UnaryStreamCall(
request, request,
@ -175,11 +183,14 @@ class Channel:
A cygrpc.AioChannel-backed implementation. A cygrpc.AioChannel-backed implementation.
""" """
_channel: cygrpc.AioChannel
_unary_unary_interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]
def __init__(self, target: Text, def __init__(self, target: Text,
options: Optional[Sequence[Tuple[Text, Any]]], options: Optional[Sequence[Tuple[Text, Any]]],
credentials: Optional[grpc.ChannelCredentials], credentials: Optional[grpc.ChannelCredentials],
compression: Optional[grpc.Compression]): compression: Optional[grpc.Compression],
interceptors: Optional[Sequence[UnaryUnaryClientInterceptor]]):
"""Constructor. """Constructor.
Args: Args:
@ -188,8 +199,9 @@ class Channel:
credentials: A cygrpc.ChannelCredentials or None. credentials: A cygrpc.ChannelCredentials or None.
compression: An optional value indicating the compression method to be compression: An optional value indicating the compression method to be
used over the lifetime of the channel. used over the lifetime of the channel.
interceptors: An optional list of interceptors that would be used for
intercepting any RPC executed with that channel.
""" """
if options: if options:
raise NotImplementedError("TODO: options not implemented yet") raise NotImplementedError("TODO: options not implemented yet")
@ -199,6 +211,24 @@ class Channel:
if compression: if compression:
raise NotImplementedError("TODO: compression not implemented yet") raise NotImplementedError("TODO: compression not implemented yet")
if interceptors is None:
self._unary_unary_interceptors = None
else:
self._unary_unary_interceptors = list(
filter(
lambda interceptor: isinstance(interceptor,
UnaryUnaryClientInterceptor),
interceptors))
invalid_interceptors = set(interceptors) - set(
self._unary_unary_interceptors)
if invalid_interceptors:
raise ValueError(
"Interceptor must be "+\
"UnaryUnaryClientInterceptors, the following are invalid: {}"\
.format(invalid_interceptors))
self._channel = cygrpc.AioChannel(_common.encode(target)) self._channel = cygrpc.AioChannel(_common.encode(target))
def unary_unary( def unary_unary(
@ -222,7 +252,8 @@ class Channel:
""" """
return UnaryUnaryMultiCallable(self._channel, _common.encode(method), return UnaryUnaryMultiCallable(self._channel, _common.encode(method),
request_serializer, request_serializer,
response_deserializer) response_deserializer,
self._unary_unary_interceptors)
def unary_stream( def unary_stream(
self, self,

@ -0,0 +1,291 @@
# Copyright 2019 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interceptors implementation of gRPC Asyncio Python."""
import asyncio
import collections
import functools
from abc import ABCMeta, abstractmethod
from typing import Callable, Optional, Iterator, Sequence, Text, Union
import grpc
from grpc._cython import cygrpc
from . import _base_call
from ._call import UnaryUnaryCall, AioRpcError
from ._utils import _timeout_to_deadline
from ._typing import (RequestType, SerializingFunction, DeserializingFunction,
MetadataType, ResponseType)
_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
class ClientCallDetails(
collections.namedtuple(
'ClientCallDetails',
('method', 'timeout', 'metadata', 'credentials')),
grpc.ClientCallDetails):
method: Text
timeout: Optional[float]
metadata: Optional[MetadataType]
credentials: Optional[grpc.CallCredentials]
class UnaryUnaryClientInterceptor(metaclass=ABCMeta):
"""Affords intercepting unary-unary invocations."""
@abstractmethod
async def intercept_unary_unary(
self, continuation: Callable[[ClientCallDetails, RequestType],
UnaryUnaryCall],
client_call_details: ClientCallDetails,
request: RequestType) -> Union[UnaryUnaryCall, ResponseType]:
"""Intercepts a unary-unary invocation asynchronously.
Args:
continuation: A coroutine that proceeds with the invocation by
executing the next interceptor in chain or invoking the
actual RPC on the underlying Channel. It is the interceptor's
responsibility to call it if it decides to move the RPC forward.
The interceptor can use
`response_future = await continuation(client_call_details, request)`
to continue with the RPC. `continuation` returns the response of the
RPC.
client_call_details: A ClientCallDetails object describing the
outgoing RPC.
request: The request value for the RPC.
Returns:
An object with the RPC response.
Raises:
AioRpcError: Indicating that the RPC terminated with non-OK status.
asyncio.CancelledError: Indicating that the RPC was canceled.
"""
class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
"""Used for running a `UnaryUnaryCall` wrapped by interceptors.
Interceptors might have some work to do before the RPC invocation with
the capacity of changing the invocation parameters, and some work to do
after the RPC invocation with the capacity for accessing to the wrapped
`UnaryUnaryCall`.
It handles also early and later cancellations, when the RPC has not even
started and the execution is still held by the interceptors or when the
RPC has finished but again the execution is still held by the interceptors.
Once the RPC is finally executed, all methods are finally done against the
intercepted call, being at the same time the same call returned to the
interceptors.
For most of the methods, like `initial_metadata()` the caller does not need
to wait until the interceptors task is finished, once the RPC is done the
caller will have the freedom for accessing to the results.
For the `__await__` method is it is proxied to the intercepted call only when
the interceptor task is finished.
"""
_loop: asyncio.AbstractEventLoop
_channel: cygrpc.AioChannel
_cancelled_before_rpc: bool
_intercepted_call: Optional[_base_call.UnaryUnaryCall]
_intercepted_call_created: asyncio.Event
_interceptors_task: asyncio.Task
def __init__( # pylint: disable=R0913
self, interceptors: Sequence[UnaryUnaryClientInterceptor],
request: RequestType, timeout: Optional[float],
channel: cygrpc.AioChannel, method: bytes,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction) -> None:
self._channel = channel
self._loop = asyncio.get_event_loop()
self._interceptors_task = asyncio.ensure_future(
self._invoke(interceptors, method, timeout, request,
request_serializer, response_deserializer))
def __del__(self):
self.cancel()
async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor],
method: bytes, timeout: Optional[float],
request: RequestType,
request_serializer: SerializingFunction,
response_deserializer: DeserializingFunction
) -> UnaryUnaryCall:
"""Run the RPC call wrapped in interceptors"""
async def _run_interceptor(
interceptors: Iterator[UnaryUnaryClientInterceptor],
client_call_details: ClientCallDetails,
request: RequestType) -> _base_call.UnaryUnaryCall:
interceptor = next(interceptors, None)
if interceptor:
continuation = functools.partial(_run_interceptor, interceptors)
call_or_response = await interceptor.intercept_unary_unary(
continuation, client_call_details, request)
if isinstance(call_or_response, _base_call.UnaryUnaryCall):
return call_or_response
else:
return UnaryUnaryCallResponse(call_or_response)
else:
return UnaryUnaryCall(
request, _timeout_to_deadline(client_call_details.timeout),
self._channel, client_call_details.method,
request_serializer, response_deserializer)
client_call_details = ClientCallDetails(method, timeout, None, None)
return await _run_interceptor(iter(interceptors), client_call_details,
request)
def cancel(self) -> bool:
if self._interceptors_task.done():
return False
return self._interceptors_task.cancel()
def cancelled(self) -> bool:
if not self._interceptors_task.done():
return False
try:
call = self._interceptors_task.result()
except AioRpcError as err:
return err.code() == grpc.StatusCode.CANCELLED
except asyncio.CancelledError:
return True
return call.cancelled()
def done(self) -> bool:
if not self._interceptors_task.done():
return False
try:
call = self._interceptors_task.result()
except (AioRpcError, asyncio.CancelledError):
return True
return call.done()
def add_done_callback(self, unused_callback) -> None:
raise NotImplementedError()
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
async def initial_metadata(self) -> Optional[MetadataType]:
try:
call = await self._interceptors_task
except AioRpcError as err:
return err.initial_metadata()
except asyncio.CancelledError:
return None
return await call.initial_metadata()
async def trailing_metadata(self) -> Optional[MetadataType]:
try:
call = await self._interceptors_task
except AioRpcError as err:
return err.trailing_metadata()
except asyncio.CancelledError:
return None
return await call.trailing_metadata()
async def code(self) -> grpc.StatusCode:
try:
call = await self._interceptors_task
except AioRpcError as err:
return err.code()
except asyncio.CancelledError:
return grpc.StatusCode.CANCELLED
return await call.code()
async def details(self) -> str:
try:
call = await self._interceptors_task
except AioRpcError as err:
return err.details()
except asyncio.CancelledError:
return _LOCAL_CANCELLATION_DETAILS
return await call.details()
async def debug_error_string(self) -> Optional[str]:
try:
call = await self._interceptors_task
except AioRpcError as err:
return err.debug_error_string()
except asyncio.CancelledError:
return ''
return await call.debug_error_string()
def __await__(self):
call = yield from self._interceptors_task.__await__()
response = yield from call.__await__()
return response
class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
"""Final UnaryUnaryCall class finished with a response."""
_response: ResponseType
def __init__(self, response: ResponseType) -> None:
self._response = response
def cancel(self) -> bool:
return False
def cancelled(self) -> bool:
return False
def done(self) -> bool:
return True
def add_done_callback(self, unused_callback) -> None:
raise NotImplementedError()
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
async def initial_metadata(self) -> Optional[MetadataType]:
return None
async def trailing_metadata(self) -> Optional[MetadataType]:
return None
async def code(self) -> grpc.StatusCode:
return grpc.StatusCode.OK
async def details(self) -> str:
return ''
async def debug_error_string(self) -> Optional[str]:
return None
def __await__(self):
if False: # pylint: disable=W0125
# This code path is never used, but a yield statement is needed
# for telling the interpreter that __await__ is a generator.
yield None
return self._response

@ -0,0 +1,22 @@
# Copyright 2019 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Internal utilities used by the gRPC Aio module."""
import time
from typing import Optional
def _timeout_to_deadline(timeout: Optional[float]) -> Optional[float]:
if timeout is None:
return None
return time.time() + timeout

@ -5,5 +5,7 @@
"unit.call_test.TestUnaryUnaryCall", "unit.call_test.TestUnaryUnaryCall",
"unit.channel_test.TestChannel", "unit.channel_test.TestChannel",
"unit.init_test.TestInsecureChannel", "unit.init_test.TestInsecureChannel",
"unit.interceptor_test.TestInterceptedUnaryUnaryCall",
"unit.interceptor_test.TestUnaryUnaryClientInterceptor",
"unit.server_test.TestServer" "unit.server_test.TestServer"
] ]

@ -16,10 +16,13 @@ import asyncio
import logging import logging
import datetime import datetime
import grpc
from grpc.experimental import aio from grpc.experimental import aio
from tests.unit.framework.common import test_constants
from src.proto.grpc.testing import messages_pb2 from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc from src.proto.grpc.testing import test_pb2_grpc
from tests.unit.framework.common import test_constants
UNARY_CALL_WITH_SLEEP_VALUE = 0.2
class _TestServiceServicer(test_pb2_grpc.TestServiceServicer): class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
@ -39,11 +42,34 @@ class _TestServiceServicer(test_pb2_grpc.TestServiceServicer):
body=b'\x00' * body=b'\x00' *
response_parameters.size)) response_parameters.size))
# Next methods are extra ones that are registred programatically
# when the sever is instantiated. They are not being provided by
# the proto file.
async def UnaryCallWithSleep(self, request, context):
await asyncio.sleep(UNARY_CALL_WITH_SLEEP_VALUE)
return messages_pb2.SimpleResponse()
async def start_test_server(): async def start_test_server():
server = aio.server(options=(('grpc.so_reuseport', 0),)) server = aio.server(options=(('grpc.so_reuseport', 0),))
test_pb2_grpc.add_TestServiceServicer_to_server(_TestServiceServicer(), servicer = _TestServiceServicer()
server) test_pb2_grpc.add_TestServiceServicer_to_server(servicer, server)
# Add programatically extra methods not provided by the proto file
# that are used during the tests
rpc_method_handlers = {
'UnaryCallWithSleep':
grpc.unary_unary_rpc_method_handler(
servicer.UnaryCallWithSleep,
request_deserializer=messages_pb2.SimpleRequest.FromString,
response_serializer=messages_pb2.SimpleResponse.
SerializeToString)
}
extra_handler = grpc.method_handlers_generic_handler(
'grpc.testing.TestService', rpc_method_handlers)
server.add_generic_rpc_handlers((extra_handler,))
port = server.add_insecure_port('[::]:0') port = server.add_insecure_port('[::]:0')
await server.start() await server.start()
# NOTE(lidizheng) returning the server to prevent it from deallocation # NOTE(lidizheng) returning the server to prevent it from deallocation

@ -26,6 +26,7 @@ from src.proto.grpc.testing import test_pb2_grpc
from tests.unit.framework.common import test_constants from tests.unit.framework.common import test_constants
from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_server import start_test_server
from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2
_NUM_STREAM_RESPONSES = 5 _NUM_STREAM_RESPONSES = 5
_RESPONSE_PAYLOAD_SIZE = 42 _RESPONSE_PAYLOAD_SIZE = 42
@ -399,5 +400,5 @@ class TestUnaryStreamCall(AioTestBase):
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG) logging.basicConfig()
unittest.main(verbosity=2) unittest.main(verbosity=2)

@ -23,10 +23,12 @@ from grpc.experimental import aio
from src.proto.grpc.testing import messages_pb2 from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc from src.proto.grpc.testing import test_pb2_grpc
from tests.unit.framework.common import test_constants from tests.unit.framework.common import test_constants
from tests_aio.unit._test_server import start_test_server from tests_aio.unit._test_server import start_test_server, UNARY_CALL_WITH_SLEEP_VALUE
from tests_aio.unit._test_base import AioTestBase from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2
_UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall' _UNARY_CALL_METHOD = '/grpc.testing.TestService/UnaryCall'
_UNARY_CALL_METHOD_WITH_SLEEP = '/grpc.testing.TestService/UnaryCallWithSleep'
_STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall' _STREAMING_OUTPUT_CALL_METHOD = '/grpc.testing.TestService/StreamingOutputCall'
_NUM_STREAM_RESPONSES = 5 _NUM_STREAM_RESPONSES = 5
_RESPONSE_PAYLOAD_SIZE = 42 _RESPONSE_PAYLOAD_SIZE = 42
@ -51,7 +53,6 @@ class TestChannel(AioTestBase):
async def test_unary_unary(self): async def test_unary_unary(self):
async with aio.insecure_channel(self._server_target) as channel: async with aio.insecure_channel(self._server_target) as channel:
channel = aio.insecure_channel(self._server_target)
hi = channel.unary_unary( hi = channel.unary_unary(
_UNARY_CALL_METHOD, _UNARY_CALL_METHOD,
request_serializer=messages_pb2.SimpleRequest.SerializeToString, request_serializer=messages_pb2.SimpleRequest.SerializeToString,
@ -61,15 +62,16 @@ class TestChannel(AioTestBase):
self.assertIsInstance(response, messages_pb2.SimpleResponse) self.assertIsInstance(response, messages_pb2.SimpleResponse)
async def test_unary_call_times_out(self): async def test_unary_call_times_out(self):
async with aio.insecure_channel(_UNREACHABLE_TARGET) as channel: async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary( hi = channel.unary_unary(
_UNARY_CALL_METHOD, _UNARY_CALL_METHOD_WITH_SLEEP,
request_serializer=messages_pb2.SimpleRequest.SerializeToString, request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString, response_deserializer=messages_pb2.SimpleResponse.FromString,
) )
with self.assertRaises(grpc.RpcError) as exception_context: with self.assertRaises(grpc.RpcError) as exception_context:
await hi(messages_pb2.SimpleRequest(), timeout=1.0) await hi(messages_pb2.SimpleRequest(),
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2)
_, details = grpc.StatusCode.DEADLINE_EXCEEDED.value # pylint: disable=unused-variable _, details = grpc.StatusCode.DEADLINE_EXCEEDED.value # pylint: disable=unused-variable
self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED,
@ -80,6 +82,18 @@ class TestChannel(AioTestBase):
self.assertIsNotNone( self.assertIsNotNone(
exception_context.exception.trailing_metadata()) exception_context.exception.trailing_metadata())
async def test_unary_call_does_not_times_out(self):
async with aio.insecure_channel(self._server_target) as channel:
hi = channel.unary_unary(
_UNARY_CALL_METHOD_WITH_SLEEP,
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString,
)
call = hi(messages_pb2.SimpleRequest(),
timeout=UNARY_CALL_WITH_SLEEP_VALUE * 5)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_unary_stream(self): async def test_unary_stream(self):
channel = aio.insecure_channel(self._server_target) channel = aio.insecure_channel(self._server_target)
stub = test_pb2_grpc.TestServiceStub(channel) stub = test_pb2_grpc.TestServiceStub(channel)

@ -0,0 +1,538 @@
# Copyright 2019 The gRPC Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import unittest
import grpc
from grpc.experimental import aio
from tests_aio.unit._test_server import start_test_server, UNARY_CALL_WITH_SLEEP_VALUE
from tests_aio.unit._test_base import AioTestBase
from src.proto.grpc.testing import messages_pb2
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
class TestUnaryUnaryClientInterceptor(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
async def tearDown(self):
await self._server.stop(None)
def test_invalid_interceptor(self):
class InvalidInterceptor:
"""Just an invalid Interceptor"""
with self.assertRaises(ValueError):
aio.insecure_channel("", interceptors=[InvalidInterceptor()])
async def test_executed_right_order(self):
interceptors_executed = []
class Interceptor(aio.UnaryUnaryClientInterceptor):
"""Interceptor used for testing if the interceptor is being called"""
async def intercept_unary_unary(self, continuation,
client_call_details, request):
interceptors_executed.append(self)
call = await continuation(client_call_details, request)
return call
interceptors = [Interceptor() for i in range(2)]
async with aio.insecure_channel(self._server_target,
interceptors=interceptors) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
response = await call
# Check that all interceptors were executed, and were executed
# in the right order.
self.assertSequenceEqual(interceptors_executed, interceptors)
self.assertIsInstance(response, messages_pb2.SimpleResponse)
@unittest.expectedFailure
# TODO(https://github.com/grpc/grpc/issues/20144) Once metadata support is
# implemented in the client-side, this test must be implemented.
def test_modify_metadata(self):
raise NotImplementedError()
@unittest.expectedFailure
# TODO(https://github.com/grpc/grpc/issues/20532) Once credentials support is
# implemented in the client-side, this test must be implemented.
def test_modify_credentials(self):
raise NotImplementedError()
async def test_status_code_Ok(self):
class StatusCodeOkInterceptor(aio.UnaryUnaryClientInterceptor):
"""Interceptor used for observing status code Ok returned by the RPC"""
def __init__(self):
self.status_code_Ok_observed = False
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
code = await call.code()
if code == grpc.StatusCode.OK:
self.status_code_Ok_observed = True
return call
interceptor = StatusCodeOkInterceptor()
async with aio.insecure_channel(self._server_target,
interceptors=[interceptor]) as channel:
# when no error StatusCode.OK must be observed
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
await multicallable(messages_pb2.SimpleRequest())
self.assertTrue(interceptor.status_code_Ok_observed)
async def test_add_timeout(self):
class TimeoutInterceptor(aio.UnaryUnaryClientInterceptor):
"""Interceptor used for adding a timeout to the RPC"""
async def intercept_unary_unary(self, continuation,
client_call_details, request):
new_client_call_details = aio.ClientCallDetails(
method=client_call_details.method,
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials)
return await continuation(new_client_call_details, request)
interceptor = TimeoutInterceptor()
async with aio.insecure_channel(self._server_target,
interceptors=[interceptor]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCallWithSleep',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
self.assertEqual(exception_context.exception.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertTrue(call.done())
self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await
call.code())
async def test_retry(self):
class RetryInterceptor(aio.UnaryUnaryClientInterceptor):
"""Simulates a Retry Interceptor which ends up by making
two RPC calls."""
def __init__(self):
self.calls = []
async def intercept_unary_unary(self, continuation,
client_call_details, request):
new_client_call_details = aio.ClientCallDetails(
method=client_call_details.method,
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials)
try:
call = await continuation(new_client_call_details, request)
await call
except grpc.RpcError:
pass
self.calls.append(call)
new_client_call_details = aio.ClientCallDetails(
method=client_call_details.method,
timeout=None,
metadata=client_call_details.metadata,
credentials=client_call_details.credentials)
call = await continuation(new_client_call_details, request)
self.calls.append(call)
return call
interceptor = RetryInterceptor()
async with aio.insecure_channel(self._server_target,
interceptors=[interceptor]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCallWithSleep',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
await call
self.assertEqual(grpc.StatusCode.OK, await call.code())
# Check that two calls were made, first one finishing with
# a deadline and second one finishing ok..
self.assertEqual(len(interceptor.calls), 2)
self.assertEqual(await interceptor.calls[0].code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertEqual(await interceptor.calls[1].code(),
grpc.StatusCode.OK)
async def test_rpcresponse(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
"""Raw responses are seen as reegular calls"""
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
response = await call
return call
class ResponseInterceptor(aio.UnaryUnaryClientInterceptor):
"""Return a raw response"""
response = messages_pb2.SimpleResponse()
async def intercept_unary_unary(self, continuation,
client_call_details, request):
return ResponseInterceptor.response
interceptor, interceptor_response = Interceptor(), ResponseInterceptor()
async with aio.insecure_channel(
self._server_target,
interceptors=[interceptor, interceptor_response]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
response = await call
# Check that the response returned is the one returned by the
# interceptor
self.assertEqual(id(response), id(ResponseInterceptor.response))
# Check all of the UnaryUnaryCallResponse attributes
self.assertTrue(call.done())
self.assertFalse(call.cancel())
self.assertFalse(call.cancelled())
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.details(), '')
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None)
self.assertEqual(await call.debug_error_string(), None)
class TestInterceptedUnaryUnaryCall(AioTestBase):
async def setUp(self):
self._server_target, self._server = await start_test_server()
async def tearDown(self):
await self._server.stop(None)
async def test_call_ok(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
return call
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
response = await call
self.assertTrue(call.done())
self.assertFalse(call.cancelled())
self.assertEqual(type(response), messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.details(), '')
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
async def test_call_ok_awaited(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
await call
return call
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
response = await call
self.assertTrue(call.done())
self.assertFalse(call.cancelled())
self.assertEqual(type(response), messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.details(), '')
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
async def test_call_rpc_error(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
return call
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCallWithSleep',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest(),
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
self.assertTrue(call.done())
self.assertFalse(call.cancelled())
self.assertEqual(await call.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertEqual(await call.details(), 'Deadline Exceeded')
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
async def test_call_rpc_error_awaited(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
await call
return call
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCallWithSleep',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest(),
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
self.assertTrue(call.done())
self.assertFalse(call.cancelled())
self.assertEqual(await call.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertEqual(await call.details(), 'Deadline Exceeded')
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
async def test_cancel_before_rpc(self):
interceptor_reached = asyncio.Event()
wait_for_ever = self.loop.create_future()
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
interceptor_reached.set()
await wait_for_ever
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
self.assertFalse(call.cancelled())
self.assertFalse(call.done())
await interceptor_reached.wait()
self.assertTrue(call.cancel())
with self.assertRaises(asyncio.CancelledError):
await call
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None)
async def test_cancel_after_rpc(self):
interceptor_reached = asyncio.Event()
wait_for_ever = self.loop.create_future()
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
await call
interceptor_reached.set()
await wait_for_ever
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
self.assertFalse(call.cancelled())
self.assertFalse(call.done())
await interceptor_reached.wait()
self.assertTrue(call.cancel())
with self.assertRaises(asyncio.CancelledError):
await call
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None)
async def test_cancel_inside_interceptor_after_rpc_awaiting(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
call.cancel()
await call
return call
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
with self.assertRaises(asyncio.CancelledError):
await call
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), None)
async def test_cancel_inside_interceptor_after_rpc_not_awaiting(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
call.cancel()
return call
async with aio.insecure_channel(self._server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
with self.assertRaises(asyncio.CancelledError):
await call
self.assertTrue(call.cancelled())
self.assertTrue(call.done())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
_LOCAL_CANCEL_DETAILS_EXPECTATION)
self.assertEqual(await call.initial_metadata(), tuple())
self.assertEqual(await call.trailing_metadata(), None)
if __name__ == '__main__':
logging.basicConfig()
unittest.main(verbosity=2)
Loading…
Cancel
Save