diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi index ca637094353..1799780fce4 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/channel.pyx.pxi @@ -146,12 +146,36 @@ cdef _cancel( cdef _next_call_event( _ChannelState channel_state, grpc_completion_queue *c_completion_queue, - on_success, deadline): - tag, event = _latent_event(c_completion_queue, deadline) - with channel_state.condition: - on_success(tag) - channel_state.condition.notify_all() - return event + on_success, on_failure, deadline): + """Block on the next event out of the completion queue. + + On success, `on_success` will be invoked with the tag taken from the CQ. + In the case of a failure due to an exception raised in a signal handler, + `on_failure` will be invoked with no arguments. Note that this situation + can only occur on the main thread. + + Args: + channel_state: The state for the channel on which the RPC is running. + c_completion_queue: The CQ which will be polled. + on_success: A callable object to be invoked upon successful receipt of a + tag from the CQ. + on_failure: A callable object to be invoked in case a Python exception is + raised from a signal handler during polling. + deadline: The point after which the RPC will time out. + """ + try: + tag, event = _latent_event(c_completion_queue, deadline) + # NOTE(rbellevi): This broad except enables us to clean up resources before + # propagating any exceptions raised by signal handlers to the application. + except: + if on_failure is not None: + on_failure() + raise + else: + with channel_state.condition: + on_success(tag) + channel_state.condition.notify_all() + return event # TODO(https://github.com/grpc/grpc/issues/14569): This could be a lot simpler. @@ -307,8 +331,14 @@ cdef class SegregatedCall: def on_success(tag): _process_segregated_call_tag( self._channel_state, self._call_state, self._c_completion_queue, tag) + def on_failure(): + self._call_state.due.clear() + grpc_call_unref(self._call_state.c_call) + self._call_state.c_call = NULL + self._channel_state.segregated_call_states.remove(self._call_state) + _destroy_c_completion_queue(self._c_completion_queue) return _next_call_event( - self._channel_state, self._c_completion_queue, on_success, None) + self._channel_state, self._c_completion_queue, on_success, on_failure, None) cdef SegregatedCall _segregated_call( @@ -461,8 +491,11 @@ cdef class Channel: queue_deadline = time.time() + 1.0 else: queue_deadline = None + # NOTE(gnossen): It is acceptable for on_failure to be None here because + # failure conditions can only ever happen on the main thread and this + # method is only ever invoked on the channel spin thread. return _next_call_event(self._state, self._state.c_call_completion_queue, - on_success, queue_deadline) + on_success, None, queue_deadline) def segregated_call( self, int flags, method, host, object deadline, object metadata, diff --git a/src/python/grpcio_tests/tests/unit/_signal_client.py b/src/python/grpcio_tests/tests/unit/_signal_client.py index 65ddd6d858e..075fe7f7177 100644 --- a/src/python/grpcio_tests/tests/unit/_signal_client.py +++ b/src/python/grpcio_tests/tests/unit/_signal_client.py @@ -45,6 +45,7 @@ def handle_sigint(unused_signum, unused_frame): if per_process_rpc_future is not None: per_process_rpc_future.cancel() sys.stderr.flush() + # This sys.exit(0) avoids an exception caused by the cancelled RPC. sys.exit(0) @@ -72,13 +73,48 @@ def main_streaming(server_target): assert False, _ASSERTION_MESSAGE +def main_unary_with_exception(server_target): + """Initiate a unary RPC with a signal handler that will raise.""" + channel = grpc.insecure_channel(server_target) + try: + channel.unary_unary(UNARY_UNARY)(_MESSAGE, wait_for_ready=True) + except KeyboardInterrupt: + sys.stderr.write("Running signal handler.\n") + sys.stderr.flush() + + # This call should not hang. + channel.close() + + +def main_streaming_with_exception(server_target): + """Initiate a streaming RPC with a signal handler that will raise.""" + channel = grpc.insecure_channel(server_target) + try: + for _ in channel.unary_stream(UNARY_STREAM)( + _MESSAGE, wait_for_ready=True): + pass + except KeyboardInterrupt: + sys.stderr.write("Running signal handler.\n") + sys.stderr.flush() + + # This call should not hang. + channel.close() + + if __name__ == '__main__': parser = argparse.ArgumentParser(description='Signal test client.') parser.add_argument('server', help='Server target') + parser.add_argument('arity', help='Arity', choices=('unary', 'streaming')) parser.add_argument( - 'arity', help='RPC arity', choices=('unary', 'streaming')) + '--exception', + help='Whether the signal throws an exception', + action='store_true') args = parser.parse_args() - if args.arity == 'unary': + if args.arity == 'unary' and not args.exception: main_unary(args.server) - else: + elif args.arity == 'streaming' and not args.exception: main_streaming(args.server) + elif args.arity == 'unary' and args.exception: + main_unary_with_exception(args.server) + else: + main_streaming_with_exception(args.server) diff --git a/src/python/grpcio_tests/tests/unit/_signal_handling_test.py b/src/python/grpcio_tests/tests/unit/_signal_handling_test.py index 8ef156c596d..6f81e0b2d34 100644 --- a/src/python/grpcio_tests/tests/unit/_signal_handling_test.py +++ b/src/python/grpcio_tests/tests/unit/_signal_handling_test.py @@ -166,6 +166,32 @@ class SignalHandlingTest(unittest.TestCase): self.assertIn(_signal_client.SIGTERM_MESSAGE, client_stdout.read()) + @unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') + def testUnaryWithException(self): + server_target = '{}:{}'.format(_HOST, self._port) + with tempfile.TemporaryFile(mode='r') as client_stdout: + with tempfile.TemporaryFile(mode='r') as client_stderr: + client = _start_client(('--exception', server_target, 'unary'), + client_stdout, client_stderr) + self._handler.await_connected_client() + client.send_signal(signal.SIGINT) + client.wait() + self.assertEqual(0, client.returncode) + + @unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') + def testStreamingHandlerWithException(self): + server_target = '{}:{}'.format(_HOST, self._port) + with tempfile.TemporaryFile(mode='r') as client_stdout: + with tempfile.TemporaryFile(mode='r') as client_stderr: + client = _start_client( + ('--exception', server_target, 'streaming'), client_stdout, + client_stderr) + self._handler.await_connected_client() + client.send_signal(signal.SIGINT) + client.wait() + print(_read_stream(client_stderr)) + self.assertEqual(0, client.returncode) + if __name__ == '__main__': logging.basicConfig()