Merge pull request #19988 from gnossen/signal_handler_exception

Gracefully handle exceptions raised by signal handlers on the main thread while unary RPCs are in flight.
pull/20006/head
Richard Belleville 5 years ago committed by GitHub
commit 6d64b2df2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 49
      src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi
  2. 42
      src/python/grpcio_tests/tests/unit/_signal_client.py
  3. 26
      src/python/grpcio_tests/tests/unit/_signal_handling_test.py

@ -146,12 +146,36 @@ cdef _cancel(
cdef _next_call_event(
_ChannelState channel_state, grpc_completion_queue *c_completion_queue,
on_success, deadline):
tag, event = _latent_event(c_completion_queue, deadline)
with channel_state.condition:
on_success(tag)
channel_state.condition.notify_all()
return event
on_success, on_failure, deadline):
"""Block on the next event out of the completion queue.
On success, `on_success` will be invoked with the tag taken from the CQ.
In the case of a failure due to an exception raised in a signal handler,
`on_failure` will be invoked with no arguments. Note that this situation
can only occur on the main thread.
Args:
channel_state: The state for the channel on which the RPC is running.
c_completion_queue: The CQ which will be polled.
on_success: A callable object to be invoked upon successful receipt of a
tag from the CQ.
on_failure: A callable object to be invoked in case a Python exception is
raised from a signal handler during polling.
deadline: The point after which the RPC will time out.
"""
try:
tag, event = _latent_event(c_completion_queue, deadline)
# NOTE(rbellevi): This broad except enables us to clean up resources before
# propagating any exceptions raised by signal handlers to the application.
except:
if on_failure is not None:
on_failure()
raise
else:
with channel_state.condition:
on_success(tag)
channel_state.condition.notify_all()
return event
# TODO(https://github.com/grpc/grpc/issues/14569): This could be a lot simpler.
@ -307,8 +331,14 @@ cdef class SegregatedCall:
def on_success(tag):
_process_segregated_call_tag(
self._channel_state, self._call_state, self._c_completion_queue, tag)
def on_failure():
self._call_state.due.clear()
grpc_call_unref(self._call_state.c_call)
self._call_state.c_call = NULL
self._channel_state.segregated_call_states.remove(self._call_state)
_destroy_c_completion_queue(self._c_completion_queue)
return _next_call_event(
self._channel_state, self._c_completion_queue, on_success, None)
self._channel_state, self._c_completion_queue, on_success, on_failure, None)
cdef SegregatedCall _segregated_call(
@ -461,8 +491,11 @@ cdef class Channel:
queue_deadline = time.time() + 1.0
else:
queue_deadline = None
# NOTE(gnossen): It is acceptable for on_failure to be None here because
# failure conditions can only ever happen on the main thread and this
# method is only ever invoked on the channel spin thread.
return _next_call_event(self._state, self._state.c_call_completion_queue,
on_success, queue_deadline)
on_success, None, queue_deadline)
def segregated_call(
self, int flags, method, host, object deadline, object metadata,

@ -45,6 +45,7 @@ def handle_sigint(unused_signum, unused_frame):
if per_process_rpc_future is not None:
per_process_rpc_future.cancel()
sys.stderr.flush()
# This sys.exit(0) avoids an exception caused by the cancelled RPC.
sys.exit(0)
@ -72,13 +73,48 @@ def main_streaming(server_target):
assert False, _ASSERTION_MESSAGE
def main_unary_with_exception(server_target):
"""Initiate a unary RPC with a signal handler that will raise."""
channel = grpc.insecure_channel(server_target)
try:
channel.unary_unary(UNARY_UNARY)(_MESSAGE, wait_for_ready=True)
except KeyboardInterrupt:
sys.stderr.write("Running signal handler.\n")
sys.stderr.flush()
# This call should not hang.
channel.close()
def main_streaming_with_exception(server_target):
"""Initiate a streaming RPC with a signal handler that will raise."""
channel = grpc.insecure_channel(server_target)
try:
for _ in channel.unary_stream(UNARY_STREAM)(
_MESSAGE, wait_for_ready=True):
pass
except KeyboardInterrupt:
sys.stderr.write("Running signal handler.\n")
sys.stderr.flush()
# This call should not hang.
channel.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Signal test client.')
parser.add_argument('server', help='Server target')
parser.add_argument('arity', help='Arity', choices=('unary', 'streaming'))
parser.add_argument(
'arity', help='RPC arity', choices=('unary', 'streaming'))
'--exception',
help='Whether the signal throws an exception',
action='store_true')
args = parser.parse_args()
if args.arity == 'unary':
if args.arity == 'unary' and not args.exception:
main_unary(args.server)
else:
elif args.arity == 'streaming' and not args.exception:
main_streaming(args.server)
elif args.arity == 'unary' and args.exception:
main_unary_with_exception(args.server)
else:
main_streaming_with_exception(args.server)

@ -166,6 +166,32 @@ class SignalHandlingTest(unittest.TestCase):
self.assertIn(_signal_client.SIGTERM_MESSAGE,
client_stdout.read())
@unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows')
def testUnaryWithException(self):
server_target = '{}:{}'.format(_HOST, self._port)
with tempfile.TemporaryFile(mode='r') as client_stdout:
with tempfile.TemporaryFile(mode='r') as client_stderr:
client = _start_client(('--exception', server_target, 'unary'),
client_stdout, client_stderr)
self._handler.await_connected_client()
client.send_signal(signal.SIGINT)
client.wait()
self.assertEqual(0, client.returncode)
@unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows')
def testStreamingHandlerWithException(self):
server_target = '{}:{}'.format(_HOST, self._port)
with tempfile.TemporaryFile(mode='r') as client_stdout:
with tempfile.TemporaryFile(mode='r') as client_stderr:
client = _start_client(
('--exception', server_target, 'streaming'), client_stdout,
client_stderr)
self._handler.await_connected_client()
client.send_signal(signal.SIGINT)
client.wait()
print(_read_stream(client_stderr))
self.assertEqual(0, client.returncode)
if __name__ == '__main__':
logging.basicConfig()

Loading…
Cancel
Save