Adding more catch clauses for CancelledError

pull/21506/head
Lidi Zheng 5 years ago
parent e8283e4818
commit d49b0849f0
  1. 6
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  2. 43
      src/python/grpcio/grpc/experimental/aio/_call.py
  3. 92
      src/python/grpcio_tests/tests_aio/unit/call_test.py

@ -77,6 +77,11 @@ cdef class _AioCall:
"""Destroys the corresponding Core object for this RPC."""
grpc_call_unref(self._grpc_call_wrapper.call)
@property
def locally_cancelled(self):
"""Grant Python layer access of the cancelled flag."""
return self._is_locally_cancelled
def cancel(self, AioRpcStatus status):
"""Cancels the RPC in Core with given RPC status.
@ -145,6 +150,7 @@ cdef class _AioCall:
receive_status_on_client_op)
# Executes all operations in one batch.
# Might raise CancelledError, handling it in Python UnaryUnaryCall.
await execute_batch(self._grpc_call_wrapper,
ops,
self._loop)

@ -15,7 +15,6 @@
import asyncio
from typing import AsyncIterable, Awaitable, Dict, Optional
import logging
import grpc
from grpc import _common
@ -42,6 +41,8 @@ _NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tdebug_error_string = "{}"\n'
'>')
_EMPTY_METADATA = tuple()
class AioRpcError(grpc.RpcError):
"""An implementation of RpcError to be used by the asynchronous API.
@ -205,7 +206,7 @@ class Call(_base_call.Call):
"""
# In case of the RPC finished without receiving metadata.
if not self._initial_metadata.done():
self._initial_metadata.set_result(None)
self._initial_metadata.set_result(_EMPTY_METADATA)
# Sets final status
self._status.set_result(status)
@ -283,10 +284,10 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
self._set_status,
)
except asyncio.CancelledError:
# Only this class can inject the CancelledError into the RPC
# coroutine, so we are certain that this exception is due to local
# cancellation.
assert self._code == grpc.StatusCode.CANCELLED
if self._code != grpc.StatusCode.CANCELLED:
self.cancel()
# Raises RpcError here if RPC failed or cancelled
await self._raise_rpc_error_if_not_ok()
return _common.deserialize(serialized_response,
@ -357,8 +358,16 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
async def _send_unary_request(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
self._request_serializer)
await self._cython_call.unary_stream(
serialized_request, self._set_initial_metadata, self._set_status)
try:
await self._cython_call.unary_stream(
serialized_request,
self._set_initial_metadata,
self._set_status
)
except asyncio.CancelledError:
if self._code != grpc.StatusCode.CANCELLED:
self.cancel()
await self._raise_rpc_error_if_not_ok()
async def _fetch_stream_responses(self) -> ResponseType:
await self._send_unary_request_task
@ -400,12 +409,21 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
return self._message_aiter
async def _read(self) -> ResponseType:
serialized_response = await self._cython_call.receive_serialized_message(
)
if serialized_response is None:
# Wait for the request being sent
await self._send_unary_request_task
# Reads response message from Core
try:
raw_response = await self._cython_call.receive_serialized_message()
except asyncio.CancelledError:
if self._code != grpc.StatusCode.CANCELLED:
self.cancel()
await self._raise_rpc_error_if_not_ok()
if raw_response is None:
return None
else:
return _common.deserialize(serialized_response,
return _common.deserialize(raw_response,
self._response_deserializer)
async def read(self) -> ResponseType:
@ -414,6 +432,7 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
response_message = await self._read()
if response_message is None:
# If the read operation failed, Core should explain why.
await self._raise_rpc_error_if_not_ok()

@ -33,6 +33,8 @@ _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
_RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000
_UNREACHABLE_TARGET = '0.1:1111'
_INFINITE_INTERVAL_US = 2**31-1
class TestUnaryUnaryCall(AioTestBase):
@ -143,6 +145,29 @@ class TestUnaryUnaryCall(AioTestBase):
self.assertEqual(await call.details(),
'Locally cancelled by application!')
async def test_cancel_unary_unary_in_task(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
coro_started = asyncio.Event()
call = stub.EmptyCall(messages_pb2.SimpleRequest())
async def another_coro():
coro_started.set()
await call
task = self.loop.create_task(another_coro())
await coro_started.wait()
self.assertFalse(task.done())
task.cancel()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
with self.assertRaises(grpc.RpcError) as exception_context:
await task
self.assertEqual(grpc.StatusCode.CANCELLED,
exception_context.exception.code())
class TestUnaryStreamCall(AioTestBase):
@ -328,6 +353,73 @@ class TestUnaryStreamCall(AioTestBase):
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_cancel_unary_stream_in_task_using_read(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
coro_started = asyncio.Event()
# Configs the server method to block forever
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_INFINITE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
async def another_coro():
coro_started.set()
await call.read()
task = self.loop.create_task(another_coro())
await coro_started.wait()
self.assertFalse(task.done())
task.cancel()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
with self.assertRaises(grpc.RpcError) as exception_context:
await task
self.assertEqual(grpc.StatusCode.CANCELLED,
exception_context.exception.code())
async def test_cancel_unary_stream_in_task_using_async_for(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
coro_started = asyncio.Event()
# Configs the server method to block forever
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_INFINITE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
async def another_coro():
coro_started.set()
async for _ in call:
pass
task = self.loop.create_task(another_coro())
await coro_started.wait()
self.assertFalse(task.done())
task.cancel()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
with self.assertRaises(grpc.RpcError) as exception_context:
await task
self.assertEqual(grpc.StatusCode.CANCELLED,
exception_context.exception.code())
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)

Loading…
Cancel
Save