mirror of https://github.com/grpc/grpc.git
Merge pull request #21455 from Skyscanner/client_unaryunary_interceptors_3
[Aio] Client Side Interceptor For Unary Callspull/21642/head
commit
da6a29dd6d
14 changed files with 999 additions and 61 deletions
@ -0,0 +1,291 @@ |
||||
# Copyright 2019 gRPC authors. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
"""Interceptors implementation of gRPC Asyncio Python.""" |
||||
import asyncio |
||||
import collections |
||||
import functools |
||||
from abc import ABCMeta, abstractmethod |
||||
from typing import Callable, Optional, Iterator, Sequence, Text, Union |
||||
|
||||
import grpc |
||||
from grpc._cython import cygrpc |
||||
|
||||
from . import _base_call |
||||
from ._call import UnaryUnaryCall, AioRpcError |
||||
from ._utils import _timeout_to_deadline |
||||
from ._typing import (RequestType, SerializingFunction, DeserializingFunction, |
||||
MetadataType, ResponseType) |
||||
|
||||
_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!' |
||||
|
||||
|
||||
class ClientCallDetails( |
||||
collections.namedtuple( |
||||
'ClientCallDetails', |
||||
('method', 'timeout', 'metadata', 'credentials')), |
||||
grpc.ClientCallDetails): |
||||
|
||||
method: Text |
||||
timeout: Optional[float] |
||||
metadata: Optional[MetadataType] |
||||
credentials: Optional[grpc.CallCredentials] |
||||
|
||||
|
||||
class UnaryUnaryClientInterceptor(metaclass=ABCMeta): |
||||
"""Affords intercepting unary-unary invocations.""" |
||||
|
||||
@abstractmethod |
||||
async def intercept_unary_unary( |
||||
self, continuation: Callable[[ClientCallDetails, RequestType], |
||||
UnaryUnaryCall], |
||||
client_call_details: ClientCallDetails, |
||||
request: RequestType) -> Union[UnaryUnaryCall, ResponseType]: |
||||
"""Intercepts a unary-unary invocation asynchronously. |
||||
Args: |
||||
continuation: A coroutine that proceeds with the invocation by |
||||
executing the next interceptor in chain or invoking the |
||||
actual RPC on the underlying Channel. It is the interceptor's |
||||
responsibility to call it if it decides to move the RPC forward. |
||||
The interceptor can use |
||||
`response_future = await continuation(client_call_details, request)` |
||||
to continue with the RPC. `continuation` returns the response of the |
||||
RPC. |
||||
client_call_details: A ClientCallDetails object describing the |
||||
outgoing RPC. |
||||
request: The request value for the RPC. |
||||
Returns: |
||||
An object with the RPC response. |
||||
Raises: |
||||
AioRpcError: Indicating that the RPC terminated with non-OK status. |
||||
asyncio.CancelledError: Indicating that the RPC was canceled. |
||||
""" |
||||
|
||||
|
||||
class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall): |
||||
"""Used for running a `UnaryUnaryCall` wrapped by interceptors. |
||||
|
||||
Interceptors might have some work to do before the RPC invocation with |
||||
the capacity of changing the invocation parameters, and some work to do |
||||
after the RPC invocation with the capacity for accessing to the wrapped |
||||
`UnaryUnaryCall`. |
||||
|
||||
It handles also early and later cancellations, when the RPC has not even |
||||
started and the execution is still held by the interceptors or when the |
||||
RPC has finished but again the execution is still held by the interceptors. |
||||
|
||||
Once the RPC is finally executed, all methods are finally done against the |
||||
intercepted call, being at the same time the same call returned to the |
||||
interceptors. |
||||
|
||||
For most of the methods, like `initial_metadata()` the caller does not need |
||||
to wait until the interceptors task is finished, once the RPC is done the |
||||
caller will have the freedom for accessing to the results. |
||||
|
||||
For the `__await__` method is it is proxied to the intercepted call only when |
||||
the interceptor task is finished. |
||||
""" |
||||
|
||||
_loop: asyncio.AbstractEventLoop |
||||
_channel: cygrpc.AioChannel |
||||
_cancelled_before_rpc: bool |
||||
_intercepted_call: Optional[_base_call.UnaryUnaryCall] |
||||
_intercepted_call_created: asyncio.Event |
||||
_interceptors_task: asyncio.Task |
||||
|
||||
def __init__( # pylint: disable=R0913 |
||||
self, interceptors: Sequence[UnaryUnaryClientInterceptor], |
||||
request: RequestType, timeout: Optional[float], |
||||
channel: cygrpc.AioChannel, method: bytes, |
||||
request_serializer: SerializingFunction, |
||||
response_deserializer: DeserializingFunction) -> None: |
||||
self._channel = channel |
||||
self._loop = asyncio.get_event_loop() |
||||
self._interceptors_task = asyncio.ensure_future( |
||||
self._invoke(interceptors, method, timeout, request, |
||||
request_serializer, response_deserializer)) |
||||
|
||||
def __del__(self): |
||||
self.cancel() |
||||
|
||||
async def _invoke(self, interceptors: Sequence[UnaryUnaryClientInterceptor], |
||||
method: bytes, timeout: Optional[float], |
||||
request: RequestType, |
||||
request_serializer: SerializingFunction, |
||||
response_deserializer: DeserializingFunction |
||||
) -> UnaryUnaryCall: |
||||
"""Run the RPC call wrapped in interceptors""" |
||||
|
||||
async def _run_interceptor( |
||||
interceptors: Iterator[UnaryUnaryClientInterceptor], |
||||
client_call_details: ClientCallDetails, |
||||
request: RequestType) -> _base_call.UnaryUnaryCall: |
||||
|
||||
interceptor = next(interceptors, None) |
||||
|
||||
if interceptor: |
||||
continuation = functools.partial(_run_interceptor, interceptors) |
||||
|
||||
call_or_response = await interceptor.intercept_unary_unary( |
||||
continuation, client_call_details, request) |
||||
|
||||
if isinstance(call_or_response, _base_call.UnaryUnaryCall): |
||||
return call_or_response |
||||
else: |
||||
return UnaryUnaryCallResponse(call_or_response) |
||||
|
||||
else: |
||||
return UnaryUnaryCall( |
||||
request, _timeout_to_deadline(client_call_details.timeout), |
||||
self._channel, client_call_details.method, |
||||
request_serializer, response_deserializer) |
||||
|
||||
client_call_details = ClientCallDetails(method, timeout, None, None) |
||||
return await _run_interceptor(iter(interceptors), client_call_details, |
||||
request) |
||||
|
||||
def cancel(self) -> bool: |
||||
if self._interceptors_task.done(): |
||||
return False |
||||
|
||||
return self._interceptors_task.cancel() |
||||
|
||||
def cancelled(self) -> bool: |
||||
if not self._interceptors_task.done(): |
||||
return False |
||||
|
||||
try: |
||||
call = self._interceptors_task.result() |
||||
except AioRpcError as err: |
||||
return err.code() == grpc.StatusCode.CANCELLED |
||||
except asyncio.CancelledError: |
||||
return True |
||||
|
||||
return call.cancelled() |
||||
|
||||
def done(self) -> bool: |
||||
if not self._interceptors_task.done(): |
||||
return False |
||||
|
||||
try: |
||||
call = self._interceptors_task.result() |
||||
except (AioRpcError, asyncio.CancelledError): |
||||
return True |
||||
|
||||
return call.done() |
||||
|
||||
def add_done_callback(self, unused_callback) -> None: |
||||
raise NotImplementedError() |
||||
|
||||
def time_remaining(self) -> Optional[float]: |
||||
raise NotImplementedError() |
||||
|
||||
async def initial_metadata(self) -> Optional[MetadataType]: |
||||
try: |
||||
call = await self._interceptors_task |
||||
except AioRpcError as err: |
||||
return err.initial_metadata() |
||||
except asyncio.CancelledError: |
||||
return None |
||||
|
||||
return await call.initial_metadata() |
||||
|
||||
async def trailing_metadata(self) -> Optional[MetadataType]: |
||||
try: |
||||
call = await self._interceptors_task |
||||
except AioRpcError as err: |
||||
return err.trailing_metadata() |
||||
except asyncio.CancelledError: |
||||
return None |
||||
|
||||
return await call.trailing_metadata() |
||||
|
||||
async def code(self) -> grpc.StatusCode: |
||||
try: |
||||
call = await self._interceptors_task |
||||
except AioRpcError as err: |
||||
return err.code() |
||||
except asyncio.CancelledError: |
||||
return grpc.StatusCode.CANCELLED |
||||
|
||||
return await call.code() |
||||
|
||||
async def details(self) -> str: |
||||
try: |
||||
call = await self._interceptors_task |
||||
except AioRpcError as err: |
||||
return err.details() |
||||
except asyncio.CancelledError: |
||||
return _LOCAL_CANCELLATION_DETAILS |
||||
|
||||
return await call.details() |
||||
|
||||
async def debug_error_string(self) -> Optional[str]: |
||||
try: |
||||
call = await self._interceptors_task |
||||
except AioRpcError as err: |
||||
return err.debug_error_string() |
||||
except asyncio.CancelledError: |
||||
return '' |
||||
|
||||
return await call.debug_error_string() |
||||
|
||||
def __await__(self): |
||||
call = yield from self._interceptors_task.__await__() |
||||
response = yield from call.__await__() |
||||
return response |
||||
|
||||
|
||||
class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall): |
||||
"""Final UnaryUnaryCall class finished with a response.""" |
||||
_response: ResponseType |
||||
|
||||
def __init__(self, response: ResponseType) -> None: |
||||
self._response = response |
||||
|
||||
def cancel(self) -> bool: |
||||
return False |
||||
|
||||
def cancelled(self) -> bool: |
||||
return False |
||||
|
||||
def done(self) -> bool: |
||||
return True |
||||
|
||||
def add_done_callback(self, unused_callback) -> None: |
||||
raise NotImplementedError() |
||||
|
||||
def time_remaining(self) -> Optional[float]: |
||||
raise NotImplementedError() |
||||
|
||||
async def initial_metadata(self) -> Optional[MetadataType]: |
||||
return None |
||||
|
||||
async def trailing_metadata(self) -> Optional[MetadataType]: |
||||
return None |
||||
|
||||
async def code(self) -> grpc.StatusCode: |
||||
return grpc.StatusCode.OK |
||||
|
||||
async def details(self) -> str: |
||||
return '' |
||||
|
||||
async def debug_error_string(self) -> Optional[str]: |
||||
return None |
||||
|
||||
def __await__(self): |
||||
if False: # pylint: disable=W0125 |
||||
# This code path is never used, but a yield statement is needed |
||||
# for telling the interpreter that __await__ is a generator. |
||||
yield None |
||||
return self._response |
@ -0,0 +1,22 @@ |
||||
# Copyright 2019 gRPC authors. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
"""Internal utilities used by the gRPC Aio module.""" |
||||
import time |
||||
from typing import Optional |
||||
|
||||
|
||||
def _timeout_to_deadline(timeout: Optional[float]) -> Optional[float]: |
||||
if timeout is None: |
||||
return None |
||||
return time.time() + timeout |
@ -0,0 +1,538 @@ |
||||
# Copyright 2019 The gRPC Authors. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
import asyncio |
||||
import logging |
||||
import unittest |
||||
|
||||
import grpc |
||||
|
||||
from grpc.experimental import aio |
||||
from tests_aio.unit._test_server import start_test_server, UNARY_CALL_WITH_SLEEP_VALUE |
||||
from tests_aio.unit._test_base import AioTestBase |
||||
from src.proto.grpc.testing import messages_pb2 |
||||
|
||||
_LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!' |
||||
|
||||
|
||||
class TestUnaryUnaryClientInterceptor(AioTestBase): |
||||
|
||||
async def setUp(self): |
||||
self._server_target, self._server = await start_test_server() |
||||
|
||||
async def tearDown(self): |
||||
await self._server.stop(None) |
||||
|
||||
def test_invalid_interceptor(self): |
||||
|
||||
class InvalidInterceptor: |
||||
"""Just an invalid Interceptor""" |
||||
|
||||
with self.assertRaises(ValueError): |
||||
aio.insecure_channel("", interceptors=[InvalidInterceptor()]) |
||||
|
||||
async def test_executed_right_order(self): |
||||
|
||||
interceptors_executed = [] |
||||
|
||||
class Interceptor(aio.UnaryUnaryClientInterceptor): |
||||
"""Interceptor used for testing if the interceptor is being called""" |
||||
|
||||
async def intercept_unary_unary(self, continuation, |
||||
client_call_details, request): |
||||
interceptors_executed.append(self) |
||||
call = await continuation(client_call_details, request) |
||||
return call |
||||
|
||||
interceptors = [Interceptor() for i in range(2)] |
||||
|
||||
async with aio.insecure_channel(self._server_target, |
||||
interceptors=interceptors) as channel: |
||||
multicallable = channel.unary_unary( |
||||
'/grpc.testing.TestService/UnaryCall', |
||||
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
||||
response_deserializer=messages_pb2.SimpleResponse.FromString) |
||||
call = multicallable(messages_pb2.SimpleRequest()) |
||||
response = await call |
||||
|
||||
# Check that all interceptors were executed, and were executed |
||||
# in the right order. |
||||
self.assertSequenceEqual(interceptors_executed, interceptors) |
||||
|
||||
self.assertIsInstance(response, messages_pb2.SimpleResponse) |
||||
|
||||
@unittest.expectedFailure |
||||
# TODO(https://github.com/grpc/grpc/issues/20144) Once metadata support is |
||||
# implemented in the client-side, this test must be implemented. |
||||
def test_modify_metadata(self): |
||||
raise NotImplementedError() |
||||
|
||||
@unittest.expectedFailure |
||||
# TODO(https://github.com/grpc/grpc/issues/20532) Once credentials support is |
||||
# implemented in the client-side, this test must be implemented. |
||||
def test_modify_credentials(self): |
||||
raise NotImplementedError() |
||||
|
||||
async def test_status_code_Ok(self): |
||||
|
||||
class StatusCodeOkInterceptor(aio.UnaryUnaryClientInterceptor): |
||||
"""Interceptor used for observing status code Ok returned by the RPC""" |
||||
|
||||
def __init__(self): |
||||
self.status_code_Ok_observed = False |
||||
|
||||
async def intercept_unary_unary(self, continuation, |
||||
client_call_details, request): |
||||
call = await continuation(client_call_details, request) |
||||
code = await call.code() |
||||
if code == grpc.StatusCode.OK: |
||||
self.status_code_Ok_observed = True |
||||
|
||||
return call |
||||
|
||||
interceptor = StatusCodeOkInterceptor() |
||||
|
||||
async with aio.insecure_channel(self._server_target, |
||||
interceptors=[interceptor]) as channel: |
||||
|
||||
# when no error StatusCode.OK must be observed |
||||
multicallable = channel.unary_unary( |
||||
'/grpc.testing.TestService/UnaryCall', |
||||
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
||||
response_deserializer=messages_pb2.SimpleResponse.FromString) |
||||
|
||||
await multicallable(messages_pb2.SimpleRequest()) |
||||
|
||||
self.assertTrue(interceptor.status_code_Ok_observed) |
||||
|
||||
async def test_add_timeout(self): |
||||
|
||||
class TimeoutInterceptor(aio.UnaryUnaryClientInterceptor): |
||||
"""Interceptor used for adding a timeout to the RPC""" |
||||
|
||||
async def intercept_unary_unary(self, continuation, |
||||
client_call_details, request): |
||||
new_client_call_details = aio.ClientCallDetails( |
||||
method=client_call_details.method, |
||||
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2, |
||||
metadata=client_call_details.metadata, |
||||
credentials=client_call_details.credentials) |
||||
return await continuation(new_client_call_details, request) |
||||
|
||||
interceptor = TimeoutInterceptor() |
||||
|
||||
async with aio.insecure_channel(self._server_target, |
||||
interceptors=[interceptor]) as channel: |
||||
|
||||
multicallable = channel.unary_unary( |
||||
'/grpc.testing.TestService/UnaryCallWithSleep', |
||||
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
||||
response_deserializer=messages_pb2.SimpleResponse.FromString) |
||||
|
||||
call = multicallable(messages_pb2.SimpleRequest()) |
||||
|
||||
with self.assertRaises(aio.AioRpcError) as exception_context: |
||||
await call |
||||
|
||||
self.assertEqual(exception_context.exception.code(), |
||||
grpc.StatusCode.DEADLINE_EXCEEDED) |
||||
|
||||
self.assertTrue(call.done()) |
||||
self.assertEqual(grpc.StatusCode.DEADLINE_EXCEEDED, await |
||||
call.code()) |
||||
|
||||
async def test_retry(self): |
||||
|
||||
class RetryInterceptor(aio.UnaryUnaryClientInterceptor): |
||||
"""Simulates a Retry Interceptor which ends up by making |
||||
two RPC calls.""" |
||||
|
||||
def __init__(self): |
||||
self.calls = [] |
||||
|
||||
async def intercept_unary_unary(self, continuation, |
||||
client_call_details, request): |
||||
|
||||
new_client_call_details = aio.ClientCallDetails( |
||||
method=client_call_details.method, |
||||
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2, |
||||
metadata=client_call_details.metadata, |
||||
credentials=client_call_details.credentials) |
||||
|
||||
try: |
||||
call = await continuation(new_client_call_details, request) |
||||
await call |
||||
except grpc.RpcError: |
||||
pass |
||||
|
||||
self.calls.append(call) |
||||
|
||||
new_client_call_details = aio.ClientCallDetails( |
||||
method=client_call_details.method, |
||||
timeout=None, |
||||
metadata=client_call_details.metadata, |
||||
credentials=client_call_details.credentials) |
||||
|
||||
call = await continuation(new_client_call_details, request) |
||||
self.calls.append(call) |
||||
return call |
||||
|
||||
interceptor = RetryInterceptor() |
||||
|
||||
async with aio.insecure_channel(self._server_target, |
||||
interceptors=[interceptor]) as channel: |
||||
|
||||
multicallable = channel.unary_unary( |
||||
'/grpc.testing.TestService/UnaryCallWithSleep', |
||||
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
||||
response_deserializer=messages_pb2.SimpleResponse.FromString) |
||||
|
||||
call = multicallable(messages_pb2.SimpleRequest()) |
||||
|
||||
await call |
||||
|
||||
self.assertEqual(grpc.StatusCode.OK, await call.code()) |
||||
|
||||
# Check that two calls were made, first one finishing with |
||||
# a deadline and second one finishing ok.. |
||||
self.assertEqual(len(interceptor.calls), 2) |
||||
self.assertEqual(await interceptor.calls[0].code(), |
||||
grpc.StatusCode.DEADLINE_EXCEEDED) |
||||
self.assertEqual(await interceptor.calls[1].code(), |
||||
grpc.StatusCode.OK) |
||||
|
||||
async def test_rpcresponse(self): |
||||
|
||||
class Interceptor(aio.UnaryUnaryClientInterceptor): |
||||
"""Raw responses are seen as reegular calls""" |
||||
|
||||
async def intercept_unary_unary(self, continuation, |
||||
client_call_details, request): |
||||
call = await continuation(client_call_details, request) |
||||
response = await call |
||||
return call |
||||
|
||||
class ResponseInterceptor(aio.UnaryUnaryClientInterceptor): |
||||
"""Return a raw response""" |
||||
response = messages_pb2.SimpleResponse() |
||||
|
||||
async def intercept_unary_unary(self, continuation, |
||||
client_call_details, request): |
||||
return ResponseInterceptor.response |
||||
|
||||
interceptor, interceptor_response = Interceptor(), ResponseInterceptor() |
||||
|
||||
async with aio.insecure_channel( |
||||
self._server_target, |
||||
interceptors=[interceptor, interceptor_response]) as channel: |
||||
|
||||
multicallable = channel.unary_unary( |
||||
'/grpc.testing.TestService/UnaryCall', |
||||
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
||||
response_deserializer=messages_pb2.SimpleResponse.FromString) |
||||
|
||||
call = multicallable(messages_pb2.SimpleRequest()) |
||||
response = await call |
||||
|
||||
# Check that the response returned is the one returned by the |
||||
# interceptor |
||||
self.assertEqual(id(response), id(ResponseInterceptor.response)) |
||||
|
||||
# Check all of the UnaryUnaryCallResponse attributes |
||||
self.assertTrue(call.done()) |
||||
self.assertFalse(call.cancel()) |
||||
self.assertFalse(call.cancelled()) |
||||
self.assertEqual(await call.code(), grpc.StatusCode.OK) |
||||
self.assertEqual(await call.details(), '') |
||||
self.assertEqual(await call.initial_metadata(), None) |
||||
self.assertEqual(await call.trailing_metadata(), None) |
||||
self.assertEqual(await call.debug_error_string(), None) |
||||
|
||||
|
||||
class TestInterceptedUnaryUnaryCall(AioTestBase): |
||||
|
||||
async def setUp(self): |
||||
self._server_target, self._server = await start_test_server() |
||||
|
||||
async def tearDown(self): |
||||
await self._server.stop(None) |
||||
|
||||
async def test_call_ok(self): |
||||
|
||||
class Interceptor(aio.UnaryUnaryClientInterceptor): |
||||
|
||||
async def intercept_unary_unary(self, continuation, |
||||
client_call_details, request): |
||||
call = await continuation(client_call_details, request) |
||||
return call |
||||
|
||||
async with aio.insecure_channel(self._server_target, |
||||
interceptors=[Interceptor() |
||||
]) as channel: |
||||
|
||||
multicallable = channel.unary_unary( |
||||
'/grpc.testing.TestService/UnaryCall', |
||||
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
||||
response_deserializer=messages_pb2.SimpleResponse.FromString) |
||||
call = multicallable(messages_pb2.SimpleRequest()) |
||||
response = await call |
||||
|
||||
self.assertTrue(call.done()) |
||||
self.assertFalse(call.cancelled()) |
||||
self.assertEqual(type(response), messages_pb2.SimpleResponse) |
||||
self.assertEqual(await call.code(), grpc.StatusCode.OK) |
||||
self.assertEqual(await call.details(), '') |
||||
self.assertEqual(await call.initial_metadata(), ()) |
||||
self.assertEqual(await call.trailing_metadata(), ()) |
||||
|
||||
async def test_call_ok_awaited(self): |
||||
|
||||
class Interceptor(aio.UnaryUnaryClientInterceptor): |
||||
|
||||
async def intercept_unary_unary(self, continuation, |
||||
client_call_details, request): |
||||
call = await continuation(client_call_details, request) |
||||
await call |
||||
return call |
||||
|
||||
async with aio.insecure_channel(self._server_target, |
||||
interceptors=[Interceptor() |
||||
]) as channel: |
||||
|
||||
multicallable = channel.unary_unary( |
||||
'/grpc.testing.TestService/UnaryCall', |
||||
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
||||
response_deserializer=messages_pb2.SimpleResponse.FromString) |
||||
call = multicallable(messages_pb2.SimpleRequest()) |
||||
response = await call |
||||
|
||||
self.assertTrue(call.done()) |
||||
self.assertFalse(call.cancelled()) |
||||
self.assertEqual(type(response), messages_pb2.SimpleResponse) |
||||
self.assertEqual(await call.code(), grpc.StatusCode.OK) |
||||
self.assertEqual(await call.details(), '') |
||||
self.assertEqual(await call.initial_metadata(), ()) |
||||
self.assertEqual(await call.trailing_metadata(), ()) |
||||
|
||||
async def test_call_rpc_error(self): |
||||
|
||||
class Interceptor(aio.UnaryUnaryClientInterceptor): |
||||
|
||||
async def intercept_unary_unary(self, continuation, |
||||
client_call_details, request): |
||||
call = await continuation(client_call_details, request) |
||||
return call |
||||
|
||||
async with aio.insecure_channel(self._server_target, |
||||
interceptors=[Interceptor() |
||||
]) as channel: |
||||
|
||||
multicallable = channel.unary_unary( |
||||
'/grpc.testing.TestService/UnaryCallWithSleep', |
||||
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
||||
response_deserializer=messages_pb2.SimpleResponse.FromString) |
||||
|
||||
call = multicallable(messages_pb2.SimpleRequest(), |
||||
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2) |
||||
|
||||
with self.assertRaises(aio.AioRpcError) as exception_context: |
||||
await call |
||||
|
||||
self.assertTrue(call.done()) |
||||
self.assertFalse(call.cancelled()) |
||||
self.assertEqual(await call.code(), |
||||
grpc.StatusCode.DEADLINE_EXCEEDED) |
||||
self.assertEqual(await call.details(), 'Deadline Exceeded') |
||||
self.assertEqual(await call.initial_metadata(), ()) |
||||
self.assertEqual(await call.trailing_metadata(), ()) |
||||
|
||||
async def test_call_rpc_error_awaited(self): |
||||
|
||||
class Interceptor(aio.UnaryUnaryClientInterceptor): |
||||
|
||||
async def intercept_unary_unary(self, continuation, |
||||
client_call_details, request): |
||||
call = await continuation(client_call_details, request) |
||||
await call |
||||
return call |
||||
|
||||
async with aio.insecure_channel(self._server_target, |
||||
interceptors=[Interceptor() |
||||
]) as channel: |
||||
|
||||
multicallable = channel.unary_unary( |
||||
'/grpc.testing.TestService/UnaryCallWithSleep', |
||||
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
||||
response_deserializer=messages_pb2.SimpleResponse.FromString) |
||||
|
||||
call = multicallable(messages_pb2.SimpleRequest(), |
||||
timeout=UNARY_CALL_WITH_SLEEP_VALUE / 2) |
||||
|
||||
with self.assertRaises(aio.AioRpcError) as exception_context: |
||||
await call |
||||
|
||||
self.assertTrue(call.done()) |
||||
self.assertFalse(call.cancelled()) |
||||
self.assertEqual(await call.code(), |
||||
grpc.StatusCode.DEADLINE_EXCEEDED) |
||||
self.assertEqual(await call.details(), 'Deadline Exceeded') |
||||
self.assertEqual(await call.initial_metadata(), ()) |
||||
self.assertEqual(await call.trailing_metadata(), ()) |
||||
|
||||
async def test_cancel_before_rpc(self): |
||||
|
||||
interceptor_reached = asyncio.Event() |
||||
wait_for_ever = self.loop.create_future() |
||||
|
||||
class Interceptor(aio.UnaryUnaryClientInterceptor): |
||||
|
||||
async def intercept_unary_unary(self, continuation, |
||||
client_call_details, request): |
||||
interceptor_reached.set() |
||||
await wait_for_ever |
||||
|
||||
async with aio.insecure_channel(self._server_target, |
||||
interceptors=[Interceptor() |
||||
]) as channel: |
||||
|
||||
multicallable = channel.unary_unary( |
||||
'/grpc.testing.TestService/UnaryCall', |
||||
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
||||
response_deserializer=messages_pb2.SimpleResponse.FromString) |
||||
call = multicallable(messages_pb2.SimpleRequest()) |
||||
|
||||
self.assertFalse(call.cancelled()) |
||||
self.assertFalse(call.done()) |
||||
|
||||
await interceptor_reached.wait() |
||||
self.assertTrue(call.cancel()) |
||||
|
||||
with self.assertRaises(asyncio.CancelledError): |
||||
await call |
||||
|
||||
self.assertTrue(call.cancelled()) |
||||
self.assertTrue(call.done()) |
||||
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) |
||||
self.assertEqual(await call.details(), |
||||
_LOCAL_CANCEL_DETAILS_EXPECTATION) |
||||
self.assertEqual(await call.initial_metadata(), None) |
||||
self.assertEqual(await call.trailing_metadata(), None) |
||||
|
||||
async def test_cancel_after_rpc(self): |
||||
|
||||
interceptor_reached = asyncio.Event() |
||||
wait_for_ever = self.loop.create_future() |
||||
|
||||
class Interceptor(aio.UnaryUnaryClientInterceptor): |
||||
|
||||
async def intercept_unary_unary(self, continuation, |
||||
client_call_details, request): |
||||
call = await continuation(client_call_details, request) |
||||
await call |
||||
interceptor_reached.set() |
||||
await wait_for_ever |
||||
|
||||
async with aio.insecure_channel(self._server_target, |
||||
interceptors=[Interceptor() |
||||
]) as channel: |
||||
|
||||
multicallable = channel.unary_unary( |
||||
'/grpc.testing.TestService/UnaryCall', |
||||
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
||||
response_deserializer=messages_pb2.SimpleResponse.FromString) |
||||
call = multicallable(messages_pb2.SimpleRequest()) |
||||
|
||||
self.assertFalse(call.cancelled()) |
||||
self.assertFalse(call.done()) |
||||
|
||||
await interceptor_reached.wait() |
||||
self.assertTrue(call.cancel()) |
||||
|
||||
with self.assertRaises(asyncio.CancelledError): |
||||
await call |
||||
|
||||
self.assertTrue(call.cancelled()) |
||||
self.assertTrue(call.done()) |
||||
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) |
||||
self.assertEqual(await call.details(), |
||||
_LOCAL_CANCEL_DETAILS_EXPECTATION) |
||||
self.assertEqual(await call.initial_metadata(), None) |
||||
self.assertEqual(await call.trailing_metadata(), None) |
||||
|
||||
async def test_cancel_inside_interceptor_after_rpc_awaiting(self): |
||||
|
||||
class Interceptor(aio.UnaryUnaryClientInterceptor): |
||||
|
||||
async def intercept_unary_unary(self, continuation, |
||||
client_call_details, request): |
||||
call = await continuation(client_call_details, request) |
||||
call.cancel() |
||||
await call |
||||
return call |
||||
|
||||
async with aio.insecure_channel(self._server_target, |
||||
interceptors=[Interceptor() |
||||
]) as channel: |
||||
|
||||
multicallable = channel.unary_unary( |
||||
'/grpc.testing.TestService/UnaryCall', |
||||
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
||||
response_deserializer=messages_pb2.SimpleResponse.FromString) |
||||
call = multicallable(messages_pb2.SimpleRequest()) |
||||
|
||||
with self.assertRaises(asyncio.CancelledError): |
||||
await call |
||||
|
||||
self.assertTrue(call.cancelled()) |
||||
self.assertTrue(call.done()) |
||||
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) |
||||
self.assertEqual(await call.details(), |
||||
_LOCAL_CANCEL_DETAILS_EXPECTATION) |
||||
self.assertEqual(await call.initial_metadata(), None) |
||||
self.assertEqual(await call.trailing_metadata(), None) |
||||
|
||||
async def test_cancel_inside_interceptor_after_rpc_not_awaiting(self): |
||||
|
||||
class Interceptor(aio.UnaryUnaryClientInterceptor): |
||||
|
||||
async def intercept_unary_unary(self, continuation, |
||||
client_call_details, request): |
||||
call = await continuation(client_call_details, request) |
||||
call.cancel() |
||||
return call |
||||
|
||||
async with aio.insecure_channel(self._server_target, |
||||
interceptors=[Interceptor() |
||||
]) as channel: |
||||
|
||||
multicallable = channel.unary_unary( |
||||
'/grpc.testing.TestService/UnaryCall', |
||||
request_serializer=messages_pb2.SimpleRequest.SerializeToString, |
||||
response_deserializer=messages_pb2.SimpleResponse.FromString) |
||||
call = multicallable(messages_pb2.SimpleRequest()) |
||||
|
||||
with self.assertRaises(asyncio.CancelledError): |
||||
await call |
||||
|
||||
self.assertTrue(call.cancelled()) |
||||
self.assertTrue(call.done()) |
||||
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) |
||||
self.assertEqual(await call.details(), |
||||
_LOCAL_CANCEL_DETAILS_EXPECTATION) |
||||
self.assertEqual(await call.initial_metadata(), tuple()) |
||||
self.assertEqual(await call.trailing_metadata(), None) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
logging.basicConfig() |
||||
unittest.main(verbosity=2) |
Loading…
Reference in new issue