mirror of https://github.com/grpc/grpc.git
Merge pull request #19586 from gnossen/revert_revert_signal_handling
Revert "Merge pull request #19583 from gnossen/revert_signal_handling"pull/19826/head
commit
7631e410cf
7 changed files with 409 additions and 59 deletions
@ -0,0 +1,84 @@ |
||||
# Copyright 2019 the gRPC authors. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
"""Client for testing responsiveness to signals.""" |
||||
|
||||
from __future__ import print_function |
||||
|
||||
import argparse |
||||
import functools |
||||
import logging |
||||
import signal |
||||
import sys |
||||
|
||||
import grpc |
||||
|
||||
SIGTERM_MESSAGE = "Handling sigterm!" |
||||
|
||||
UNARY_UNARY = "/test/Unary" |
||||
UNARY_STREAM = "/test/ServerStreaming" |
||||
|
||||
_MESSAGE = b'\x00\x00\x00' |
||||
|
||||
_ASSERTION_MESSAGE = "Control flow should never reach here." |
||||
|
||||
# NOTE(gnossen): We use a global variable here so that the signal handler can be |
||||
# installed before the RPC begins. If we do not do this, then we may receive the |
||||
# SIGINT before the signal handler is installed. I'm not happy with per-process |
||||
# global state, but the per-process global state that is signal handlers |
||||
# somewhat forces my hand. |
||||
per_process_rpc_future = None |
||||
|
||||
|
||||
def handle_sigint(unused_signum, unused_frame): |
||||
print(SIGTERM_MESSAGE) |
||||
if per_process_rpc_future is not None: |
||||
per_process_rpc_future.cancel() |
||||
sys.stderr.flush() |
||||
sys.exit(0) |
||||
|
||||
|
||||
def main_unary(server_target): |
||||
"""Initiate a unary RPC to be interrupted by a SIGINT.""" |
||||
global per_process_rpc_future # pylint: disable=global-statement |
||||
with grpc.insecure_channel(server_target) as channel: |
||||
multicallable = channel.unary_unary(UNARY_UNARY) |
||||
signal.signal(signal.SIGINT, handle_sigint) |
||||
per_process_rpc_future = multicallable.future( |
||||
_MESSAGE, wait_for_ready=True) |
||||
result = per_process_rpc_future.result() |
||||
assert False, _ASSERTION_MESSAGE |
||||
|
||||
|
||||
def main_streaming(server_target): |
||||
"""Initiate a streaming RPC to be interrupted by a SIGINT.""" |
||||
global per_process_rpc_future # pylint: disable=global-statement |
||||
with grpc.insecure_channel(server_target) as channel: |
||||
signal.signal(signal.SIGINT, handle_sigint) |
||||
per_process_rpc_future = channel.unary_stream(UNARY_STREAM)( |
||||
_MESSAGE, wait_for_ready=True) |
||||
for result in per_process_rpc_future: |
||||
pass |
||||
assert False, _ASSERTION_MESSAGE |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
parser = argparse.ArgumentParser(description='Signal test client.') |
||||
parser.add_argument('server', help='Server target') |
||||
parser.add_argument( |
||||
'arity', help='RPC arity', choices=('unary', 'streaming')) |
||||
args = parser.parse_args() |
||||
if args.arity == 'unary': |
||||
main_unary(args.server) |
||||
else: |
||||
main_streaming(args.server) |
@ -0,0 +1,172 @@ |
||||
# Copyright 2019 the gRPC authors. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
"""Test of responsiveness to signals.""" |
||||
|
||||
import logging |
||||
import os |
||||
import signal |
||||
import subprocess |
||||
import tempfile |
||||
import threading |
||||
import unittest |
||||
import sys |
||||
|
||||
import grpc |
||||
|
||||
from tests.unit import test_common |
||||
from tests.unit import _signal_client |
||||
|
||||
_CLIENT_PATH = None |
||||
if sys.executable is not None: |
||||
_CLIENT_PATH = os.path.abspath(os.path.realpath(_signal_client.__file__)) |
||||
else: |
||||
# NOTE(rbellevi): For compatibility with internal testing. |
||||
if len(sys.argv) != 2: |
||||
raise RuntimeError("Must supply path to executable client.") |
||||
client_name = sys.argv[1].split("/")[-1] |
||||
del sys.argv[1] # For compatibility with test runner. |
||||
_CLIENT_PATH = os.path.realpath( |
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), client_name)) |
||||
|
||||
_HOST = 'localhost' |
||||
|
||||
|
||||
class _GenericHandler(grpc.GenericRpcHandler): |
||||
|
||||
def __init__(self): |
||||
self._connected_clients_lock = threading.RLock() |
||||
self._connected_clients_event = threading.Event() |
||||
self._connected_clients = 0 |
||||
|
||||
self._unary_unary_handler = grpc.unary_unary_rpc_method_handler( |
||||
self._handle_unary_unary) |
||||
self._unary_stream_handler = grpc.unary_stream_rpc_method_handler( |
||||
self._handle_unary_stream) |
||||
|
||||
def _on_client_connect(self): |
||||
with self._connected_clients_lock: |
||||
self._connected_clients += 1 |
||||
self._connected_clients_event.set() |
||||
|
||||
def _on_client_disconnect(self): |
||||
with self._connected_clients_lock: |
||||
self._connected_clients -= 1 |
||||
if self._connected_clients == 0: |
||||
self._connected_clients_event.clear() |
||||
|
||||
def await_connected_client(self): |
||||
"""Blocks until a client connects to the server.""" |
||||
self._connected_clients_event.wait() |
||||
|
||||
def _handle_unary_unary(self, request, servicer_context): |
||||
"""Handles a unary RPC. |
||||
|
||||
Blocks until the client disconnects and then echoes. |
||||
""" |
||||
stop_event = threading.Event() |
||||
|
||||
def on_rpc_end(): |
||||
self._on_client_disconnect() |
||||
stop_event.set() |
||||
|
||||
servicer_context.add_callback(on_rpc_end) |
||||
self._on_client_connect() |
||||
stop_event.wait() |
||||
return request |
||||
|
||||
def _handle_unary_stream(self, request, servicer_context): |
||||
"""Handles a server streaming RPC. |
||||
|
||||
Blocks until the client disconnects and then echoes. |
||||
""" |
||||
stop_event = threading.Event() |
||||
|
||||
def on_rpc_end(): |
||||
self._on_client_disconnect() |
||||
stop_event.set() |
||||
|
||||
servicer_context.add_callback(on_rpc_end) |
||||
self._on_client_connect() |
||||
stop_event.wait() |
||||
yield request |
||||
|
||||
def service(self, handler_call_details): |
||||
if handler_call_details.method == _signal_client.UNARY_UNARY: |
||||
return self._unary_unary_handler |
||||
elif handler_call_details.method == _signal_client.UNARY_STREAM: |
||||
return self._unary_stream_handler |
||||
else: |
||||
return None |
||||
|
||||
|
||||
def _read_stream(stream): |
||||
stream.seek(0) |
||||
return stream.read() |
||||
|
||||
|
||||
def _start_client(args, stdout, stderr): |
||||
invocation = None |
||||
if sys.executable is not None: |
||||
invocation = (sys.executable, _CLIENT_PATH) + tuple(args) |
||||
else: |
||||
invocation = (_CLIENT_PATH,) + tuple(args) |
||||
return subprocess.Popen(invocation, stdout=stdout, stderr=stderr) |
||||
|
||||
|
||||
class SignalHandlingTest(unittest.TestCase): |
||||
|
||||
def setUp(self): |
||||
self._server = test_common.test_server() |
||||
self._port = self._server.add_insecure_port('{}:0'.format(_HOST)) |
||||
self._handler = _GenericHandler() |
||||
self._server.add_generic_rpc_handlers((self._handler,)) |
||||
self._server.start() |
||||
|
||||
def tearDown(self): |
||||
self._server.stop(None) |
||||
|
||||
@unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') |
||||
def testUnary(self): |
||||
"""Tests that the server unary code path does not stall signal handlers.""" |
||||
server_target = '{}:{}'.format(_HOST, self._port) |
||||
with tempfile.TemporaryFile(mode='r') as client_stdout: |
||||
with tempfile.TemporaryFile(mode='r') as client_stderr: |
||||
client = _start_client((server_target, 'unary'), client_stdout, |
||||
client_stderr) |
||||
self._handler.await_connected_client() |
||||
client.send_signal(signal.SIGINT) |
||||
self.assertFalse(client.wait(), msg=_read_stream(client_stderr)) |
||||
client_stdout.seek(0) |
||||
self.assertIn(_signal_client.SIGTERM_MESSAGE, |
||||
client_stdout.read()) |
||||
|
||||
@unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows') |
||||
def testStreaming(self): |
||||
"""Tests that the server streaming code path does not stall signal handlers.""" |
||||
server_target = '{}:{}'.format(_HOST, self._port) |
||||
with tempfile.TemporaryFile(mode='r') as client_stdout: |
||||
with tempfile.TemporaryFile(mode='r') as client_stderr: |
||||
client = _start_client((server_target, 'streaming'), |
||||
client_stdout, client_stderr) |
||||
self._handler.await_connected_client() |
||||
client.send_signal(signal.SIGINT) |
||||
self.assertFalse(client.wait(), msg=_read_stream(client_stderr)) |
||||
client_stdout.seek(0) |
||||
self.assertIn(_signal_client.SIGTERM_MESSAGE, |
||||
client_stdout.read()) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
logging.basicConfig() |
||||
unittest.main(verbosity=2) |
Loading…
Reference in new issue