diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi index 1bcc61a9856..b30c8710f6a 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/callback_common.pyx.pxi @@ -173,3 +173,23 @@ async def _receive_initial_metadata(GrpcCallWrapper grpc_call_wrapper, cdef tuple ops = (op,) await execute_batch(grpc_call_wrapper, ops, loop) return op.initial_metadata() + +async def _send_error_status_from_server(GrpcCallWrapper grpc_call_wrapper, + grpc_status_code code, + str details, + tuple trailing_metadata, + bint metadata_sent, + object loop): + assert code != StatusCode.ok, 'Expecting non-ok status code.' + cdef SendStatusFromServerOperation op = SendStatusFromServerOperation( + trailing_metadata, + code, + details, + _EMPTY_FLAGS, + ) + cdef tuple ops + if metadata_sent: + ops = (op,) + else: + ops = (op, SendInitialMetadataOperation(None, _EMPTY_FLAG)) + await execute_batch(grpc_call_wrapper, ops, loop) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi index e8852f4f5b9..b8ae832bfc8 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi @@ -21,6 +21,9 @@ cdef class RPCState(GrpcCallWrapper): cdef grpc_call_details details cdef grpc_metadata_array request_metadata cdef AioServer server + cdef object abort_exception + cdef bint metadata_sent + cdef bint status_sent cdef bytes method(self) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index 8aee3295f55..46594924dea 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -14,6 +14,7 @@ import inspect +import traceback # TODO(https://github.com/grpc/grpc/issues/20850) refactor this. @@ -34,6 +35,9 @@ cdef class RPCState: self.server = server grpc_metadata_array_init(&self.request_metadata) grpc_call_details_init(&self.details) + self.abort_exception = None + self.metadata_sent = False + self.status_sent = False cdef bytes method(self): return _slice_bytes(self.details.method) @@ -46,10 +50,54 @@ cdef class RPCState: grpc_call_unref(self.call) +# TODO(lidiz) inherit this from Python level `AioRpcStatus`, we need to improve +# current code structure to make it happen. +class AbortError(Exception): pass + + +def _raise_if_aborted(RPCState rpc_state): + """Raise AbortError if RPC is aborted. + + Server method handlers may suppress the abort exception. We need to halt + the RPC execution in that case. This function needs to be called after + running application code. + """ + if rpc_state.abort_exception is not None: + raise rpc_state.abort_exception + + +async def _perform_abort(RPCState rpc_state, + grpc_status_code code, + str details, + tuple trailing_metadata, + object loop): + """Perform the abort logic. + + Sends final status to the client, and then set the RPC into corresponding + state. + """ + if rpc_state.abort_exception is not None: + raise RuntimeError('Abort already called!') + else: + # Keeps track of the exception object. After abort happen, the RPC + # should stop execution. However, if users decided to suppress it, it + # could lead to undefined behavior. + rpc_state.abort_exception = AbortError('Locally aborted.') + + rpc_state.status_sent = True + await _send_error_status_from_server( + rpc_state, + code, + details, + trailing_metadata, + rpc_state.metadata_sent, + loop + ) + + cdef class _ServicerContext: cdef RPCState _rpc_state cdef object _loop - cdef bint _metadata_sent cdef object _request_deserializer cdef object _response_serializer @@ -62,27 +110,46 @@ cdef class _ServicerContext: self._request_deserializer = request_deserializer self._response_serializer = response_serializer self._loop = loop - self._metadata_sent = False async def read(self): + if self._rpc_state.status_sent: + raise RuntimeError('RPC already finished.') cdef bytes raw_message = await _receive_message(self._rpc_state, self._loop) return deserialize(self._request_deserializer, raw_message) async def write(self, object message): + if self._rpc_state.status_sent: + raise RuntimeError('RPC already finished.') await _send_message(self._rpc_state, serialize(self._response_serializer, message), - self._metadata_sent, + self._rpc_state.metadata_sent, self._loop) - if not self._metadata_sent: - self._metadata_sent = True + if not self._rpc_state.metadata_sent: + self._rpc_state.metadata_sent = True async def send_initial_metadata(self, tuple metadata): - if self._metadata_sent: + if self._rpc_state.status_sent: + raise RuntimeError('RPC already finished.') + elif self._rpc_state.metadata_sent: raise RuntimeError('Send initial metadata failed: already sent') else: _send_initial_metadata(self._rpc_state, self._loop) - self._metadata_sent = True + self._rpc_state.metadata_sent = True + + async def abort(self, + object code, + str details='', + tuple trailing_metadata=_EMPTY_METADATA): + await _perform_abort( + self._rpc_state, + code.value[0], + details, + trailing_metadata, + self._loop + ) + + raise self._rpc_state.abort_exception cdef _find_method_handler(str method, list generic_handlers): @@ -120,6 +187,9 @@ async def _handle_unary_unary_rpc(object method_handler, ), ) + # Raises exception if aborted + _raise_if_aborted(rpc_state) + # Serializes the response message cdef bytes response_raw = serialize( method_handler.response_serializer, @@ -138,6 +208,7 @@ async def _handle_unary_unary_rpc(object method_handler, SendMessageOperation(response_raw, _EMPTY_FLAGS), ) await execute_batch(rpc_state, send_ops, loop) + rpc_state.status_sent = True async def _handle_unary_stream_rpc(object method_handler, @@ -167,6 +238,9 @@ async def _handle_unary_stream_rpc(object method_handler, request_message, servicer_context, ) + + # Raises exception if aborted + _raise_if_aborted(rpc_state) else: # The handler uses async generator API async_response_generator = method_handler.unary_stream( @@ -176,6 +250,9 @@ async def _handle_unary_stream_rpc(object method_handler, # Consumes messages from the generator async for response_message in async_response_generator: + # Raises exception if aborted + _raise_if_aborted(rpc_state) + if rpc_state.server._status == AIO_SERVER_STATUS_STOPPED: # The async generator might yield much much later after the # server is destroied. If we proceed, Core will crash badly. @@ -194,6 +271,34 @@ async def _handle_unary_stream_rpc(object method_handler, cdef tuple ops = (op,) await execute_batch(rpc_state, ops, loop) + rpc_state.status_sent = True + + +async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop): + try: + try: + await rpc_coro + except AbortError as e: + # Caught AbortError check if it is the same one + assert rpc_state.abort_exception is e, 'Abort error has been replaced!' + return + else: + # Check if the abort exception got suppressed + if rpc_state.abort_exception is not None: + _LOGGER.error( + 'Abort error unexpectedly suppressed: %s', + traceback.format_exception(rpc_state.abort_exception) + ) + except Exception as e: + _LOGGER.exception(e) + if not rpc_state.status_sent and rpc_state.server._status != AIO_SERVER_STATUS_STOPPED: + await _perform_abort( + rpc_state, + StatusCode.unknown, + '%s: %s' % (type(e), e), + _EMPTY_METADATA, + loop + ) async def _handle_cancellation_from_core(object rpc_task, @@ -213,7 +318,11 @@ async def _schedule_rpc_coro(object rpc_coro, RPCState rpc_state, object loop): # Schedules the RPC coroutine. - cdef object rpc_task = loop.create_task(rpc_coro) + cdef object rpc_task = loop.create_task(_handle_exceptions( + rpc_state, + rpc_coro, + loop, + )) await _handle_cancellation_from_core(rpc_task, rpc_state, loop) @@ -224,14 +333,23 @@ async def _handle_rpc(list generic_handlers, RPCState rpc_state, object loop): generic_handlers, ) if method_handler is None: - # TODO(lidiz) return unimplemented error to client side - raise NotImplementedError() + await _perform_abort( + rpc_state, + StatusCode.unimplemented, + b'Method not found!', + _EMPTY_METADATA, + loop + ) + return # TODO(lidiz) extend to all 4 types of RPC if not method_handler.request_streaming and method_handler.response_streaming: - await _handle_unary_stream_rpc(method_handler, - rpc_state, - loop) + try: + await _handle_unary_stream_rpc(method_handler, + rpc_state, + loop) + except Exception as e: + raise elif not method_handler.request_streaming and not method_handler.response_streaming: await _handle_unary_unary_rpc(method_handler, rpc_state, diff --git a/src/python/grpcio_tests/tests_aio/unit/abort_test.py b/src/python/grpcio_tests/tests_aio/unit/abort_test.py new file mode 100644 index 00000000000..c877fb7e6cf --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/abort_test.py @@ -0,0 +1,154 @@ +# Copyright 2019 The gRPC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import unittest +import time +import gc + +import grpc +from grpc.experimental import aio +from tests_aio.unit._test_base import AioTestBase +from tests.unit.framework.common import test_constants + +_UNARY_UNARY_ABORT = '/test/UnaryUnaryAbort' +_SUPPRESS_ABORT = '/test/SuppressAbort' +_REPLACE_ABORT = '/test/ReplaceAbort' +_ABORT_AFTER_REPLY = '/test/AbortAfterReply' + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x01\x01\x01' +_NUM_STREAM_RESPONSES = 5 + +_ABORT_CODE = grpc.StatusCode.RESOURCE_EXHAUSTED +_ABORT_DETAILS = 'Dummy error details' + + +class _GenericHandler(grpc.GenericRpcHandler): + + @staticmethod + async def _unary_unary_abort(unused_request, context): + await context.abort(_ABORT_CODE, _ABORT_DETAILS) + raise RuntimeError('This line should not be executed') + + @staticmethod + async def _suppress_abort(unused_request, context): + try: + await context.abort(_ABORT_CODE, _ABORT_DETAILS) + except Exception as e: + pass + return _RESPONSE + + @staticmethod + async def _replace_abort(unused_request, context): + try: + await context.abort(_ABORT_CODE, _ABORT_DETAILS) + except Exception as e: + await context.abort(grpc.StatusCode.INVALID_ARGUMENT, + 'Override abort!') + + @staticmethod + async def _abort_after_reply(unused_request, context): + yield _RESPONSE + await context.abort(_ABORT_CODE, _ABORT_DETAILS) + raise RuntimeError('This line should not be executed') + + def service(self, handler_details): + if handler_details.method == _UNARY_UNARY_ABORT: + return grpc.unary_unary_rpc_method_handler(self._unary_unary_abort) + if handler_details.method == _SUPPRESS_ABORT: + return grpc.unary_unary_rpc_method_handler(self._suppress_abort) + if handler_details.method == _REPLACE_ABORT: + return grpc.unary_unary_rpc_method_handler(self._replace_abort) + if handler_details.method == _ABORT_AFTER_REPLY: + return grpc.unary_stream_rpc_method_handler(self._abort_after_reply) + + +async def _start_test_server(): + server = aio.server() + port = server.add_insecure_port('[::]:0') + server.add_generic_rpc_handlers((_GenericHandler(),)) + await server.start() + return 'localhost:%d' % port, server + + +class TestServer(AioTestBase): + + async def setUp(self): + address, self._server = await _start_test_server() + self._channel = aio.insecure_channel(address) + + async def tearDown(self): + await self._channel.close() + await self._server.stop(None) + + async def test_unary_unary_abort(self): + method = self._channel.unary_unary(_UNARY_UNARY_ABORT) + call = method(_REQUEST) + + self.assertEqual(_ABORT_CODE, await call.code()) + self.assertEqual(_ABORT_DETAILS, await call.details()) + + with self.assertRaises(grpc.RpcError) as exception_context: + await call + + rpc_error = exception_context.exception + rpc_error.code() + self.assertEqual(_ABORT_CODE, rpc_error.code()) + self.assertEqual(_ABORT_DETAILS, rpc_error.details()) + + async def test_suppress_abort(self): + method = self._channel.unary_unary(_SUPPRESS_ABORT) + call = method(_REQUEST) + + with self.assertRaises(grpc.RpcError) as exception_context: + await call + + rpc_error = exception_context.exception + rpc_error.code() + self.assertEqual(_ABORT_CODE, rpc_error.code()) + self.assertEqual(_ABORT_DETAILS, rpc_error.details()) + + async def test_replace_abort(self): + method = self._channel.unary_unary(_REPLACE_ABORT) + call = method(_REQUEST) + + with self.assertRaises(grpc.RpcError) as exception_context: + await call + + rpc_error = exception_context.exception + rpc_error.code() + self.assertEqual(_ABORT_CODE, rpc_error.code()) + self.assertEqual(_ABORT_DETAILS, rpc_error.details()) + + async def test_abort_after_reply(self): + method = self._channel.unary_stream(_ABORT_AFTER_REPLY) + call = method(_REQUEST) + + with self.assertRaises(grpc.RpcError) as exception_context: + await call.read() + + rpc_error = exception_context.exception + rpc_error.code() + self.assertEqual(_ABORT_CODE, rpc_error.code()) + self.assertEqual(_ABORT_DETAILS, rpc_error.details()) + + self.assertEqual(_ABORT_CODE, await call.code()) + self.assertEqual(_ABORT_DETAILS, await call.details()) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main(verbosity=2)