mirror of https://github.com/grpc/grpc.git
Merge pull request #19586 from gnossen/revert_revert_signal_handling
Revert "Merge pull request #19583 from gnossen/revert_signal_handling"pull/19826/head
commit
7631e410cf
7 changed files with 409 additions and 59 deletions
@ -0,0 +1,84 @@ |
|||||||
|
# Copyright 2019 the gRPC authors. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
"""Client for testing responsiveness to signals.""" |
||||||
|
|
||||||
|
from __future__ import print_function |
||||||
|
|
||||||
|
import argparse |
||||||
|
import functools |
||||||
|
import logging |
||||||
|
import signal |
||||||
|
import sys |
||||||
|
|
||||||
|
import grpc |
||||||
|
|
||||||
|
SIGTERM_MESSAGE = "Handling sigterm!" |
||||||
|
|
||||||
|
UNARY_UNARY = "/test/Unary" |
||||||
|
UNARY_STREAM = "/test/ServerStreaming" |
||||||
|
|
||||||
|
_MESSAGE = b'\x00\x00\x00' |
||||||
|
|
||||||
|
_ASSERTION_MESSAGE = "Control flow should never reach here." |
||||||
|
|
||||||
|
# NOTE(gnossen): We use a global variable here so that the signal handler can be |
||||||
|
# installed before the RPC begins. If we do not do this, then we may receive the |
||||||
|
# SIGINT before the signal handler is installed. I'm not happy with per-process |
||||||
|
# global state, but the per-process global state that is signal handlers |
||||||
|
# somewhat forces my hand. |
||||||
|
per_process_rpc_future = None |
||||||
|
|
||||||
|
|
||||||
|
def handle_sigint(unused_signum, unused_frame): |
||||||
|
print(SIGTERM_MESSAGE) |
||||||
|
if per_process_rpc_future is not None: |
||||||
|
per_process_rpc_future.cancel() |
||||||
|
sys.stderr.flush() |
||||||
|
sys.exit(0) |
||||||
|
|
||||||
|
|
||||||
|
def main_unary(server_target): |
||||||
|
"""Initiate a unary RPC to be interrupted by a SIGINT.""" |
||||||
|
global per_process_rpc_future # pylint: disable=global-statement |
||||||
|
with grpc.insecure_channel(server_target) as channel: |
||||||
|
multicallable = channel.unary_unary(UNARY_UNARY) |
||||||
|
signal.signal(signal.SIGINT, handle_sigint) |
||||||
|
per_process_rpc_future = multicallable.future( |
||||||
|
_MESSAGE, wait_for_ready=True) |
||||||
|
result = per_process_rpc_future.result() |
||||||
|
assert False, _ASSERTION_MESSAGE |
||||||
|
|
||||||
|
|
||||||
|
def main_streaming(server_target): |
||||||
|
"""Initiate a streaming RPC to be interrupted by a SIGINT.""" |
||||||
|
global per_process_rpc_future # pylint: disable=global-statement |
||||||
|
with grpc.insecure_channel(server_target) as channel: |
||||||
|
signal.signal(signal.SIGINT, handle_sigint) |
||||||
|
per_process_rpc_future = channel.unary_stream(UNARY_STREAM)( |
||||||
|
_MESSAGE, wait_for_ready=True) |
||||||
|
for result in per_process_rpc_future: |
||||||
|
pass |
||||||
|
assert False, _ASSERTION_MESSAGE |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
parser = argparse.ArgumentParser(description='Signal test client.') |
||||||
|
parser.add_argument('server', help='Server target') |
||||||
|
parser.add_argument( |
||||||
|
'arity', help='RPC arity', choices=('unary', 'streaming')) |
||||||
|
args = parser.parse_args() |
||||||
|
if args.arity == 'unary': |
||||||
|
main_unary(args.server) |
||||||
|
else: |
||||||
|
main_streaming(args.server) |
@ -0,0 +1,172 @@ |
|||||||
|
# Copyright 2019 the gRPC authors. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
"""Test of responsiveness to signals.""" |
||||||
|
|
||||||
|
import logging |
||||||
|
import os |
||||||
|
import signal |
||||||
|
import subprocess |
||||||
|
import tempfile |
||||||
|
import threading |
||||||
|
import unittest |
||||||
|
import sys |
||||||
|
|
||||||
|
import grpc |
||||||
|
|
||||||
|
from tests.unit import test_common |
||||||
|
from tests.unit import _signal_client |
||||||
|
|
||||||
|
_CLIENT_PATH = None |
||||||
|
if sys.executable is not None: |
||||||
|
_CLIENT_PATH = os.path.abspath(os.path.realpath(_signal_client.__file__)) |
||||||
|
else: |
||||||
|
# NOTE(rbellevi): For compatibility with internal testing. |
||||||
|
if len(sys.argv) != 2: |
||||||
|
raise RuntimeError("Must supply path to executable client.") |
||||||
|
client_name = sys.argv[1].split("/")[-1] |
||||||
|
del sys.argv[1] # For compatibility with test runner. |
||||||
|
_CLIENT_PATH = os.path.realpath( |
||||||
|
os.path.join(os.path.dirname(os.path.abspath(__file__)), client_name)) |
||||||
|
|
||||||
|
_HOST = 'localhost' |
||||||
|
|
||||||
|
|
||||||
|
class _GenericHandler(grpc.GenericRpcHandler): |
||||||
|
|
||||||
|
def __init__(self): |
||||||
|
self._connected_clients_lock = threading.RLock() |
||||||
|
self._connected_clients_event = threading.Event() |
||||||
|
self._connected_clients = 0 |
||||||
|
|
||||||
|
self._unary_unary_handler = grpc.unary_unary_rpc_method_handler( |
||||||
|
self._handle_unary_unary) |
||||||
|
self._unary_stream_handler = grpc.unary_stream_rpc_method_handler( |
||||||
|
self._handle_unary_stream) |
||||||
|
|
||||||
|
def _on_client_connect(self): |
||||||
|
with self._connected_clients_lock: |
||||||
|
self._connected_clients += 1 |
||||||
|
self._connected_clients_event.set() |
||||||
|
|
||||||
|
def _on_client_disconnect(self): |
||||||
|
with self._connected_clients_lock: |
||||||
|
self._connected_clients -= 1 |
||||||
|
if self._connected_clients == 0: |
||||||
|
self._connected_clients_event.clear() |
||||||
|
|
||||||
|
def await_connected_client(self): |
||||||
|
"""Blocks until a client connects to the server.""" |
||||||
|
self._connected_clients_event.wait() |
||||||
|
|
||||||
|
def _handle_unary_unary(self, request, servicer_context): |
||||||
|
"""Handles a unary RPC. |
||||||
|
|
||||||
|
Blocks until the client disconnects and then echoes. |
||||||
|
""" |
||||||
|
stop_event = threading.Event() |
||||||
|
|
||||||
|
def on_rpc_end(): |
||||||
|
self._on_client_disconnect() |
||||||
|
stop_event.set() |
||||||
|
|
||||||
|
servicer_context.add_callback(on_rpc_end) |
||||||
|
self._on_client_connect() |
||||||
|
stop_event.wait() |
||||||
|
return request |
||||||
|
|
||||||
|
def _handle_unary_stream(self, request, servicer_context): |
||||||
|
"""Handles a server streaming RPC. |
||||||
|
|
||||||
|
Blocks until the client disconnects and then echoes. |
||||||
|
""" |
||||||
|
stop_event = threading.Event() |
||||||
|
|
||||||
|
def on_rpc_end(): |
||||||
|
self._on_client_disconnect() |
||||||
|
stop_event.set() |
||||||
|
|
||||||
|
servicer_context.add_callback(on_rpc_end) |
||||||
|
self._on_client_connect() |
||||||
|
stop_event.wait() |
||||||
|
yield request |
||||||
|
|
||||||
|
def service(self, handler_call_details): |
||||||
|
if handler_call_details.method == _signal_client.UNARY_UNARY: |
||||||
|
return self._unary_unary_handler |
||||||
|
elif handler_call_details.method == _signal_client.UNARY_STREAM: |
||||||
|
return self._unary_stream_handler |
||||||
|
else: |
||||||
|
return None |
||||||
|
|
||||||
|
|
||||||
|
def _read_stream(stream): |
||||||
|
stream.seek(0) |
||||||
|
return stream.read() |
||||||
|
|
||||||
|
|
||||||
|
def _start_client(args, stdout, stderr): |
||||||
|
invocation = None |
||||||
|
if sys.executable is not None: |
||||||
|
invocation = (sys.executable, _CLIENT_PATH) + tuple(args) |
||||||
|
else: |
||||||
|
invocation = (_CLIENT_PATH,) + tuple(args) |
||||||
|
return subprocess.Popen(invocation, stdout=stdout, stderr=stderr) |
||||||
|
|
||||||
|
|
||||||
|
class SignalHandlingTest(unittest.TestCase): |
||||||
|
|
||||||
|
def setUp(self): |
||||||
|
self._server = test_common.test_server() |
||||||
|
self._port = self._server.add_insecure_port('{}:0'.format(_HOST)) |
||||||
|
self._handler = _GenericHandler() |
||||||
|
self._server.add_generic_rpc_handlers((self._handler,)) |
||||||
|
self._server.start() |
||||||
|
|
||||||
|
def tearDown(self): |
||||||
|
self._server.stop(None) |
||||||
|
|
||||||
|
@unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') |
||||||
|
def testUnary(self): |
||||||
|
"""Tests that the server unary code path does not stall signal handlers.""" |
||||||
|
server_target = '{}:{}'.format(_HOST, self._port) |
||||||
|
with tempfile.TemporaryFile(mode='r') as client_stdout: |
||||||
|
with tempfile.TemporaryFile(mode='r') as client_stderr: |
||||||
|
client = _start_client((server_target, 'unary'), client_stdout, |
||||||
|
client_stderr) |
||||||
|
self._handler.await_connected_client() |
||||||
|
client.send_signal(signal.SIGINT) |
||||||
|
self.assertFalse(client.wait(), msg=_read_stream(client_stderr)) |
||||||
|
client_stdout.seek(0) |
||||||
|
self.assertIn(_signal_client.SIGTERM_MESSAGE, |
||||||
|
client_stdout.read()) |
||||||
|
|
||||||
|
@unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') |
||||||
|
def testStreaming(self): |
||||||
|
"""Tests that the server streaming code path does not stall signal handlers.""" |
||||||
|
server_target = '{}:{}'.format(_HOST, self._port) |
||||||
|
with tempfile.TemporaryFile(mode='r') as client_stdout: |
||||||
|
with tempfile.TemporaryFile(mode='r') as client_stderr: |
||||||
|
client = _start_client((server_target, 'streaming'), |
||||||
|
client_stdout, client_stderr) |
||||||
|
self._handler.await_connected_client() |
||||||
|
client.send_signal(signal.SIGINT) |
||||||
|
self.assertFalse(client.wait(), msg=_read_stream(client_stderr)) |
||||||
|
client_stdout.seek(0) |
||||||
|
self.assertIn(_signal_client.SIGTERM_MESSAGE, |
||||||
|
client_stdout.read()) |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__': |
||||||
|
logging.basicConfig() |
||||||
|
unittest.main(verbosity=2) |
Loading…
Reference in new issue