Fixing a segfault in the server shutdown path

pull/21708/head
Lidi Zheng 5 years ago
parent b9083a9edb
commit 80d7acff7c
  1. 1
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pxd.pxi
  2. 74
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  3. 32
      src/python/grpcio_tests/tests_aio/unit/server_test.py

@ -32,6 +32,7 @@ cdef class RPCState(GrpcCallWrapper):
cdef bytes method(self)
cdef tuple invocation_metadata(self)
cdef void raise_for_termination(self) except *
cdef enum AioServerStatus:

@ -20,7 +20,7 @@ import traceback
# TODO(https://github.com/grpc/grpc/issues/20850) refactor this.
_LOGGER = logging.getLogger(__name__)
cdef int _EMPTY_FLAG = 0
# TODO(lidiz) Use a designated value other than None.
cdef str _RPC_FINISHED_DETAILS = 'RPC already finished.'
cdef str _SERVER_STOPPED_DETAILS = 'Server already stopped.'
cdef class _HandlerCallDetails:
@ -29,6 +29,10 @@ cdef class _HandlerCallDetails:
self.invocation_metadata = invocation_metadata
class _ServerStoppedError(RuntimeError):
"""Raised if the server is stopped."""
cdef class RPCState:
def __cinit__(self, AioServer server):
@ -48,6 +52,23 @@ cdef class RPCState:
cdef tuple invocation_metadata(self):
return _metadata(&self.request_metadata)
cdef void raise_for_termination(self) except *:
"""Raise exceptions if RPC is not running.
Server method handlers may suppress the abort exception. We need to halt
the RPC execution in that case. This function needs to be called after
running application code.
Also, the server may stop unexpected. We need to check before calling
into Core functions, otherwise, segfault.
"""
if self.abort_exception is not None:
raise self.abort_exception
if self.status_sent:
raise RuntimeError(_RPC_FINISHED_DETAILS)
if self.server._status == AIO_SERVER_STATUS_STOPPED:
raise _ServerStoppedError(_SERVER_STOPPED_DETAILS)
def __dealloc__(self):
"""Cleans the Core objects."""
grpc_call_details_destroy(&self.details)
@ -61,17 +82,6 @@ cdef class RPCState:
class AbortError(Exception): pass
def _raise_if_aborted(RPCState rpc_state):
"""Raise AbortError if RPC is aborted.
Server method handlers may suppress the abort exception. We need to halt
the RPC execution in that case. This function needs to be called after
running application code.
"""
if rpc_state.abort_exception is not None:
raise rpc_state.abort_exception
cdef class _ServicerContext:
cdef RPCState _rpc_state
cdef object _loop
@ -90,10 +100,8 @@ cdef class _ServicerContext:
async def read(self):
cdef bytes raw_message
if self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED:
raise RuntimeError(_SERVER_STOPPED_DETAILS)
if self._rpc_state.status_sent:
raise RuntimeError('RPC already finished.')
self._rpc_state.raise_for_termination()
if self._rpc_state.client_closed:
return EOF
raw_message = await _receive_message(self._rpc_state, self._loop)
@ -104,10 +112,8 @@ cdef class _ServicerContext:
raw_message)
async def write(self, object message):
if self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED:
raise RuntimeError(_SERVER_STOPPED_DETAILS)
if self._rpc_state.status_sent:
raise RuntimeError('RPC already finished.')
self._rpc_state.raise_for_termination()
await _send_message(self._rpc_state,
serialize(self._response_serializer, message),
self._rpc_state.metadata_sent,
@ -116,11 +122,9 @@ cdef class _ServicerContext:
self._rpc_state.metadata_sent = True
async def send_initial_metadata(self, tuple metadata):
if self._rpc_state.status_sent:
raise RuntimeError('RPC already finished.')
elif self._rpc_state.server._status == AIO_SERVER_STATUS_STOPPED:
raise RuntimeError(_SERVER_STOPPED_DETAILS)
elif self._rpc_state.metadata_sent:
self._rpc_state.raise_for_termination()
if self._rpc_state.metadata_sent:
raise RuntimeError('Send initial metadata failed: already sent')
else:
await _send_initial_metadata(self._rpc_state, metadata, self._loop)
@ -191,7 +195,7 @@ async def _finish_handler_with_unary_response(RPCState rpc_state,
)
# Raises exception if aborted
_raise_if_aborted(rpc_state)
rpc_state.raise_for_termination()
# Serializes the response message
cdef bytes response_raw = serialize(
@ -238,9 +242,6 @@ async def _finish_handler_with_stream_responses(RPCState rpc_state,
request,
servicer_context,
)
# Raises exception if aborted
_raise_if_aborted(rpc_state)
else:
# The handler uses async generator API
async_response_generator = stream_handler(
@ -251,15 +252,12 @@ async def _finish_handler_with_stream_responses(RPCState rpc_state,
# Consumes messages from the generator
async for response_message in async_response_generator:
# Raises exception if aborted
_raise_if_aborted(rpc_state)
rpc_state.raise_for_termination()
if rpc_state.server._status == AIO_SERVER_STATUS_STOPPED:
# The async generator might yield much much later after the
# server is destroied. If we proceed, Core will crash badly.
_LOGGER.info('Aborting RPC due to server stop.')
return
else:
await servicer_context.write(response_message)
await servicer_context.write(response_message)
# Raises exception if aborted
rpc_state.raise_for_termination()
# Sends the final status of this RPC
cdef SendStatusFromServerOperation op = SendStatusFromServerOperation(
@ -418,6 +416,8 @@ async def _handle_exceptions(RPCState rpc_state, object rpc_coro, object loop):
)
except (KeyboardInterrupt, SystemExit):
raise
except _ServerStoppedError:
_LOGGER.info('Aborting RPC due to server stop.')
except Exception as e:
_LOGGER.exception(e)
if not rpc_state.status_sent and rpc_state.server._status != AIO_SERVER_STATUS_STOPPED:

@ -37,6 +37,7 @@ _STREAM_STREAM_ASYNC_GEN = '/test/StreamStreamAsyncGen'
_STREAM_STREAM_READER_WRITER = '/test/StreamStreamReaderWriter'
_STREAM_STREAM_EVILLY_MIXED = '/test/StreamStreamEvillyMixed'
_UNIMPLEMENTED_METHOD = '/test/UnimplementedMethod'
_ERROR_IN_STREAM_STREAM = '/test/ErrorInStreamStream'
_REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x01\x01\x01'
@ -82,6 +83,9 @@ class _GenericHandler(grpc.GenericRpcHandler):
_STREAM_STREAM_EVILLY_MIXED:
grpc.stream_stream_rpc_method_handler(
self._stream_stream_evilly_mixed),
_ERROR_IN_STREAM_STREAM:
grpc.stream_stream_rpc_method_handler(
self._error_in_stream_stream),
}
@staticmethod
@ -158,6 +162,12 @@ class _GenericHandler(grpc.GenericRpcHandler):
for _ in range(_NUM_STREAM_RESPONSES - 1):
await context.write(_RESPONSE)
async def _error_in_stream_stream(self, request_iterator, unused_context):
async for request in request_iterator:
assert _REQUEST == request
raise RuntimeError('A testing RuntimeError!')
yield _RESPONSE
def service(self, handler_details):
self._called.set_result(None)
return self._routing_table.get(handler_details.method)
@ -401,6 +411,28 @@ class TestServer(AioTestBase):
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNIMPLEMENTED, rpc_error.code())
async def test_shutdown_during_stream_stream(self):
stream_stream_call = self._channel.stream_stream(
_STREAM_STREAM_ASYNC_GEN)
call = stream_stream_call()
# Don't half close the RPC yet, keep it alive.
await call.write(_REQUEST)
await self._server.stop(None)
self.assertEqual(grpc.StatusCode.UNAVAILABLE, await call.code())
# No segfault
async def test_error_in_stream_stream(self):
stream_stream_call = self._channel.stream_stream(
_ERROR_IN_STREAM_STREAM)
call = stream_stream_call()
# Don't half close the RPC yet, keep it alive.
await call.write(_REQUEST)
# Don't segfault here
self.assertEqual(grpc.StatusCode.UNKNOWN, await call.code())
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)

Loading…
Cancel
Save