Not mask AioRpcError and CancelledError at interceptor level

pull/21455/head
Pau Freixes 5 years ago
parent a2667b80c3
commit 33765f5ee5
  1. 2
      src/python/grpcio/grpc/experimental/aio/_call.py
  2. 158
      src/python/grpcio/grpc/experimental/aio/_interceptor.py
  3. 156
      src/python/grpcio_tests/tests_aio/unit/interceptor_test.py

@ -233,7 +233,7 @@ class Call(_base_call.Call):
if self._code is grpc.StatusCode.OK:
return _OK_CALL_REPRESENTATION.format(
self.__class__.__name__, self._code,
self._status.result().self._status.result().details())
self._status.result().details())
else:
return _NON_OK_CALL_REPRESENTATION.format(
self.__class__.__name__, self._code,

@ -22,7 +22,7 @@ import grpc
from grpc._cython import cygrpc
from . import _base_call
from ._call import UnaryUnaryCall
from ._call import UnaryUnaryCall, AioRpcError
from ._utils import _timeout_to_deadline
from ._typing import (RequestType, SerializingFunction, DeserializingFunction,
MetadataType, ResponseType)
@ -135,19 +135,9 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
if interceptor:
continuation = functools.partial(_run_interceptor, interceptors)
try:
call_or_response = await interceptor.intercept_unary_unary(
continuation, client_call_details, request)
except grpc.RpcError as err:
# gRPC error is masked inside an artificial call,
# caller will see this error if and only
# if it runs an `await call` operation
return UnaryUnaryCallRpcError(err)
except asyncio.CancelledError:
# Cancellation is masked inside an artificial call,
# caller will see this error if and only
# if it runs an `await call` operation
return UnaryUnaryCancelledError()
call_or_response = await interceptor.intercept_unary_unary(
continuation, client_call_details, request)
if isinstance(call_or_response, _base_call.UnaryUnaryCall):
return call_or_response
@ -176,14 +166,25 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
if not self._interceptors_task.done():
return False
call = self._interceptors_task.result()
return call.cancelled()
try:
call = self._interceptors_task.result()
except AioRpcError:
return False
except asyncio.CancelledError:
return True
else:
return call.cancelled()
def done(self) -> bool:
if not self._interceptors_task.done():
return False
return True
try:
call = self._interceptors_task.result()
except (AioRpcError, asyncio.CancelledError):
return True
else:
return call.done()
def add_done_callback(self, unused_callback) -> None:
raise NotImplementedError()
@ -192,19 +193,54 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
raise NotImplementedError()
async def initial_metadata(self) -> Optional[MetadataType]:
return await (await self._interceptors_task).initial_metadata()
try:
call = await self._interceptors_task
except AioRpcError as err:
return err.initial_metadata()
except asyncio.CancelledError:
return None
else:
return await call.initial_metadata()
async def trailing_metadata(self) -> Optional[MetadataType]:
return await (await self._interceptors_task).trailing_metadata()
try:
call = await self._interceptors_task
except AioRpcError as err:
return err.trailing_metadata()
except asyncio.CancelledError:
return None
else:
return await call.trailing_metadata()
async def code(self) -> grpc.StatusCode:
return await (await self._interceptors_task).code()
try:
call = await self._interceptors_task
except AioRpcError as err:
return err.code()
except asyncio.CancelledError:
return grpc.StatusCode.CANCELLED
else:
return await call.code()
async def details(self) -> str:
return await (await self._interceptors_task).details()
try:
call = await self._interceptors_task
except AioRpcError as err:
return err.details()
except asyncio.CancelledError:
return _LOCAL_CANCELLATION_DETAILS
else:
return await call.details()
async def debug_error_string(self) -> Optional[str]:
return await (await self._interceptors_task).debug_error_string()
try:
call = await self._interceptors_task
except AioRpcError as err:
return err.debug_error_string()
except asyncio.CancelledError:
return ''
else:
return await call.debug_error_string()
def __await__(self):
call = yield from self._interceptors_task.__await__()
@ -212,47 +248,6 @@ class InterceptedUnaryUnaryCall(_base_call.UnaryUnaryCall):
return response
class UnaryUnaryCallRpcError(_base_call.UnaryUnaryCall):
"""Final UnaryUnaryCall class finished with an RpcError."""
_error: grpc.RpcError
def __init__(self, error: grpc.RpcError) -> None:
self._error = error
def cancel(self) -> bool:
return False
def cancelled(self) -> bool:
return False
def done(self) -> bool:
return True
def add_done_callback(self, unused_callback) -> None:
raise NotImplementedError()
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
async def initial_metadata(self) -> Optional[MetadataType]:
return None
async def trailing_metadata(self) -> Optional[MetadataType]:
return self._error.initial_metadata()
async def code(self) -> grpc.StatusCode:
return self._error.code()
async def details(self) -> str:
return self._error.details()
async def debug_error_string(self) -> Optional[str]:
return self._error.debug_error_string()
def __await__(self):
raise self._error
class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
"""Final UnaryUnaryCall class finished with a response."""
_response: ResponseType
@ -296,40 +291,3 @@ class UnaryUnaryCallResponse(_base_call.UnaryUnaryCall):
# for telling the interpreter that __await__ is a generator.
yield None
return self._response
class UnaryUnaryCancelledError(_base_call.UnaryUnaryCall):
"""Final UnaryUnaryCall class finished with an asyncio.CancelledError."""
def cancel(self) -> bool:
return False
def cancelled(self) -> bool:
return True
def done(self) -> bool:
return True
def add_done_callback(self, unused_callback) -> None:
raise NotImplementedError()
def time_remaining(self) -> Optional[float]:
raise NotImplementedError()
async def initial_metadata(self) -> Optional[MetadataType]:
return None
async def trailing_metadata(self) -> Optional[MetadataType]:
return None
async def code(self) -> grpc.StatusCode:
return grpc.StatusCode.CANCELLED
async def details(self) -> str:
return _LOCAL_CANCELLATION_DETAILS
async def debug_error_string(self) -> Optional[str]:
return None
def __await__(self):
raise asyncio.CancelledError()

@ -177,6 +177,7 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
self.calls.append(call)
new_client_call_details = aio.ClientCallDetails(
method=client_call_details.method,
timeout=None,
@ -212,61 +213,6 @@ class TestUnaryUnaryClientInterceptor(AioTestBase):
self.assertEqual(await interceptor.calls[1].code(),
grpc.StatusCode.OK)
async def test_rpcerror_raised_when_call_is_awaited(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
"""RpcErrors are only seen when the call is awaited"""
def __init__(self):
self.deadline_seen = False
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
try:
await call
except aio.AioRpcError as err:
if err.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
self.deadline_seen = True
raise
# This point should never be reached
raise Exception()
interceptor_a, interceptor_b = (Interceptor(), Interceptor())
server_target, server = await start_test_server()
async with aio.insecure_channel(
server_target, interceptors=[interceptor_a,
interceptor_b]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest(), timeout=0.1)
with self.assertRaises(grpc.RpcError) as exception_context:
await call
# Check that the two interceptors catch the deadline exception
# only when the call was awaited
self.assertTrue(interceptor_a.deadline_seen)
self.assertTrue(interceptor_b.deadline_seen)
# Check all of the UnaryUnaryCallRpcError attributes
self.assertTrue(call.done())
self.assertFalse(call.cancel())
self.assertFalse(call.cancelled())
self.assertEqual(await call.code(),
grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertEqual(await call.details(), 'Deadline Exceeded')
self.assertEqual(await call.initial_metadata(), None)
self.assertEqual(await call.trailing_metadata(), ())
self.assertEqual(await call.debug_error_string(), None)
async def test_rpcresponse(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
@ -348,6 +294,106 @@ class TestInterceptedUnaryUnaryCall(AioTestBase):
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
async def test_call_ok_awaited(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
await call
return call
server_target, _ = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
call = multicallable(messages_pb2.SimpleRequest())
response = await call
self.assertTrue(call.done())
self.assertFalse(call.cancelled())
self.assertEqual(type(response), messages_pb2.SimpleResponse)
self.assertEqual(await call.code(), grpc.StatusCode.OK)
self.assertEqual(await call.details(), '')
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
async def test_call_rpcerror(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
return call
server_target, server = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
await server.stop(None)
call = multicallable(messages_pb2.SimpleRequest(), timeout=0.1)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
self.assertTrue(call.done())
self.assertFalse(call.cancelled())
self.assertEqual(await call.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertEqual(await call.details(), 'Deadline Exceeded')
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
async def test_call_rpcerror_awaited(self):
class Interceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation,
client_call_details, request):
call = await continuation(client_call_details, request)
await call
return call
server_target, server = await start_test_server() # pylint: disable=unused-variable
async with aio.insecure_channel(server_target,
interceptors=[Interceptor()
]) as channel:
multicallable = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
await server.stop(None)
call = multicallable(messages_pb2.SimpleRequest(), timeout=0.1)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
self.assertTrue(call.done())
self.assertFalse(call.cancelled())
self.assertEqual(await call.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
self.assertEqual(await call.details(), 'Deadline Exceeded')
self.assertEqual(await call.initial_metadata(), ())
self.assertEqual(await call.trailing_metadata(), ())
async def test_cancel_before_rpc(self):
interceptor_reached = asyncio.Event()

Loading…
Cancel
Save