Merge pull request #19586 from gnossen/revert_revert_signal_handling

Revert "Merge pull request #19583 from gnossen/revert_signal_handling"
pull/19826/head
Lidi Zheng 5 years ago committed by GitHub
commit 7631e410cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 131
      src/python/grpcio/grpc/_channel.py
  2. 50
      src/python/grpcio/grpc/_common.py
  3. 2
      src/python/grpcio_tests/commands.py
  4. 1
      src/python/grpcio_tests/tests/tests.json
  5. 8
      src/python/grpcio_tests/tests/unit/BUILD.bazel
  6. 84
      src/python/grpcio_tests/tests/unit/_signal_client.py
  7. 172
      src/python/grpcio_tests/tests/unit/_signal_handling_test.py

@ -13,6 +13,7 @@
# limitations under the License.
"""Invocation-side implementation of gRPC Python."""
import functools
import logging
import sys
import threading
@ -81,17 +82,6 @@ def _unknown_code_details(unknown_cygrpc_code, details):
unknown_cygrpc_code, details)
def _wait_once_until(condition, until):
if until is None:
condition.wait()
else:
remaining = until - time.time()
if remaining < 0:
raise grpc.FutureTimeoutError()
else:
condition.wait(timeout=remaining)
class _RPCState(object):
def __init__(self, due, initial_metadata, trailing_metadata, code, details):
@ -178,12 +168,11 @@ def _event_handler(state, response_deserializer):
#pylint: disable=too-many-statements
def _consume_request_iterator(request_iterator, state, call, request_serializer,
event_handler):
if cygrpc.is_fork_support_enabled():
condition_wait_timeout = 1.0
else:
condition_wait_timeout = None
"""Consume a request iterator supplied by the user."""
def consume_request_iterator(): # pylint: disable=too-many-branches
# Iterate over the request iterator until it is exhausted or an error
# condition is encountered.
while True:
return_from_user_request_generator_invoked = False
try:
@ -224,13 +213,18 @@ def _consume_request_iterator(request_iterator, state, call, request_serializer,
state.due.add(cygrpc.OperationType.send_message)
else:
return
while True:
state.condition.wait(condition_wait_timeout)
cygrpc.block_if_fork_in_progress(state)
if state.code is None:
if cygrpc.OperationType.send_message not in state.due:
break
else:
def _done():
return (state.code is not None or
cygrpc.OperationType.send_message not in
state.due)
_common.wait(
state.condition.wait,
_done,
spin_cb=functools.partial(
cygrpc.block_if_fork_in_progress, state))
if state.code is not None:
return
else:
return
@ -281,13 +275,21 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too
with self._state.condition:
return self._state.code is not None
def _is_complete(self):
return self._state.code is not None
def result(self, timeout=None):
until = None if timeout is None else time.time() + timeout
"""Returns the result of the computation or raises its exception.
See grpc.Future.result for the full API contract.
"""
with self._state.condition:
while True:
if self._state.code is None:
_wait_once_until(self._state.condition, until)
elif self._state.code is grpc.StatusCode.OK:
timed_out = _common.wait(
self._state.condition.wait, self._is_complete, timeout=timeout)
if timed_out:
raise grpc.FutureTimeoutError()
else:
if self._state.code is grpc.StatusCode.OK:
return self._state.response
elif self._state.cancelled:
raise grpc.FutureCancelledError()
@ -295,12 +297,17 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too
raise self
def exception(self, timeout=None):
until = None if timeout is None else time.time() + timeout
"""Return the exception raised by the computation.
See grpc.Future.exception for the full API contract.
"""
with self._state.condition:
while True:
if self._state.code is None:
_wait_once_until(self._state.condition, until)
elif self._state.code is grpc.StatusCode.OK:
timed_out = _common.wait(
self._state.condition.wait, self._is_complete, timeout=timeout)
if timed_out:
raise grpc.FutureTimeoutError()
else:
if self._state.code is grpc.StatusCode.OK:
return None
elif self._state.cancelled:
raise grpc.FutureCancelledError()
@ -308,12 +315,17 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too
return self
def traceback(self, timeout=None):
until = None if timeout is None else time.time() + timeout
"""Access the traceback of the exception raised by the computation.
See grpc.future.traceback for the full API contract.
"""
with self._state.condition:
while True:
if self._state.code is None:
_wait_once_until(self._state.condition, until)
elif self._state.code is grpc.StatusCode.OK:
timed_out = _common.wait(
self._state.condition.wait, self._is_complete, timeout=timeout)
if timed_out:
raise grpc.FutureTimeoutError()
else:
if self._state.code is grpc.StatusCode.OK:
return None
elif self._state.cancelled:
raise grpc.FutureCancelledError()
@ -345,8 +357,14 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too
raise StopIteration()
else:
raise self
while True:
self._state.condition.wait()
def _response_ready():
return (
self._state.response is not None or
(cygrpc.OperationType.receive_message not in self._state.due
and self._state.code is not None))
_common.wait(self._state.condition.wait, _response_ready)
if self._state.response is not None:
response = self._state.response
self._state.response = None
@ -386,32 +404,47 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too
def initial_metadata(self):
with self._state.condition:
while self._state.initial_metadata is None:
self._state.condition.wait()
def _done():
return self._state.initial_metadata is not None
_common.wait(self._state.condition.wait, _done)
return self._state.initial_metadata
def trailing_metadata(self):
with self._state.condition:
while self._state.trailing_metadata is None:
self._state.condition.wait()
def _done():
return self._state.trailing_metadata is not None
_common.wait(self._state.condition.wait, _done)
return self._state.trailing_metadata
def code(self):
with self._state.condition:
while self._state.code is None:
self._state.condition.wait()
def _done():
return self._state.code is not None
_common.wait(self._state.condition.wait, _done)
return self._state.code
def details(self):
with self._state.condition:
while self._state.details is None:
self._state.condition.wait()
def _done():
return self._state.details is not None
_common.wait(self._state.condition.wait, _done)
return _common.decode(self._state.details)
def debug_error_string(self):
with self._state.condition:
while self._state.debug_error_string is None:
self._state.condition.wait()
def _done():
return self._state.debug_error_string is not None
_common.wait(self._state.condition.wait, _done)
return _common.decode(self._state.debug_error_string)
def _repr(self):

@ -15,6 +15,7 @@
import logging
import time
import six
import grpc
@ -60,6 +61,8 @@ STATUS_CODE_TO_CYGRPC_STATUS_CODE = {
CYGRPC_STATUS_CODE_TO_STATUS_CODE)
}
MAXIMUM_WAIT_TIMEOUT = 0.1
def encode(s):
if isinstance(s, bytes):
@ -96,3 +99,50 @@ def deserialize(serialized_message, deserializer):
def fully_qualified_method(group, method):
return '/{}/{}'.format(group, method)
def _wait_once(wait_fn, timeout, spin_cb):
wait_fn(timeout=timeout)
if spin_cb is not None:
spin_cb()
def wait(wait_fn, wait_complete_fn, timeout=None, spin_cb=None):
"""Blocks waiting for an event without blocking the thread indefinitely.
See https://github.com/grpc/grpc/issues/19464 for full context. CPython's
`threading.Event.wait` and `threading.Condition.wait` methods, if invoked
without a timeout kwarg, may block the calling thread indefinitely. If the
call is made from the main thread, this means that signal handlers may not
run for an arbitrarily long period of time.
This wrapper calls the supplied wait function with an arbitrary short
timeout to ensure that no signal handler has to wait longer than
MAXIMUM_WAIT_TIMEOUT before executing.
Args:
wait_fn: A callable acceptable a single float-valued kwarg named
`timeout`. This function is expected to be one of `threading.Event.wait`
or `threading.Condition.wait`.
wait_complete_fn: A callable taking no arguments and returning a bool.
When this function returns true, it indicates that waiting should cease.
timeout: An optional float-valued number of seconds after which the wait
should cease.
spin_cb: An optional Callable taking no arguments and returning nothing.
This callback will be called on each iteration of the spin. This may be
used for, e.g. work related to forking.
Returns:
True if a timeout was supplied and it was reached. False otherwise.
"""
if timeout is None:
while not wait_complete_fn():
_wait_once(wait_fn, MAXIMUM_WAIT_TIMEOUT, spin_cb)
else:
end = time.time() + timeout
while not wait_complete_fn():
remaining = min(end - time.time(), MAXIMUM_WAIT_TIMEOUT)
if remaining < 0:
return True
_wait_once(wait_fn, remaining, spin_cb)
return False

