Merge pull request #21506 from lidizheng/aio-cancel

[Aio] Improve cancellation mechanism on client side
pull/21595/head
Lidi Zheng 5 years ago committed by GitHub
commit f9aed63225
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 13
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pxd.pxi
  2. 114
      src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
  3. 37
      src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
  4. 188
      src/python/grpcio/grpc/experimental/aio/_call.py
  5. 106
      src/python/grpcio_tests/tests_aio/unit/call_test.py

@ -22,13 +22,12 @@ cdef class _AioCall:
# time we need access to the event loop.
object _loop
# Streaming call only attributes:
#
# A asyncio.Event that indicates if the status is received on the client side.
object _status_received
# A tuple of key value pairs representing the initial metadata sent by peer.
tuple _initial_metadata
# Flag indicates whether cancel being called or not. Cancellation from
# Core or peer works perfectly fine with normal procedure. However, we
# need this flag to clean up resources for cancellation from the
# application layer. Directly cancelling tasks might cause segfault
# because Core is holding a pointer for the callback handler.
bint _is_locally_cancelled
cdef grpc_call* _create_grpc_call(self, object timeout, bytes method) except *
cdef void _destroy_grpc_call(self)
cdef AioRpcStatus _cancel_and_create_status(self, object cancellation_future)

@ -33,8 +33,7 @@ cdef class _AioCall:
self._grpc_call_wrapper = GrpcCallWrapper()
self._loop = asyncio.get_event_loop()
self._create_grpc_call(deadline, method)
self._status_received = asyncio.Event(loop=self._loop)
self._is_locally_cancelled = False
def __dealloc__(self):
self._destroy_grpc_call()
@ -78,17 +77,21 @@ cdef class _AioCall:
"""Destroys the corresponding Core object for this RPC."""
grpc_call_unref(self._grpc_call_wrapper.call)
cdef AioRpcStatus _cancel_and_create_status(self, object cancellation_future):
"""Cancels the RPC in Core, and return the final RPC status."""
cdef AioRpcStatus status
def cancel(self, AioRpcStatus status):
"""Cancels the RPC in Core with given RPC status.
Above abstractions must invoke this method to set Core objects into
proper state.
"""
self._is_locally_cancelled = True
cdef object details
cdef char *c_details
cdef grpc_call_error error
# Try to fetch application layer cancellation details in the future.
# * If cancellation details present, cancel with status;
# * If details not present, cancel with unknown reason.
if cancellation_future.done():
status = cancellation_future.result()
if status is not None:
details = str_to_bytes(status.details())
self._references.append(details)
c_details = <char *>details
@ -100,23 +103,13 @@ cdef class _AioCall:
NULL,
)
assert error == GRPC_CALL_OK
return status
else:
# By implementation, grpc_call_cancel always return OK
error = grpc_call_cancel(self._grpc_call_wrapper.call, NULL)
assert error == GRPC_CALL_OK
status = AioRpcStatus(
StatusCode.cancelled,
_UNKNOWN_CANCELLATION_DETAILS,
None,
None,
)
cancellation_future.set_result(status)
return status
async def unary_unary(self,
bytes request,
object cancellation_future,
object initial_metadata_observer,
object status_observer):
"""Performs a unary unary RPC.
@ -145,19 +138,11 @@ cdef class _AioCall:
receive_initial_metadata_op, receive_message_op,
receive_status_on_client_op)
try:
await execute_batch(self._grpc_call_wrapper,
ops,
self._loop)
except asyncio.CancelledError:
status = self._cancel_and_create_status(cancellation_future)
initial_metadata_observer(None)
status_observer(status)
raise
else:
initial_metadata_observer(
receive_initial_metadata_op.initial_metadata()
)
# Executes all operations in one batch.
# Might raise CancelledError, handling it in Python UnaryUnaryCall.
await execute_batch(self._grpc_call_wrapper,
ops,
self._loop)
status = AioRpcStatus(
receive_status_on_client_op.code(),
@ -179,6 +164,11 @@ cdef class _AioCall:
cdef ReceiveStatusOnClientOperation op = ReceiveStatusOnClientOperation(_EMPTY_FLAGS)
cdef tuple ops = (op,)
await execute_batch(self._grpc_call_wrapper, ops, self._loop)
# Halts if the RPC is locally cancelled
if self._is_locally_cancelled:
return
cdef AioRpcStatus status = AioRpcStatus(
op.code(),
op.details(),
@ -186,52 +176,30 @@ cdef class _AioCall:
op.error_string(),
)
status_observer(status)
self._status_received.set()
def _handle_cancellation_from_application(self,
object cancellation_future,
object status_observer):
def _cancellation_action(finished_future):
if not self._status_received.set():
status = self._cancel_and_create_status(finished_future)
status_observer(status)
self._status_received.set()
cancellation_future.add_done_callback(_cancellation_action)
async def _message_async_generator(self):
async def receive_serialized_message(self):
"""Receives one single raw message in bytes."""
cdef bytes received_message
# Infinitely receiving messages, until:
# Receives a message. Returns None when failed:
# * EOF, no more messages to read;
# * The client application cancells;
# * The client application cancels;
# * The server sends final status.
while True:
if self._status_received.is_set():
return
received_message = await _receive_message(
self._grpc_call_wrapper,
self._loop
)
if received_message is None:
# The read operation failed, Core should explain why it fails
await self._status_received.wait()
return
else:
yield received_message
received_message = await _receive_message(
self._grpc_call_wrapper,
self._loop
)
return received_message
async def unary_stream(self,
bytes request,
object cancellation_future,
object initial_metadata_observer,
object status_observer):
"""Actual implementation of the complete unary-stream call.
Needs to pay extra attention to the raise mechanism. If we want to
propagate the final status exception, then we have to raise it.
Othersize, it would end normally and raise `StopAsyncIteration()`.
"""
"""Implementation of the start of a unary-stream call."""
# Peer may prematurely end this RPC at any point. We need a corutine
# that watches if the server sends the final status.
self._loop.create_task(self._handle_status_once_received(status_observer))
cdef tuple outbound_ops
cdef Operation initial_metadata_op = SendInitialMetadataOperation(
_EMPTY_METADATA,
@ -248,21 +216,13 @@ cdef class _AioCall:
send_close_op,
)
# Actually sends out the request message.
# Sends out the request message.
await execute_batch(self._grpc_call_wrapper,
outbound_ops,
self._loop)
# Peer may prematurely end this RPC at any point. We need a mechanism
# that handles both the normal case and the error case.
self._loop.create_task(self._handle_status_once_received(status_observer))
self._handle_cancellation_from_application(cancellation_future,
status_observer)
outbound_ops,
self._loop)
# Receives initial metadata.
initial_metadata_observer(
await _receive_initial_metadata(self._grpc_call_wrapper,
self._loop),
)
return self._message_async_generator()

@ -26,38 +26,13 @@ cdef class AioChannel:
def close(self):
grpc_channel_destroy(self.channel)
async def unary_unary(self,
bytes method,
bytes request,
object deadline,
object cancellation_future,
object initial_metadata_observer,
object status_observer):
"""Assembles a unary-unary RPC.
def call(self,
bytes method,
object deadline):
"""Assembles a Cython Call object.
Returns:
The response message in bytes.
The _AioCall object.
"""
cdef _AioCall call = _AioCall(self, deadline, method)
return await call.unary_unary(request,
cancellation_future,
initial_metadata_observer,
status_observer)
def unary_stream(self,
bytes method,
bytes request,
object deadline,
object cancellation_future,
object initial_metadata_observer,
object status_observer):
"""Assembles a unary-stream RPC.
Returns:
An async generator that yields raw responses.
"""
cdef _AioCall call = _AioCall(self, deadline, method)
return call.unary_stream(request,
cancellation_future,
initial_metadata_observer,
status_observer)
return call

