diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi index 4703337b60c..5c79460879d 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi @@ -32,6 +32,7 @@ cdef class RPCState(GrpcCallWrapper): cdef bytes method(self) cdef tuple invocation_metadata(self) + cdef void raise_for_termination(self) except * cdef enum AioServerStatus: diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index 98410c9b502..65bfa6b30ca 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -20,7 +20,7 @@ import traceback # TODO(https://github.com/grpc/grpc/issues/20850) refactor this. _LOGGER = logging.getLogger(__name__) cdef int _EMPTY_FLAG = 0 -# TODO(lidiz) Use a designated value other than None. +cdef str _RPC_FINISHED_DETAILS = 'RPC already finished.' cdef str _SERVER_STOPPED_DETAILS = 'Server already stopped.' cdef class _HandlerCallDetails: @@ -29,6 +29,10 @@ cdef class _HandlerCallDetails: self.invocation_metadata = invocation_metadata +class _ServerStoppedError(RuntimeError): + """Raised if the server is stopped.""" + + cdef class RPCState: def __cinit__(self, AioServer server): @@ -48,6 +52,23 @@ cdef class RPCState: cdef tuple invocation_metadata(self): return _metadata(&self.request_metadata) + cdef void raise_for_termination(self) except *: + """Raise exceptions if RPC is not running. + + Server method handlers may suppress the abort exception. We need to halt + the RPC execution in that case. This function needs to be called after + running application code. + + Also, the server may stop unexpected. We need to check before calling + into Core functions, otherwise, segfault. + """ + if self.abort_exception is not None: + raise self.abort_exception + if self.status_sent: + raise RuntimeError(_RPC_FINISHED_DETAILS) + if self.server._status == AIO_SERVER_STATUS_STOPPED: + raise _ServerStoppedError(_SERVER_STOPPED_DETAILS) + def __dealloc__(self): """Cleans the Core objects.""" grpc_call_details_destroy(&self.details) @@ -61,17 +82,6 @@ cdef class RPCState: class AbortError(Exception): pass -def _raise_if_aborted(RPCState rpc_state): - """Raise AbortError if RPC is aborted. - - Server method handlers may suppress the abort exception. We need to halt - the RPC execution in that case. This function needs to be called after - running application code. - """ - if rpc_state.abort_exception is not None: - raise rpc_state.abort_exception - - cdef class _ServicerContext: cdef RPCState _rpc_state cdef object _loop @@ -90,10 +100,8 @@ cdef class _ServicerContext: async def read(self): cdef bytes raw_message - if self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED: - raise RuntimeError(_SERVER_STOPPED_DETAILS) - if self._rpc_state.status_sent: - raise RuntimeError('RPC already finished.') + self._rpc_state.raise_for_termination() + if self._rpc_state.client_closed: return EOF raw_message = await _receive_message(self._rpc_state, self._loop) @@ -104,10 +112,8 @@ cdef class _ServicerContext: raw_message) async def write(self, object message): - if self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED: - raise RuntimeError(_SERVER_STOPPED_DETAILS) - if self._rpc_state.status_sent: - raise RuntimeError('RPC already finished.') + self._rpc_state.raise_for_termination() + await _send_message(self._rpc_state, serialize(self._response_serializer, message), self._rpc_state.metadata_sent, @@ -116,11 +122,9 @@ cdef class _ServicerContext: self._rpc_state.metadata_sent = True async def send_initial_metadata(self, tuple metadata): - if self._rpc_state.status_sent: - raise RuntimeError('RPC already finished.') - elif self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED: - raise RuntimeError(_SERVER_STOPPED_DETAILS) - elif self._rpc_state.metadata_sent: + self._rpc_state.raise_for_termination() + + if self._rpc_state.metadata_sent: raise RuntimeError('Send initial metadata failed: already sent') else: await _send_initial_metadata(self._rpc_state, metadata, self._loop) @@ -191,7 +195,7 @@ async def _finish_handler_with_unary_response(RPCState rpc_state, ) # Raises exception if aborted - _raise_if_aborted(rpc_state) + rpc_state.raise_for_termination() # Serializes the response message cdef bytes response_raw = serialize( @@ -238,9 +242,6 @@ async def _finish_handler_with_stream_responses(RPCState rpc_state, request, servicer_context, ) - - # Raises exception if aborted - _raise_if_aborted(rpc_state) else: # The handler uses async generator API async_response_generator = stream_handler( @@ -251,15 +252,12 @@ async def _finish_handler_with_stream_responses(RPCState rpc_state, # Consumes messages from the generator async for response_message in async_response_generator: # Raises exception if aborted - _raise_if_aborted(rpc_state) + rpc_state.raise_for_termination() - if rpc_state.server._status == AIO_SERVER_STATUS_STOPPED: - # The async generator might yield much much later after the - # server is destroied. If we proceed, Core will crash badly. - _LOGGER.info('Aborting RPC due to server stop.') - return - else: - await servicer_context.write(response_message) + await servicer_context.write(response_message) + + # Raises exception if aborted + rpc_state.raise_for_termination() # Sends the final status of this RPC cdef SendStatusFromServerOperation op = SendStatusFromServerOperation( @@ -418,6 +416,8 @@ async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop): ) except (KeyboardInterrupt, SystemExit): raise + except _ServerStoppedError: + _LOGGER.info('Aborting RPC due to server stop.') except Exception as e: _LOGGER.exception(e) if not rpc_state.status_sent and rpc_state.server._status != AIO_SERVER_STATUS_STOPPED: diff --git a/src/python/grpcio_tests/tests_aio/unit/server_test.py b/src/python/grpcio_tests/tests_aio/unit/server_test.py index b7e4b233cb5..367d54c82cc 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_test.py @@ -37,6 +37,7 @@ _STREAM_STREAM_ASYNC_GEN = '/test/StreamStreamAsyncGen' _STREAM_STREAM_READER_WRITER = '/test/StreamStreamReaderWriter' _STREAM_STREAM_EVILLY_MIXED = '/test/StreamStreamEvillyMixed' _UNIMPLEMENTED_METHOD = '/test/UnimplementedMethod' +_ERROR_IN_STREAM_STREAM = '/test/ErrorInStreamStream' _REQUEST = b'\x00\x00\x00' _RESPONSE = b'\x01\x01\x01' @@ -82,6 +83,9 @@ class _GenericHandler(grpc.GenericRpcHandler): _STREAM_STREAM_EVILLY_MIXED: grpc.stream_stream_rpc_method_handler( self._stream_stream_evilly_mixed), + _ERROR_IN_STREAM_STREAM: + grpc.stream_stream_rpc_method_handler( + self._error_in_stream_stream), } @staticmethod @@ -158,6 +162,12 @@ class _GenericHandler(grpc.GenericRpcHandler): for _ in range(_NUM_STREAM_RESPONSES - 1): await context.write(_RESPONSE) + async def _error_in_stream_stream(self, request_iterator, unused_context): + async for request in request_iterator: + assert _REQUEST == request + raise RuntimeError('A testing RuntimeError!') + yield _RESPONSE + def service(self, handler_details): self._called.set_result(None) return self._routing_table.get(handler_details.method) @@ -401,6 +411,28 @@ class TestServer(AioTestBase): rpc_error = exception_context.exception self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code()) + async def test_shutdown_during_stream_stream(self): + stream_stream_call = self._channel.stream_stream( + _STREAM_STREAM_ASYNC_GEN) + call = stream_stream_call() + + # Don't half close the RPC yet, keep it alive. + await call.write(_REQUEST) + await self._server.stop(None) + + self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code()) + # No segfault + + async def test_error_in_stream_stream(self): + stream_stream_call = self._channel.stream_stream( + _ERROR_IN_STREAM_STREAM) + call = stream_stream_call() + + # Don't half close the RPC yet, keep it alive. + await call.write(_REQUEST) + + # Don't segfault here + self.assertEqual(grpc.StatusCode.UNKNOWN, await call.code()) if __name__ == '__main__': logging.basicConfig(level=logging.DEBUG)