@ -145,6 +145,8 @@ class TestGevent(setuptools.Command):
'unit._exit_test.ExitTest.test_in_flight_partial_unary_stream_call',
'unit._exit_test.ExitTest.test_in_flight_partial_stream_unary_call',
'unit._exit_test.ExitTest.test_in_flight_partial_stream_stream_call',
# TODO(https://github.com/grpc/grpc/issues/18980): Reenable.
'unit._signal_handling_test.SignalHandlingTest',
'unit._metadata_flags_test',
'health_check._health_servicer_test.HealthServicerTest.test_cancelled_watch_removed_from_watch_list',
# TODO(https://github.com/grpc/grpc/issues/17330) enable these three tests

@ -67,6 +67,7 @@
"unit._server_ssl_cert_config_test.ServerSSLCertReloadTestWithoutClientAuth",
"unit._server_test.ServerTest",
"unit._session_cache_test.SSLSessionCacheTest",
"unit._signal_handling_test.SignalHandlingTest",
"unit._version_test.VersionTest",
"unit.beta._beta_features_test.BetaFeaturesTest",
"unit.beta._beta_features_test.ContextManagementAndLifecycleTest",

@ -16,6 +16,7 @@ GRPCIO_TESTS_UNIT = [
"_credentials_test.py",
"_dns_resolver_test.py",
"_empty_message_test.py",
"_error_message_encoding_test.py",
"_exit_test.py",
"_interceptor_test.py",
"_invalid_metadata_test.py",
@ -27,6 +28,7 @@ GRPCIO_TESTS_UNIT = [
# "_reconnect_test.py",
"_resource_exhausted_test.py",
"_rpc_test.py",
"_signal_handling_test.py",
# TODO(ghostwriternr): To be added later.
# "_server_ssl_cert_config_test.py",
"_server_test.py",
@ -39,6 +41,11 @@ py_library(
srcs = ["_tcp_proxy.py"],
)
py_library(
name = "_signal_client",
srcs = ["_signal_client.py"],
)
py_library(
name = "resources",
srcs = ["resources.py"],
@ -87,6 +94,7 @@ py_library(
":_server_shutdown_scenarios",
":_from_grpc_import_star",
":_tcp_proxy",
":_signal_client",
"//src/python/grpcio_tests/tests/unit/framework/common",
"//src/python/grpcio_tests/tests/testing",
requirement('six'),

@ -0,0 +1,84 @@
# Copyright 2019 the gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Client for testing responsiveness to signals."""
from __future__ import print_function
import argparse
import functools
import logging
import signal
import sys
import grpc
SIGTERM_MESSAGE = "Handling sigterm!"
UNARY_UNARY = "/test/Unary"
UNARY_STREAM = "/test/ServerStreaming"
_MESSAGE = b'\x00\x00\x00'
_ASSERTION_MESSAGE = "Control flow should never reach here."
# NOTE(gnossen): We use a global variable here so that the signal handler can be
# installed before the RPC begins. If we do not do this, then we may receive the
# SIGINT before the signal handler is installed. I'm not happy with per-process
# global state, but the per-process global state that is signal handlers
# somewhat forces my hand.
per_process_rpc_future = None
def handle_sigint(unused_signum, unused_frame):
print(SIGTERM_MESSAGE)
if per_process_rpc_future is not None:
per_process_rpc_future.cancel()
sys.stderr.flush()
sys.exit(0)
def main_unary(server_target):
"""Initiate a unary RPC to be interrupted by a SIGINT."""
global per_process_rpc_future # pylint: disable=global-statement
with grpc.insecure_channel(server_target) as channel:
multicallable = channel.unary_unary(UNARY_UNARY)
signal.signal(signal.SIGINT, handle_sigint)
per_process_rpc_future = multicallable.future(
_MESSAGE, wait_for_ready=True)
result = per_process_rpc_future.result()
assert False, _ASSERTION_MESSAGE
def main_streaming(server_target):
"""Initiate a streaming RPC to be interrupted by a SIGINT."""
global per_process_rpc_future # pylint: disable=global-statement
with grpc.insecure_channel(server_target) as channel:
signal.signal(signal.SIGINT, handle_sigint)
per_process_rpc_future = channel.unary_stream(UNARY_STREAM)(
_MESSAGE, wait_for_ready=True)
for result in per_process_rpc_future:
pass
assert False, _ASSERTION_MESSAGE
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Signal test client.')
parser.add_argument('server', help='Server target')
parser.add_argument(
'arity', help='RPC arity', choices=('unary', 'streaming'))
args = parser.parse_args()
if args.arity == 'unary':
main_unary(args.server)
else:
main_streaming(args.server)

@ -0,0 +1,172 @@
# Copyright 2019 the gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test of responsiveness to signals."""
import logging
import os
import signal
import subprocess
import tempfile
import threading
import unittest
import sys
import grpc
from tests.unit import test_common
from tests.unit import _signal_client
_CLIENT_PATH = None
if sys.executable is not None:
_CLIENT_PATH = os.path.abspath(os.path.realpath(_signal_client.__file__))
else:
# NOTE(rbellevi): For compatibility with internal testing.
if len(sys.argv) != 2:
raise RuntimeError("Must supply path to executable client.")
client_name = sys.argv[1].split("/")[-1]
del sys.argv[1] # For compatibility with test runner.
_CLIENT_PATH = os.path.realpath(
os.path.join(os.path.dirname(os.path.abspath(__file__)), client_name))
_HOST = 'localhost'
class _GenericHandler(grpc.GenericRpcHandler):
def __init__(self):
self._connected_clients_lock = threading.RLock()
self._connected_clients_event = threading.Event()
self._connected_clients = 0
self._unary_unary_handler = grpc.unary_unary_rpc_method_handler(
self._handle_unary_unary)
self._unary_stream_handler = grpc.unary_stream_rpc_method_handler(
self._handle_unary_stream)
def _on_client_connect(self):
with self._connected_clients_lock:
self._connected_clients += 1
self._connected_clients_event.set()
def _on_client_disconnect(self):
with self._connected_clients_lock:
self._connected_clients -= 1
if self._connected_clients == 0:
self._connected_clients_event.clear()
def await_connected_client(self):
"""Blocks until a client connects to the server."""
self._connected_clients_event.wait()
def _handle_unary_unary(self, request, servicer_context):
"""Handles a unary RPC.
Blocks until the client disconnects and then echoes.
"""
stop_event = threading.Event()
def on_rpc_end():
self._on_client_disconnect()
stop_event.set()
servicer_context.add_callback(on_rpc_end)
self._on_client_connect()
stop_event.wait()
return request
def _handle_unary_stream(self, request, servicer_context):
"""Handles a server streaming RPC.
Blocks until the client disconnects and then echoes.
"""
stop_event = threading.Event()
def on_rpc_end():
self._on_client_disconnect()
stop_event.set()
servicer_context.add_callback(on_rpc_end)
self._on_client_connect()
stop_event.wait()
yield request
def service(self, handler_call_details):
if handler_call_details.method == _signal_client.UNARY_UNARY:
return self._unary_unary_handler
elif handler_call_details.method == _signal_client.UNARY_STREAM:
return self._unary_stream_handler
else:
return None
def _read_stream(stream):
stream.seek(0)
return stream.read()
def _start_client(args, stdout, stderr):
invocation = None
if sys.executable is not None:
invocation = (sys.executable, _CLIENT_PATH) + tuple(args)
else:
invocation = (_CLIENT_PATH,) + tuple(args)
return subprocess.Popen(invocation, stdout=stdout, stderr=stderr)
class SignalHandlingTest(unittest.TestCase):
def setUp(self):
self._server = test_common.test_server()
self._port = self._server.add_insecure_port('{}:0'.format(_HOST))
self._handler = _GenericHandler()
self._server.add_generic_rpc_handlers((self._handler,))
self._server.start()
def tearDown(self):
self._server.stop(None)
@unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows')
def testUnary(self):
"""Tests that the server unary code path does not stall signal handlers."""
server_target = '{}:{}'.format(_HOST, self._port)
with tempfile.TemporaryFile(mode='r') as client_stdout:
with tempfile.TemporaryFile(mode='r') as client_stderr:
client = _start_client((server_target, 'unary'), client_stdout,
client_stderr)
self._handler.await_connected_client()
client.send_signal(signal.SIGINT)
self.assertFalse(client.wait(), msg=_read_stream(client_stderr))
client_stdout.seek(0)
self.assertIn(_signal_client.SIGTERM_MESSAGE,
client_stdout.read())
@unittest.skipIf(os.name == 'nt', 'SIGINT not supported on windows')
def testStreaming(self):
"""Tests that the server streaming code path does not stall signal handlers."""
server_target = '{}:{}'.format(_HOST, self._port)
with tempfile.TemporaryFile(mode='r') as client_stdout:
with tempfile.TemporaryFile(mode='r') as client_stderr:
client = _start_client((server_target, 'streaming'),
client_stdout, client_stderr)
self._handler.await_connected_client()
client.send_signal(signal.SIGINT)
self.assertFalse(client.wait(), msg=_read_stream(client_stderr))
client_stdout.seek(0)
self.assertIn(_signal_client.SIGTERM_MESSAGE,
client_stdout.read())
if __name__ == '__main__':
logging.basicConfig()
unittest.main(verbosity=2)
Loading…
Cancel
Save