Implement abort mechanism for server side

pull/21582/head
Lidi Zheng 5 years ago
parent b0d7e680cb
commit cddd0a0419
  1. 20
      src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi
  2. 3
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
  3. 138
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  4. 154
      src/python/grpcio_tests/tests_aio/unit/abort_test.py

@ -173,3 +173,23 @@ async def _receive_initial_metadata(GrpcCallWrapper grpc_call_wrapper,
cdef tuple ops = (op,)
await execute_batch(grpc_call_wrapper, ops, loop)
return op.initial_metadata()
async def _send_error_status_from_server(GrpcCallWrapper grpc_call_wrapper,
grpc_status_code code,
str details,
tuple trailing_metadata,
bint metadata_sent,
object loop):
assert code != StatusCode.ok, 'Expecting non-ok status code.'
cdef SendStatusFromServerOperation op = SendStatusFromServerOperation(
trailing_metadata,
code,
details,
_EMPTY_FLAGS,
)
cdef tuple ops
if metadata_sent:
ops = (op,)
else:
ops = (op, SendInitialMetadataOperation(None, _EMPTY_FLAG))
await execute_batch(grpc_call_wrapper, ops, loop)

@ -21,6 +21,9 @@ cdef class RPCState(GrpcCallWrapper):
cdef grpc_call_details details
cdef grpc_metadata_array request_metadata
cdef AioServer server
cdef object abort_exception
cdef bint metadata_sent
cdef bint status_sent
cdef bytes method(self)

@ -14,6 +14,7 @@
import inspect
import traceback
# TODO(https://github.com/grpc/grpc/issues/20850) refactor this.
@ -34,6 +35,9 @@ cdef class RPCState:
self.server = server
grpc_metadata_array_init(&self.request_metadata)
grpc_call_details_init(&self.details)
self.abort_exception = None
self.metadata_sent = False
self.status_sent = False
cdef bytes method(self):
return _slice_bytes(self.details.method)
@ -46,10 +50,54 @@ cdef class RPCState:
grpc_call_unref(self.call)
# TODO(lidiz) inherit this from Python level `AioRpcStatus`, we need to improve
# current code structure to make it happen.
class AbortError(Exception): pass
def _raise_if_aborted(RPCState rpc_state):
"""Raise AbortError if RPC is aborted.
Server method handlers may suppress the abort exception. We need to halt
the RPC execution in that case. This function needs to be called after
running application code.
"""
if rpc_state.abort_exception is not None:
raise rpc_state.abort_exception
async def _perform_abort(RPCState rpc_state,
grpc_status_code code,
str details,
tuple trailing_metadata,
object loop):
"""Perform the abort logic.
Sends final status to the client, and then set the RPC into corresponding
state.
"""
if rpc_state.abort_exception is not None:
raise RuntimeError('Abort already called!')
else:
# Keeps track of the exception object. After abort happen, the RPC
# should stop execution. However, if users decided to suppress it, it
# could lead to undefined behavior.
rpc_state.abort_exception = AbortError('Locally aborted.')
rpc_state.status_sent = True
await _send_error_status_from_server(
rpc_state,
code,
details,
trailing_metadata,
rpc_state.metadata_sent,
loop
)
cdef class _ServicerContext:
cdef RPCState _rpc_state
cdef object _loop
cdef bint _metadata_sent
cdef object _request_deserializer
cdef object _response_serializer
@ -62,27 +110,46 @@ cdef class _ServicerContext:
self._request_deserializer = request_deserializer
self._response_serializer = response_serializer
self._loop = loop
self._metadata_sent = False
async def read(self):
if self._rpc_state.status_sent:
raise RuntimeError('RPC already finished.')
cdef bytes raw_message = await _receive_message(self._rpc_state, self._loop)
return deserialize(self._request_deserializer,
raw_message)
async def write(self, object message):
if self._rpc_state.status_sent:
raise RuntimeError('RPC already finished.')
await _send_message(self._rpc_state,
serialize(self._response_serializer, message),
self._metadata_sent,
self._rpc_state.metadata_sent,
self._loop)
if not self._metadata_sent:
self._metadata_sent = True
if not self._rpc_state.metadata_sent:
self._rpc_state.metadata_sent = True
async def send_initial_metadata(self, tuple metadata):
if self._metadata_sent:
if self._rpc_state.status_sent:
raise RuntimeError('RPC already finished.')
elif self._rpc_state.metadata_sent:
raise RuntimeError('Send initial metadata failed: already sent')
else:
_send_initial_metadata(self._rpc_state, self._loop)
self._metadata_sent = True
self._rpc_state.metadata_sent = True
async def abort(self,
object code,
str details='',
tuple trailing_metadata=_EMPTY_METADATA):
await _perform_abort(
self._rpc_state,
code.value[0],
details,
trailing_metadata,
self._loop
)
raise self._rpc_state.abort_exception
cdef _find_method_handler(str method, list generic_handlers):
@ -120,6 +187,9 @@ async def _handle_unary_unary_rpc(object method_handler,
),
)
# Raises exception if aborted
_raise_if_aborted(rpc_state)
# Serializes the response message
cdef bytes response_raw = serialize(
method_handler.response_serializer,
@ -138,6 +208,7 @@ async def _handle_unary_unary_rpc(object method_handler,
SendMessageOperation(response_raw, _EMPTY_FLAGS),
)
await execute_batch(rpc_state, send_ops, loop)
rpc_state.status_sent = True
async def _handle_unary_stream_rpc(object method_handler,
@ -167,6 +238,9 @@ async def _handle_unary_stream_rpc(object method_handler,
request_message,
servicer_context,
)
# Raises exception if aborted
_raise_if_aborted(rpc_state)
else:
# The handler uses async generator API
async_response_generator = method_handler.unary_stream(
@ -176,6 +250,9 @@ async def _handle_unary_stream_rpc(object method_handler,
# Consumes messages from the generator
async for response_message in async_response_generator:
# Raises exception if aborted
_raise_if_aborted(rpc_state)
if rpc_state.server._status == AIO_SERVER_STATUS_STOPPED:
# The async generator might yield much much later after the
# server is destroied. If we proceed, Core will crash badly.
@ -194,6 +271,34 @@ async def _handle_unary_stream_rpc(object method_handler,
cdef tuple ops = (op,)
await execute_batch(rpc_state, ops, loop)
rpc_state.status_sent = True
async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop):
try:
try:
await rpc_coro
except AbortError as e:
# Caught AbortError check if it is the same one
assert rpc_state.abort_exception is e, 'Abort error has been replaced!'
return
else:
# Check if the abort exception got suppressed
if rpc_state.abort_exception is not None:
_LOGGER.error(
'Abort error unexpectedly suppressed: %s',
traceback.format_exception(rpc_state.abort_exception)
)
except Exception as e:
_LOGGER.exception(e)
if not rpc_state.status_sent and rpc_state.server._status != AIO_SERVER_STATUS_STOPPED:
await _perform_abort(
rpc_state,
StatusCode.unknown,
'%s: %s' % (type(e), e),
_EMPTY_METADATA,
loop
)
async def _handle_cancellation_from_core(object rpc_task,
@ -213,7 +318,11 @@ async def _schedule_rpc_coro(object rpc_coro,
RPCState rpc_state,
object loop):
# Schedules the RPC coroutine.
cdef object rpc_task = loop.create_task(rpc_coro)
cdef object rpc_task = loop.create_task(_handle_exceptions(
rpc_state,
rpc_coro,
loop,
))
await _handle_cancellation_from_core(rpc_task, rpc_state, loop)
@ -224,14 +333,23 @@ async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop):
generic_handlers,
)
if method_handler is None:
# TODO(lidiz) return unimplemented error to client side
raise NotImplementedError()
await _perform_abort(
rpc_state,
StatusCode.unimplemented,
b'Method not found!',
_EMPTY_METADATA,
loop
)
return
# TODO(lidiz) extend to all 4 types of RPC
if not method_handler.request_streaming and method_handler.response_streaming:
try:
await _handle_unary_stream_rpc(method_handler,
rpc_state,
loop)
except Exception as e:
raise
elif not method_handler.request_streaming and not method_handler.response_streaming:
await _handle_unary_unary_rpc(method_handler,
rpc_state,

@ -0,0 +1,154 @@
# Copyright 2019 The gRPC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import unittest
import time
import gc
import grpc
from grpc.experimental import aio
from tests_aio.unit._test_base import AioTestBase
from tests.unit.framework.common import test_constants
_UNARY_UNARY_ABORT = '/test/UnaryUnaryAbort'
_SUPPRESS_ABORT = '/test/SuppressAbort'
_REPLACE_ABORT = '/test/ReplaceAbort'
_ABORT_AFTER_REPLY = '/test/AbortAfterReply'
_REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x01\x01\x01'
_NUM_STREAM_RESPONSES = 5
_ABORT_CODE = grpc.StatusCode.RESOURCE_EXHAUSTED
_ABORT_DETAILS = 'Dummy error details'
class _GenericHandler(grpc.GenericRpcHandler):
@staticmethod
async def _unary_unary_abort(unused_request, context):
await context.abort(_ABORT_CODE, _ABORT_DETAILS)
raise RuntimeError('This line should not be executed')
@staticmethod
async def _suppress_abort(unused_request, context):
try:
await context.abort(_ABORT_CODE, _ABORT_DETAILS)
except Exception as e:
pass
return _RESPONSE
@staticmethod
async def _replace_abort(unused_request, context):
try:
await context.abort(_ABORT_CODE, _ABORT_DETAILS)
except Exception as e:
await context.abort(grpc.StatusCode.INVALID_ARGUMENT,
'Override abort!')
@staticmethod
async def _abort_after_reply(unused_request, context):
yield _RESPONSE
await context.abort(_ABORT_CODE, _ABORT_DETAILS)
raise RuntimeError('This line should not be executed')
def service(self, handler_details):
if handler_details.method == _UNARY_UNARY_ABORT:
return grpc.unary_unary_rpc_method_handler(self._unary_unary_abort)
if handler_details.method == _SUPPRESS_ABORT:
return grpc.unary_unary_rpc_method_handler(self._suppress_abort)
if handler_details.method == _REPLACE_ABORT:
return grpc.unary_unary_rpc_method_handler(self._replace_abort)
if handler_details.method == _ABORT_AFTER_REPLY:
return grpc.unary_stream_rpc_method_handler(self._abort_after_reply)
async def _start_test_server():
server = aio.server()
port = server.add_insecure_port('[::]:0')
server.add_generic_rpc_handlers((_GenericHandler(),))
await server.start()
return 'localhost:%d' % port, server
class TestServer(AioTestBase):
async def setUp(self):
address, self._server = await _start_test_server()
self._channel = aio.insecure_channel(address)
async def tearDown(self):
await self._channel.close()
await self._server.stop(None)
async def test_unary_unary_abort(self):
method = self._channel.unary_unary(_UNARY_UNARY_ABORT)
call = method(_REQUEST)
self.assertEqual(_ABORT_CODE, await call.code())
self.assertEqual(_ABORT_DETAILS, await call.details())
with self.assertRaises(grpc.RpcError) as exception_context:
await call
rpc_error = exception_context.exception
rpc_error.code()
self.assertEqual(_ABORT_CODE, rpc_error.code())
self.assertEqual(_ABORT_DETAILS, rpc_error.details())
async def test_suppress_abort(self):
method = self._channel.unary_unary(_SUPPRESS_ABORT)
call = method(_REQUEST)
with self.assertRaises(grpc.RpcError) as exception_context:
await call
rpc_error = exception_context.exception
rpc_error.code()
self.assertEqual(_ABORT_CODE, rpc_error.code())
self.assertEqual(_ABORT_DETAILS, rpc_error.details())
async def test_replace_abort(self):
method = self._channel.unary_unary(_REPLACE_ABORT)
call = method(_REQUEST)
with self.assertRaises(grpc.RpcError) as exception_context:
await call
rpc_error = exception_context.exception
rpc_error.code()
self.assertEqual(_ABORT_CODE, rpc_error.code())
self.assertEqual(_ABORT_DETAILS, rpc_error.details())
async def test_abort_after_reply(self):
method = self._channel.unary_stream(_ABORT_AFTER_REPLY)
call = method(_REQUEST)
with self.assertRaises(grpc.RpcError) as exception_context:
await call.read()
rpc_error = exception_context.exception
rpc_error.code()
self.assertEqual(_ABORT_CODE, rpc_error.code())
self.assertEqual(_ABORT_DETAILS, rpc_error.details())
self.assertEqual(_ABORT_CODE, await call.code())
self.assertEqual(_ABORT_DETAILS, await call.details())
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
Loading…
Cancel
Save