From 4e3d980f7038f229a09121cf885e9fcf39221d1d Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Thu, 2 Jan 2020 10:23:06 -0800 Subject: [PATCH] Convert local cancellation exception into CancelledError --- .../grpcio/grpc/experimental/aio/_call.py | 24 +++++++++----- .../grpcio_tests/tests_aio/unit/call_test.py | 33 ++++--------------- 2 files changed, 23 insertions(+), 34 deletions(-) diff --git a/src/python/grpcio/grpc/experimental/aio/_call.py b/src/python/grpcio/grpc/experimental/aio/_call.py index a8969b06edb..1557b678d66 100644 --- a/src/python/grpcio/grpc/experimental/aio/_call.py +++ b/src/python/grpcio/grpc/experimental/aio/_call.py @@ -150,12 +150,14 @@ class Call(_base_call.Call): _code: grpc.StatusCode _status: Awaitable[cygrpc.AioRpcStatus] _initial_metadata: Awaitable[MetadataType] + _locally_cancelled: bool def __init__(self) -> None: self._loop = asyncio.get_event_loop() self._code = None self._status = self._loop.create_future() self._initial_metadata = self._loop.create_future() + self._locally_cancelled = False def cancel(self) -> bool: """Placeholder cancellation method. @@ -204,6 +206,10 @@ class Call(_base_call.Call): cancellation (by application) and Core receiving status from peer. We make no promise here which one will win. """ + # In case of local cancellation, flip the flag. + if status.details() is _LOCAL_CANCELLATION_DETAILS: + self._locally_cancelled = True + # In case of the RPC finished without receiving metadata. if not self._initial_metadata.done(): self._initial_metadata.set_result(_EMPTY_METADATA) @@ -212,7 +218,9 @@ class Call(_base_call.Call): self._status.set_result(status) self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()] - async def _raise_rpc_error_if_not_ok(self) -> None: + async def _raise_if_not_ok(self) -> None: + if self._locally_cancelled: + raise asyncio.CancelledError() await self._status if self._code != grpc.StatusCode.OK: raise _create_rpc_error(await self.initial_metadata(), @@ -287,8 +295,8 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): if self._code != grpc.StatusCode.CANCELLED: self.cancel() - # Raises RpcError here if RPC failed or cancelled - await self._raise_rpc_error_if_not_ok() + # Raises here if RPC failed or cancelled + await self._raise_if_not_ok() return _common.deserialize(serialized_response, self._response_deserializer) @@ -319,7 +327,7 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall): # `CancelledError`. if not self.cancelled(): self.cancel() - raise _create_rpc_error(_EMPTY_METADATA, self._status.result()) + raise return response @@ -367,7 +375,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): except asyncio.CancelledError: if self._code != grpc.StatusCode.CANCELLED: self.cancel() - await self._raise_rpc_error_if_not_ok() + raise async def _fetch_stream_responses(self) -> ResponseType: await self._send_unary_request_task @@ -418,7 +426,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): except asyncio.CancelledError: if self._code != grpc.StatusCode.CANCELLED: self.cancel() - await self._raise_rpc_error_if_not_ok() + raise if raw_response is None: return None @@ -428,14 +436,14 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall): async def read(self) -> ResponseType: if self._status.done(): - await self._raise_rpc_error_if_not_ok() + await self._raise_if_not_ok() raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS) response_message = await self._read() if response_message is None: # If the read operation failed, Core should explain why. - await self._raise_rpc_error_if_not_ok() + await self._raise_if_not_ok() # If no exception raised, there is something wrong internally. assert False, 'Read operation failed with StatusCode.OK' else: diff --git a/src/python/grpcio_tests/tests_aio/unit/call_test.py b/src/python/grpcio_tests/tests_aio/unit/call_test.py index cecce1c79d5..bdf2fbfda6b 100644 --- a/src/python/grpcio_tests/tests_aio/unit/call_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/call_test.py @@ -124,18 +124,10 @@ class TestUnaryUnaryCall(AioTestBase): self.assertTrue(call.cancel()) self.assertFalse(call.cancel()) - with self.assertRaises(grpc.RpcError) as exception_context: + with self.assertRaises(asyncio.CancelledError): await call # The info in the RpcError should match the info in Call object. - rpc_error = exception_context.exception - self.assertEqual(rpc_error.code(), await call.code()) - self.assertEqual(rpc_error.details(), await call.details()) - self.assertEqual(rpc_error.trailing_metadata(), await - call.trailing_metadata()) - self.assertEqual(rpc_error.debug_error_string(), await - call.debug_error_string()) - self.assertTrue(call.cancelled()) self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED) self.assertEqual(await call.details(), @@ -159,10 +151,8 @@ class TestUnaryUnaryCall(AioTestBase): self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) - with self.assertRaises(grpc.RpcError) as exception_context: + with self.assertRaises(asyncio.CancelledError): await task - self.assertEqual(grpc.StatusCode.CANCELLED, - exception_context.exception.code()) class TestUnaryStreamCall(AioTestBase): @@ -201,7 +191,7 @@ class TestUnaryStreamCall(AioTestBase): call.details()) self.assertFalse(call.cancel()) - with self.assertRaises(grpc.RpcError) as exception_context: + with self.assertRaises(asyncio.CancelledError): await call.read() self.assertTrue(call.cancelled()) @@ -232,7 +222,7 @@ class TestUnaryStreamCall(AioTestBase): self.assertFalse(call.cancel()) self.assertFalse(call.cancel()) - with self.assertRaises(grpc.RpcError) as exception_context: + with self.assertRaises(asyncio.CancelledError): await call.read() async def test_early_cancel_unary_stream(self): @@ -256,16 +246,11 @@ class TestUnaryStreamCall(AioTestBase): self.assertTrue(call.cancel()) self.assertFalse(call.cancel()) - with self.assertRaises(grpc.RpcError) as exception_context: + with self.assertRaises(asyncio.CancelledError): await call.read() self.assertTrue(call.cancelled()) - self.assertEqual(grpc.StatusCode.CANCELLED, - exception_context.exception.code()) - self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, - exception_context.exception.details()) - self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await call.details()) @@ -377,10 +362,8 @@ class TestUnaryStreamCall(AioTestBase): self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) - with self.assertRaises(grpc.RpcError) as exception_context: + with self.assertRaises(asyncio.CancelledError): await task - self.assertEqual(grpc.StatusCode.CANCELLED, - exception_context.exception.code()) async def test_cancel_unary_stream_in_task_using_async_for(self): async with aio.insecure_channel(self._server_target) as channel: @@ -411,10 +394,8 @@ class TestUnaryStreamCall(AioTestBase): self.assertEqual(grpc.StatusCode.CANCELLED, await call.code()) - with self.assertRaises(grpc.RpcError) as exception_context: + with self.assertRaises(asyncio.CancelledError): await task - self.assertEqual(grpc.StatusCode.CANCELLED, - exception_context.exception.code()) if __name__ == '__main__':