mirror of https://github.com/grpc/grpc.git
Previously, signal handlers were only given a chance to run upon receipt of an entry in the RPC stream. Since there is no time bound on how long that might take, there can be an arbitrarily long time gap between receipt of the signal and the execution of the application's signal handlers. Signal handlers are only run on the main thread. The cpython implementation takes great care to ensure that the main thread does not block for an arbitrarily long period between signal checks. Our indefinite blocking was due to wait() invocations on condition variables without a timeout. This changes all usages of wait() in the the channel implementation to use a wrapper that is responsive to signals even while waiting on an RPC. A test has been added to verify this. Tests are currently disabled under gevent due to https://github.com/grpc/grpc/issues/18980, but a fix for that has been found and should be merged shortly.pull/19481/head
parent
51d6416691
commit
af1b09f7e7
7 changed files with 393 additions and 59 deletions
@ -0,0 +1,82 @@ |
||||
# Copyright 2019 the gRPC authors. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
"""Client for testing responsiveness to signals.""" |
||||
|
||||
from __future__ import print_function |
||||
|
||||
import argparse |
||||
import functools |
||||
import logging |
||||
import signal |
||||
import sys |
||||
|
||||
import grpc |
||||
|
||||
SIGTERM_MESSAGE = "Handling sigterm!" |
||||
|
||||
UNARY_UNARY = "/test/Unary" |
||||
UNARY_STREAM = "/test/ServerStreaming" |
||||
|
||||
_MESSAGE = b'\x00\x00\x00' |
||||
|
||||
_ASSERTION_MESSAGE = "Control flow should never reach here." |
||||
|
||||
# NOTE(gnossen): We use a global variable here so that the signal handler can be |
||||
# installed before the RPC begins. If we do not do this, then we may receive the |
||||
# SIGINT before the signal handler is installed. I'm not happy with per-process |
||||
# global state, but the per-process global state that is signal handlers |
||||
# somewhat forces my hand. |
||||
per_process_rpc_future = None |
||||
|
||||
|
||||
def handle_sigint(unused_signum, unused_frame): |
||||
print(SIGTERM_MESSAGE) |
||||
if per_process_rpc_future is not None: |
||||
per_process_rpc_future.cancel() |
||||
sys.stderr.flush() |
||||
sys.exit(0) |
||||
|
||||
|
||||
def main_unary(server_target): |
||||
global per_process_rpc_future # pylint: disable=global-statement |
||||
with grpc.insecure_channel(server_target) as channel: |
||||
multicallable = channel.unary_unary(UNARY_UNARY) |
||||
signal.signal(signal.SIGINT, handle_sigint) |
||||
per_process_rpc_future = multicallable.future( |
||||
_MESSAGE, wait_for_ready=True) |
||||
result = per_process_rpc_future.result() |
||||
assert False, _ASSERTION_MESSAGE |
||||
|
||||
|
||||
def main_streaming(server_target): |
||||
global per_process_rpc_future # pylint: disable=global-statement |
||||
with grpc.insecure_channel(server_target) as channel: |
||||
signal.signal(signal.SIGINT, handle_sigint) |
||||
per_process_rpc_future = channel.unary_stream(UNARY_STREAM)( |
||||
_MESSAGE, wait_for_ready=True) |
||||
for result in per_process_rpc_future: |
||||
pass |
||||
assert False, _ASSERTION_MESSAGE |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
parser = argparse.ArgumentParser(description='Signal test client.') |
||||
parser.add_argument('server', help='Server target') |
||||
parser.add_argument( |
||||
'arity', help='RPC arity', choices=('unary', 'streaming')) |
||||
args = parser.parse_args() |
||||
if args.arity == 'unary': |
||||
main_unary(args.server) |
||||
else: |
||||
main_streaming(args.server) |
@ -0,0 +1,158 @@ |
||||
# Copyright 2019 the gRPC authors. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
"""Test of responsiveness to signals.""" |
||||
|
||||
from __future__ import print_function |
||||
|
||||
import logging |
||||
import os |
||||
import signal |
||||
import subprocess |
||||
import tempfile |
||||
import threading |
||||
import unittest |
||||
import sys |
||||
|
||||
import grpc |
||||
|
||||
from tests.unit import test_common |
||||
from tests.unit import _signal_client |
||||
|
||||
_CLIENT_PATH = os.path.abspath(os.path.realpath(_signal_client.__file__)) |
||||
_HOST = 'localhost' |
||||
|
||||
|
||||
class _GenericHandler(grpc.GenericRpcHandler): |
||||
|
||||
def __init__(self): |
||||
self._connected_clients_lock = threading.RLock() |
||||
self._connected_clients_event = threading.Event() |
||||
self._connected_clients = 0 |
||||
|
||||
self._unary_unary_handler = grpc.unary_unary_rpc_method_handler( |
||||
self._handle_unary_unary) |
||||
self._unary_stream_handler = grpc.unary_stream_rpc_method_handler( |
||||
self._handle_unary_stream) |
||||
|
||||
def _on_client_connect(self): |
||||
with self._connected_clients_lock: |
||||
self._connected_clients += 1 |
||||
self._connected_clients_event.set() |
||||
|
||||
def _on_client_disconnect(self): |
||||
with self._connected_clients_lock: |
||||
self._connected_clients -= 1 |
||||
if self._connected_clients == 0: |
||||
self._connected_clients_event.clear() |
||||
|
||||
def await_connected_client(self): |
||||
"""Blocks until a client connects to the server.""" |
||||
self._connected_clients_event.wait() |
||||
|
||||
def _handle_unary_unary(self, request, servicer_context): |
||||
"""Handles a unary RPC. |
||||
|
||||
Blocks until the client disconnects and then echoes. |
||||
""" |
||||
stop_event = threading.Event() |
||||
|
||||
def on_rpc_end(): |
||||
self._on_client_disconnect() |
||||
stop_event.set() |
||||
|
||||
servicer_context.add_callback(on_rpc_end) |
||||
self._on_client_connect() |
||||
stop_event.wait() |
||||
return request |
||||
|
||||
def _handle_unary_stream(self, request, servicer_context): |
||||
"""Handles a server streaming RPC. |
||||
|
||||
Blocks until the client disconnects and then echoes. |
||||
""" |
||||
stop_event = threading.Event() |
||||
|
||||
def on_rpc_end(): |
||||
self._on_client_disconnect() |
||||
stop_event.set() |
||||
|
||||
servicer_context.add_callback(on_rpc_end) |
||||
self._on_client_connect() |
||||
stop_event.wait() |
||||
yield request |
||||
|
||||
def service(self, handler_call_details): |
||||
if handler_call_details.method == _signal_client.UNARY_UNARY: |
||||
return self._unary_unary_handler |
||||
elif handler_call_details.method == _signal_client.UNARY_STREAM: |
||||
return self._unary_stream_handler |
||||
else: |
||||
return None |
||||
|
||||
|
||||
def _read_stream(stream): |
||||
stream.seek(0) |
||||
return stream.read() |
||||
|
||||
|
||||
class SignalHandlingTest(unittest.TestCase): |
||||
|
||||
def setUp(self): |
||||
self._server = test_common.test_server() |
||||
self._port = self._server.add_insecure_port('{}:0'.format(_HOST)) |
||||
self._handler = _GenericHandler() |
||||
self._server.add_generic_rpc_handlers((self._handler,)) |
||||
self._server.start() |
||||
|
||||
def tearDown(self): |
||||
self._server.stop(None) |
||||
|
||||
@unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') |
||||
def testUnary(self): |
||||
"""Tests that the server unary code path does not stall signal handlers.""" |
||||
server_target = '{}:{}'.format(_HOST, self._port) |
||||
with tempfile.TemporaryFile(mode='r') as client_stdout: |
||||
with tempfile.TemporaryFile(mode='r') as client_stderr: |
||||
client = subprocess.Popen( |
||||
(sys.executable, _CLIENT_PATH, server_target, 'unary'), |
||||
stdout=client_stdout, |
||||
stderr=client_stderr) |
||||
self._handler.await_connected_client() |
||||
client.send_signal(signal.SIGINT) |
||||
self.assertFalse(client.wait(), msg=_read_stream(client_stderr)) |
||||
client_stdout.seek(0) |
||||
self.assertIn(_signal_client.SIGTERM_MESSAGE, |
||||
client_stdout.read()) |
||||
|
||||
@unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') |
||||
def testStreaming(self): |
||||
"""Tests that the server streaming code path does not stall signal handlers.""" |
||||
server_target = '{}:{}'.format(_HOST, self._port) |
||||
with tempfile.TemporaryFile(mode='r') as client_stdout: |
||||
with tempfile.TemporaryFile(mode='r') as client_stderr: |
||||
client = subprocess.Popen( |
||||
(sys.executable, _CLIENT_PATH, server_target, 'streaming'), |
||||
stdout=client_stdout, |
||||
stderr=client_stderr) |
||||
self._handler.await_connected_client() |
||||
client.send_signal(signal.SIGINT) |
||||
self.assertFalse(client.wait(), msg=_read_stream(client_stderr)) |
||||
client_stdout.seek(0) |
||||
self.assertIn(_signal_client.SIGTERM_MESSAGE, |
||||
client_stdout.read()) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
logging.basicConfig() |
||||
unittest.main(verbosity=2) |
Loading…
Reference in new issue