Fixing a segfault in the server shutdown path

pull/21708/head
Lidi Zheng 5 years ago
parent b9083a9edb
commit 80d7acff7c
  1. 1
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
  2. 74
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  3. 32
      src/python/grpcio_tests/tests_aio/unit/server_test.py

@ -32,6 +32,7 @@ cdef class RPCState(GrpcCallWrapper):
cdef bytes method(self) cdef bytes method(self)
cdef tuple invocation_metadata(self) cdef tuple invocation_metadata(self)
cdef void raise_for_termination(self) except *
cdef enum AioServerStatus: cdef enum AioServerStatus:

@ -20,7 +20,7 @@ import traceback
# TODO(https://github.com/grpc/grpc/issues/20850) refactor this. # TODO(https://github.com/grpc/grpc/issues/20850) refactor this.
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
cdef int _EMPTY_FLAG = 0 cdef int _EMPTY_FLAG = 0
# TODO(lidiz) Use a designated value other than None. cdef str _RPC_FINISHED_DETAILS = 'RPC already finished.'
cdef str _SERVER_STOPPED_DETAILS = 'Server already stopped.' cdef str _SERVER_STOPPED_DETAILS = 'Server already stopped.'
cdef class _HandlerCallDetails: cdef class _HandlerCallDetails:
@ -29,6 +29,10 @@ cdef class _HandlerCallDetails:
self.invocation_metadata = invocation_metadata self.invocation_metadata = invocation_metadata
class _ServerStoppedError(RuntimeError):
"""Raised if the server is stopped."""
cdef class RPCState: cdef class RPCState:
def __cinit__(self, AioServer server): def __cinit__(self, AioServer server):
@ -48,6 +52,23 @@ cdef class RPCState:
cdef tuple invocation_metadata(self): cdef tuple invocation_metadata(self):
return _metadata(&self.request_metadata) return _metadata(&self.request_metadata)
cdef void raise_for_termination(self) except *:
"""Raise exceptions if RPC is not running.
Server method handlers may suppress the abort exception. We need to halt
the RPC execution in that case. This function needs to be called after
running application code.
Also, the server may stop unexpected. We need to check before calling
into Core functions, otherwise, segfault.
"""
if self.abort_exception is not None:
raise self.abort_exception
if self.status_sent:
raise RuntimeError(_RPC_FINISHED_DETAILS)
if self.server._status == AIO_SERVER_STATUS_STOPPED:
raise _ServerStoppedError(_SERVER_STOPPED_DETAILS)
def __dealloc__(self): def __dealloc__(self):
"""Cleans the Core objects.""" """Cleans the Core objects."""
grpc_call_details_destroy(&self.details) grpc_call_details_destroy(&self.details)
@ -61,17 +82,6 @@ cdef class RPCState:
class AbortError(Exception): pass class AbortError(Exception): pass
def _raise_if_aborted(RPCState rpc_state):
"""Raise AbortError if RPC is aborted.
Server method handlers may suppress the abort exception. We need to halt
the RPC execution in that case. This function needs to be called after
running application code.
"""
if rpc_state.abort_exception is not None:
raise rpc_state.abort_exception
cdef class _ServicerContext: cdef class _ServicerContext:
cdef RPCState _rpc_state cdef RPCState _rpc_state
cdef object _loop cdef object _loop
@ -90,10 +100,8 @@ cdef class _ServicerContext:
async def read(self): async def read(self):
cdef bytes raw_message cdef bytes raw_message
if self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED: self._rpc_state.raise_for_termination()
raise RuntimeError(_SERVER_STOPPED_DETAILS)
if self._rpc_state.status_sent:
raise RuntimeError('RPC already finished.')
if self._rpc_state.client_closed: if self._rpc_state.client_closed:
return EOF return EOF
raw_message = await _receive_message(self._rpc_state, self._loop) raw_message = await _receive_message(self._rpc_state, self._loop)
@ -104,10 +112,8 @@ cdef class _ServicerContext:
raw_message) raw_message)
async def write(self, object message): async def write(self, object message):
if self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED: self._rpc_state.raise_for_termination()
raise RuntimeError(_SERVER_STOPPED_DETAILS)
if self._rpc_state.status_sent:
raise RuntimeError('RPC already finished.')
await _send_message(self._rpc_state, await _send_message(self._rpc_state,
serialize(self._response_serializer, message), serialize(self._response_serializer, message),
self._rpc_state.metadata_sent, self._rpc_state.metadata_sent,
@ -116,11 +122,9 @@ cdef class _ServicerContext:
self._rpc_state.metadata_sent = True self._rpc_state.metadata_sent = True
async def send_initial_metadata(self, tuple metadata): async def send_initial_metadata(self, tuple metadata):
if self._rpc_state.status_sent: self._rpc_state.raise_for_termination()
raise RuntimeError('RPC already finished.')
elif self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED: if self._rpc_state.metadata_sent:
raise RuntimeError(_SERVER_STOPPED_DETAILS)
elif self._rpc_state.metadata_sent:
raise RuntimeError('Send initial metadata failed: already sent') raise RuntimeError('Send initial metadata failed: already sent')
else: else:
await _send_initial_metadata(self._rpc_state, metadata, self._loop) await _send_initial_metadata(self._rpc_state, metadata, self._loop)
@ -191,7 +195,7 @@ async def _finish_handler_with_unary_response(RPCState rpc_state,
) )
# Raises exception if aborted # Raises exception if aborted
_raise_if_aborted(rpc_state) rpc_state.raise_for_termination()
# Serializes the response message # Serializes the response message
cdef bytes response_raw = serialize( cdef bytes response_raw = serialize(
@ -238,9 +242,6 @@ async def _finish_handler_with_stream_responses(RPCState rpc_state,
request, request,
servicer_context, servicer_context,
) )
# Raises exception if aborted
_raise_if_aborted(rpc_state)
else: else:
# The handler uses async generator API # The handler uses async generator API
async_response_generator = stream_handler( async_response_generator = stream_handler(
@ -251,15 +252,12 @@ async def _finish_handler_with_stream_responses(RPCState rpc_state,
# Consumes messages from the generator # Consumes messages from the generator
async for response_message in async_response_generator: async for response_message in async_response_generator:
# Raises exception if aborted # Raises exception if aborted
_raise_if_aborted(rpc_state) rpc_state.raise_for_termination()
if rpc_state.server._status == AIO_SERVER_STATUS_STOPPED: await servicer_context.write(response_message)
# The async generator might yield much much later after the
# server is destroied. If we proceed, Core will crash badly. # Raises exception if aborted
_LOGGER.info('Aborting RPC due to server stop.') rpc_state.raise_for_termination()
return
else:
await servicer_context.write(response_message)
# Sends the final status of this RPC # Sends the final status of this RPC
cdef SendStatusFromServerOperation op = SendStatusFromServerOperation( cdef SendStatusFromServerOperation op = SendStatusFromServerOperation(
@ -418,6 +416,8 @@ async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop):
) )
except (KeyboardInterrupt, SystemExit): except (KeyboardInterrupt, SystemExit):
raise raise
except _ServerStoppedError:
_LOGGER.info('Aborting RPC due to server stop.')
except Exception as e: except Exception as e:
_LOGGER.exception(e) _LOGGER.exception(e)
if not rpc_state.status_sent and rpc_state.server._status != AIO_SERVER_STATUS_STOPPED: if not rpc_state.status_sent and rpc_state.server._status != AIO_SERVER_STATUS_STOPPED:

@ -37,6 +37,7 @@ _STREAM_STREAM_ASYNC_GEN = '/test/StreamStreamAsyncGen'
_STREAM_STREAM_READER_WRITER = '/test/StreamStreamReaderWriter' _STREAM_STREAM_READER_WRITER = '/test/StreamStreamReaderWriter'
_STREAM_STREAM_EVILLY_MIXED = '/test/StreamStreamEvillyMixed' _STREAM_STREAM_EVILLY_MIXED = '/test/StreamStreamEvillyMixed'
_UNIMPLEMENTED_METHOD = '/test/UnimplementedMethod' _UNIMPLEMENTED_METHOD = '/test/UnimplementedMethod'
_ERROR_IN_STREAM_STREAM = '/test/ErrorInStreamStream'
_REQUEST = b'\x00\x00\x00' _REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x01\x01\x01' _RESPONSE = b'\x01\x01\x01'
@ -82,6 +83,9 @@ class _GenericHandler(grpc.GenericRpcHandler):
_STREAM_STREAM_EVILLY_MIXED: _STREAM_STREAM_EVILLY_MIXED:
grpc.stream_stream_rpc_method_handler( grpc.stream_stream_rpc_method_handler(
self._stream_stream_evilly_mixed), self._stream_stream_evilly_mixed),
_ERROR_IN_STREAM_STREAM:
grpc.stream_stream_rpc_method_handler(
self._error_in_stream_stream),
} }
@staticmethod @staticmethod
@ -158,6 +162,12 @@ class _GenericHandler(grpc.GenericRpcHandler):
for _ in range(_NUM_STREAM_RESPONSES - 1): for _ in range(_NUM_STREAM_RESPONSES - 1):
await context.write(_RESPONSE) await context.write(_RESPONSE)
async def _error_in_stream_stream(self, request_iterator, unused_context):
async for request in request_iterator:
assert _REQUEST == request
raise RuntimeError('A testing RuntimeError!')
yield _RESPONSE
def service(self, handler_details): def service(self, handler_details):
self._called.set_result(None) self._called.set_result(None)
return self._routing_table.get(handler_details.method) return self._routing_table.get(handler_details.method)
@ -401,6 +411,28 @@ class TestServer(AioTestBase):
rpc_error = exception_context.exception rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code()) self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code())
async def test_shutdown_during_stream_stream(self):
stream_stream_call = self._channel.stream_stream(
_STREAM_STREAM_ASYNC_GEN)
call = stream_stream_call()
# Don't half close the RPC yet, keep it alive.
await call.write(_REQUEST)
await self._server.stop(None)
self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
# No segfault
async def test_error_in_stream_stream(self):
stream_stream_call = self._channel.stream_stream(
_ERROR_IN_STREAM_STREAM)
call = stream_stream_call()
# Don't half close the RPC yet, keep it alive.
await call.write(_REQUEST)
# Don't segfault here
self.assertEqual(grpc.StatusCode.UNKNOWN, await call.code())
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)

Loading…
Cancel
Save