@ -41,6 +41,8 @@ _NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
'\tdebug_error_string = "{}"\n'
'>')
_EMPTY_METADATA = tuple()
class AioRpcError(grpc.RpcError):
"""An implementation of RpcError to be used by the asynchronous API.
@ -148,14 +150,14 @@ class Call(_base_call.Call):
_code: grpc.StatusCode
_status: Awaitable[cygrpc.AioRpcStatus]
_initial_metadata: Awaitable[MetadataType]
_cancellation: asyncio.Future
_locally_cancelled: bool
def __init__(self) -> None:
self._loop = asyncio.get_event_loop()
self._code = None
self._status = self._loop.create_future()
self._initial_metadata = self._loop.create_future()
self._cancellation = self._loop.create_future()
self._locally_cancelled = False
def cancel(self) -> bool:
"""Placeholder cancellation method.
@ -167,8 +169,7 @@ class Call(_base_call.Call):
raise NotImplementedError()
def cancelled(self) -> bool:
return self._cancellation.done(
) or self._code == grpc.StatusCode.CANCELLED
return self._code == grpc.StatusCode.CANCELLED
def done(self) -> bool:
return self._status.done()
@ -205,14 +206,22 @@ class Call(_base_call.Call):
cancellation (by application) and Core receiving status from peer. We
make no promise here which one will win.
"""
if self._status.done():
return
else:
self._status.set_result(status)
self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[
status.code()]
# In case of local cancellation, flip the flag.
if status.details() is _LOCAL_CANCELLATION_DETAILS:
self._locally_cancelled = True
async def _raise_rpc_error_if_not_ok(self) -> None:
# In case of the RPC finished without receiving metadata.
if not self._initial_metadata.done():
self._initial_metadata.set_result(_EMPTY_METADATA)
# Sets final status
self._status.set_result(status)
self._code = _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()]
async def _raise_for_status(self) -> None:
if self._locally_cancelled:
raise asyncio.CancelledError()
await self._status
if self._code != grpc.StatusCode.OK:
raise _create_rpc_error(await self.initial_metadata(),
self._status.result())
@ -245,12 +254,11 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
Returned when an instance of `UnaryUnaryMultiCallable` object is called.
"""
_request: RequestType
_deadline: Optional[float]
_channel: cygrpc.AioChannel
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_call: asyncio.Task
_cython_call: cygrpc._AioCall
def __init__(self, request: RequestType, deadline: Optional[float],
channel: cygrpc.AioChannel, method: bytes,
@ -258,11 +266,10 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
response_deserializer: DeserializingFunction) -> None:
super().__init__()
self._request = request
self._deadline = deadline
self._channel = channel
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._cython_call = self._channel.call(method, deadline)
self._call = self._loop.create_task(self._invoke())
def __del__(self) -> None:
@ -275,28 +282,30 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
serialized_request = _common.serialize(self._request,
self._request_serializer)
# NOTE(lidiz) asyncio.CancelledError is not a good transport for
# status, since the Task class do not cache the exact
# asyncio.CancelledError object. So, the solution is catching the error
# in Cython layer, then cancel the RPC and update the status, finally
# re-raise the CancelledError.
serialized_response = await self._channel.unary_unary(
self._method,
serialized_request,
self._deadline,
self._cancellation,
self._set_initial_metadata,
self._set_status,
)
await self._raise_rpc_error_if_not_ok()
# NOTE(lidiz) asyncio.CancelledError is not a good transport for status,
# because the asyncio.Task class do not cache the exception object.
# https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
try:
serialized_response = await self._cython_call.unary_unary(
serialized_request,
self._set_initial_metadata,
self._set_status,
)
except asyncio.CancelledError:
if self._code != grpc.StatusCode.CANCELLED:
self.cancel()
# Raises here if RPC failed or cancelled
await self._raise_for_status()
return _common.deserialize(serialized_response,
self._response_deserializer)
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool:
"""Forwards the application cancellation reasoning."""
if not self._status.done() and not self._cancellation.done():
self._cancellation.set_result(status)
if not self._status.done():
self._set_status(status)
self._cython_call.cancel(status)
self._call.cancel()
return True
else:
@ -308,16 +317,17 @@ class UnaryUnaryCall(Call, _base_call.UnaryUnaryCall):
_LOCAL_CANCELLATION_DETAILS, None, None))
def __await__(self) -> ResponseType:
"""Wait till the ongoing RPC request finishes.
Returns:
Response of the RPC call.
Raises:
RpcError: Indicating that the RPC terminated with non-OK status.
asyncio.CancelledError: Indicating that the RPC was canceled.
"""
response = yield from self._call
"""Wait till the ongoing RPC request finishes."""
try:
response = yield from self._call
except asyncio.CancelledError:
# Even if we caught all other CancelledError, there is still
# this corner case. If the application cancels immediately after
# the Call object is created, we will observe this
# `CancelledError`.
if not self.cancelled():
self.cancel()
raise
return response
@ -328,13 +338,11 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
Returned when an instance of `UnaryStreamMultiCallable` object is called.
"""
_request: RequestType
_deadline: Optional[float]
_channel: cygrpc.AioChannel
_method: bytes
_request_serializer: SerializingFunction
_response_deserializer: DeserializingFunction
_call: asyncio.Task
_bytes_aiter: AsyncIterable[bytes]
_cython_call: cygrpc._AioCall
_send_unary_request_task: asyncio.Task
_message_aiter: AsyncIterable[ResponseType]
def __init__(self, request: RequestType, deadline: Optional[float],
@ -343,13 +351,13 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
response_deserializer: DeserializingFunction) -> None:
super().__init__()
self._request = request
self._deadline = deadline
self._channel = channel
self._method = method
self._request_serializer = request_serializer
self._response_deserializer = response_deserializer
self._call = self._loop.create_task(self._invoke())
self._message_aiter = self._process()
self._send_unary_request_task = self._loop.create_task(
self._send_unary_request())
self._message_aiter = self._fetch_stream_responses()
self._cython_call = self._channel.call(method, deadline)
def __del__(self) -> None:
if not self._status.done():
@ -357,32 +365,24 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
cygrpc.AioRpcStatus(cygrpc.StatusCode.cancelled,
_GC_CANCELLATION_DETAILS, None, None))
async def _invoke(self) -> ResponseType:
async def _send_unary_request(self) -> ResponseType:
serialized_request = _common.serialize(self._request,
self._request_serializer)
self._bytes_aiter = await self._channel.unary_stream(
self._method,
serialized_request,
self._deadline,
self._cancellation,
self._set_initial_metadata,
self._set_status,
)
async def _process(self) -> ResponseType:
await self._call
async for serialized_response in self._bytes_aiter:
if self._cancellation.done():
await self._status
if self._status.done():
# Raises pre-maturely if final status received here. Generates
# more helpful stack trace for end users.
await self._raise_rpc_error_if_not_ok()
yield _common.deserialize(serialized_response,
self._response_deserializer)
await self._raise_rpc_error_if_not_ok()
try:
await self._cython_call.unary_stream(serialized_request,
self._set_initial_metadata,
self._set_status)
except asyncio.CancelledError:
if self._code != grpc.StatusCode.CANCELLED:
self.cancel()
raise
async def _fetch_stream_responses(self) -> ResponseType:
await self._send_unary_request_task
message = await self._read()
while message:
yield message
message = await self._read()
def _cancel(self, status: cygrpc.AioRpcStatus) -> bool:
"""Forwards the application cancellation reasoning.
@ -395,8 +395,15 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
and the client calling "cancel" at the same time, this method respects
the winner in Core.
"""
if not self._status.done() and not self._cancellation.done():
self._cancellation.set_result(status)
if not self._status.done():
self._set_status(status)
self._cython_call.cancel(status)
if not self._send_unary_request_task.done():
# Injects CancelledError to the Task. The exception will
# propagate to _fetch_stream_responses as well, if the sending
# is not done.
self._send_unary_request_task.cancel()
return True
else:
return False
@ -409,8 +416,35 @@ class UnaryStreamCall(Call, _base_call.UnaryStreamCall):
def __aiter__(self) -> AsyncIterable[ResponseType]:
return self._message_aiter
async def _read(self) -> ResponseType:
# Wait for the request being sent
await self._send_unary_request_task
# Reads response message from Core
try:
raw_response = await self._cython_call.receive_serialized_message()
except asyncio.CancelledError:
if self._code != grpc.StatusCode.CANCELLED:
self.cancel()
raise
if raw_response is None:
return None
else:
return _common.deserialize(raw_response,
self._response_deserializer)
async def read(self) -> ResponseType:
if self._status.done():
await self._raise_rpc_error_if_not_ok()
await self._raise_for_status()
raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
return await self._message_aiter.__anext__()
response_message = await self._read()
if response_message is None:
# If the read operation failed, Core should explain why.
await self._raise_for_status()
# If no exception raised, there is something wrong internally.
assert False, 'Read operation failed with StatusCode.OK'
else:
return response_message

@ -33,6 +33,8 @@ _LOCAL_CANCEL_DETAILS_EXPECTATION = 'Locally cancelled by application!'
_RESPONSE_INTERVAL_US = test_constants.SHORT_TIMEOUT * 1000 * 1000
_UNREACHABLE_TARGET = '0.1:1111'
_INFINITE_INTERVAL_US = 2**31 - 1
class TestUnaryUnaryCall(AioTestBase):
@ -119,24 +121,38 @@ class TestUnaryUnaryCall(AioTestBase):
self.assertFalse(call.cancelled())
# TODO(https://github.com/grpc/grpc/issues/20869) remove sleep.
# Force the loop to execute the RPC task.
await asyncio.sleep(0)
self.assertTrue(call.cancel())
self.assertFalse(call.cancel())
with self.assertRaises(asyncio.CancelledError) as exception_context:
with self.assertRaises(asyncio.CancelledError):
await call
# The info in the RpcError should match the info in Call object.
self.assertTrue(call.cancelled())
self.assertEqual(await call.code(), grpc.StatusCode.CANCELLED)
self.assertEqual(await call.details(),
'Locally cancelled by application!')
# NOTE(lidiz) The CancelledError is almost always re-created,
# so we might not want to use it to transmit data.
# https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
async def test_cancel_unary_unary_in_task(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
coro_started = asyncio.Event()
call = stub.EmptyCall(messages_pb2.SimpleRequest())
async def another_coro():
coro_started.set()
await call
task = self.loop.create_task(another_coro())
await coro_started.wait()
self.assertFalse(task.done())
task.cancel()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
with self.assertRaises(asyncio.CancelledError):
await task
class TestUnaryStreamCall(AioTestBase):
@ -175,7 +191,7 @@ class TestUnaryStreamCall(AioTestBase):
call.details())
self.assertFalse(call.cancel())
with self.assertRaises(grpc.RpcError) as exception_context:
with self.assertRaises(asyncio.CancelledError):
await call.read()
self.assertTrue(call.cancelled())
@ -206,7 +222,7 @@ class TestUnaryStreamCall(AioTestBase):
self.assertFalse(call.cancel())
self.assertFalse(call.cancel())
with self.assertRaises(grpc.RpcError) as exception_context:
with self.assertRaises(asyncio.CancelledError):
await call.read()
async def test_early_cancel_unary_stream(self):
@ -230,16 +246,11 @@ class TestUnaryStreamCall(AioTestBase):
self.assertTrue(call.cancel())
self.assertFalse(call.cancel())
with self.assertRaises(grpc.RpcError) as exception_context:
with self.assertRaises(asyncio.CancelledError):
await call.read()
self.assertTrue(call.cancelled())
self.assertEqual(grpc.StatusCode.CANCELLED,
exception_context.exception.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION,
exception_context.exception.details())
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
self.assertEqual(_LOCAL_CANCEL_DETAILS_EXPECTATION, await
call.details())
@ -323,6 +334,69 @@ class TestUnaryStreamCall(AioTestBase):
self.assertEqual(await call.code(), grpc.StatusCode.OK)
async def test_cancel_unary_stream_in_task_using_read(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
coro_started = asyncio.Event()
# Configs the server method to block forever
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_INFINITE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
async def another_coro():
coro_started.set()
await call.read()
task = self.loop.create_task(another_coro())
await coro_started.wait()
self.assertFalse(task.done())
task.cancel()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
with self.assertRaises(asyncio.CancelledError):
await task
async def test_cancel_unary_stream_in_task_using_async_for(self):
async with aio.insecure_channel(self._server_target) as channel:
stub = test_pb2_grpc.TestServiceStub(channel)
coro_started = asyncio.Event()
# Configs the server method to block forever
request = messages_pb2.StreamingOutputCallRequest()
request.response_parameters.append(
messages_pb2.ResponseParameters(
size=_RESPONSE_PAYLOAD_SIZE,
interval_us=_INFINITE_INTERVAL_US,
))
# Invokes the actual RPC
call = stub.StreamingOutputCall(request)
async def another_coro():
coro_started.set()
async for _ in call:
pass
task = self.loop.create_task(another_coro())
await coro_started.wait()
self.assertFalse(task.done())
task.cancel()
self.assertEqual(grpc.StatusCode.CANCELLED, await call.code())
with self.assertRaises(asyncio.CancelledError):
await task
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)

Loading…
Cancel
Save