Convert local cancellation exception into CancelledError

pull/21506/head
Lidi Zheng 5 years ago
parent a3d7733dd0
commit 4e3d980f70
  1. 24
      src/python/grpcio/grpc/experimental/aio/_call.py
  2. 33
      src/python/grpcio_tests/tests_aio/unit/call_test.py

@ -150,12 +150,14 @@ class Call(_base_call.Call):
_code: grpc.StatusCode
_status: Awaitable[cygrpc.AioRpcStatus]
_initial_metadata: Awaitable[MetadataType]
_locally_cancelled: bool
def __init__(self) -> None:
self._loop = asyncio.get_event_loop()
self._code = None
self._status = self._loop.create_future()
self._initial_metadata = self._loop.create_future()
self._locally_cancelled = False
def cancel(self) -> bool:
"""Placeholder cancellation method.
@ -204,6 +206,10 @@ class Call(_base_call.Call):
cancellation (by application) and Core receiving status from peer. We
make no promise here which one will win.
"""
# In case of local cancellation, flip the flag.
if status.details() is _LOCAL_CANCELLATION_DETAILS:
self._locally_cancelled = True
# In case of the RPC finished without receiving metadata.
if not self._initial_metadata.done():
self._initial_metadata.set_result(_EMPTY_METADATA)
@ -212,7 +218,9 @@ class Call(_base_call.Call):
self._status.set_result(status)
self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()]
async def _raise_rpc_error_if_not_ok(self) -> None:
async def _raise_if_not_ok(self) -> None:
if self._locally_cancelled:
raise asyncio.CancelledError()
await self._status
if self._code != grpc.StatusCode.OK:
raise _create_rpc_error(await self.initial_metadata(),
@ -287,8 +295,8 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
if self._code != grpc.StatusCode.CANCELLED:
self.cancel()
# Raises RpcError here if RPC failed or cancelled
await self._raise_rpc_error_if_not_ok()
# Raises here if RPC failed or cancelled
await self._raise_if_not_ok()
return _common.deserialize(serialized_response,
self._response_deserializer)
@ -319,7 +327,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
# `CancelledError`.
if not self.cancelled():
self.cancel()
raise _create_rpc_error(_EMPTY_METADATA, self._status.result())
raise
return response
@ -367,7 +375,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
except asyncio.CancelledError:
if self._code != grpc.StatusCode.CANCELLED:
self.cancel()
await self._raise_rpc_error_if_not_ok()
raise
async def _fetch_stream_responses(self) -> ResponseType:
await self._send_unary_request_task
@ -418,7 +426,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
except asyncio.CancelledError:
if self._code != grpc.StatusCode.CANCELLED:
self.cancel()
await self._raise_rpc_error_if_not_ok()
raise
if raw_response is None:
return None
@ -428,14 +436,14 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
async def read(self) -> ResponseType:
if self._status.done():
await self._raise_rpc_error_if_not_ok()
await self._raise_if_not_ok()
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
response_message = await self._read()
if response_message is None:
# If the read operation failed, Core should explain why.
await self._raise_rpc_error_if_not_ok()
await self._raise_if_not_ok()
# If no exception raised, there is something wrong internally.
assert False, 'Read operation failed with StatusCode.OK'
else:

@ -124,18 +124,10 @@ class TestUnaryUnaryCall(AioTestBase):
self.assertTrue(call.cancel())
self.assertFalse(call.cancel())
with self.assertRaises(grpc.RpcError) as exception_context:
with self.assertRaises(asyncio.CancelledError):
await call
# The info in the RpcError should match the info in Call object.
rpc_error = exception_context.exception
self.assertEqual(rpc_error.code(), await call.code())
self.assertEqual(rpc_error.details(), await call.details())
self.assertEqual(rpc_error.trailing_metadata(), await
call.trailing_metadata())
self.assertEqual(rpc_error.debug_error_string(), await
call.debug_error_string())
self.assertTrue(call.cancelled())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
@ -159,10 +151,8 @@ class TestUnaryUnaryCall(AioTestBase):
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
with self.assertRaises(grpc.RpcError) as exception_context:
with self.assertRaises(asyncio.CancelledError):
await task
self.assertEqual(grpc.StatusCode.CANCELLED,
exception_context.exception.code())
class TestUnaryStreamCall(AioTestBase):
@ -201,7 +191,7 @@ class TestUnaryStreamCall(AioTestBase):
call.details())
self.assertFalse(call.cancel())
with self.assertRaises(grpc.RpcError) as exception_context:
with self.assertRaises(asyncio.CancelledError):
await call.read()
self.assertTrue(call.cancelled())
@ -232,7 +222,7 @@ class TestUnaryStreamCall(AioTestBase):
self.assertFalse(call.cancel())
self.assertFalse(call.cancel())
with self.assertRaises(grpc.RpcError) as exception_context:
with self.assertRaises(asyncio.CancelledError):
await call.read()
async def test_early_cancel_unary_stream(self):
@ -256,16 +246,11 @@ class TestUnaryStreamCall(AioTestBase):
self.assertTrue(call.cancel())
self.assertFalse(call.cancel())
with self.assertRaises(grpc.RpcError) as exception_context:
with self.assertRaises(asyncio.CancelledError):
await call.read()
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED,
exception_context.exception.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION,
exception_context.exception.details())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
call.details())
@ -377,10 +362,8 @@ class TestUnaryStreamCall(AioTestBase):
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
with self.assertRaises(grpc.RpcError) as exception_context:
with self.assertRaises(asyncio.CancelledError):
await task
self.assertEqual(grpc.StatusCode.CANCELLED,
exception_context.exception.code())
async def test_cancel_unary_stream_in_task_using_async_for(self):
async with aio.insecure_channel(self._server_target) as channel:
@ -411,10 +394,8 @@ class TestUnaryStreamCall(AioTestBase):
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
with self.assertRaises(grpc.RpcError) as exception_context:
with self.assertRaises(asyncio.CancelledError):
await task
self.assertEqual(grpc.StatusCode.CANCELLED,
exception_context.exception.code())
if __name__ == '__main__':

Loading…
Cancel
Save