Updated several threads to use CleanupThread for clean exit

pull/7001/head
Ken Payson 9 years ago
parent be22335879
commit aa4bb51d95
  1. 65
      src/python/grpcio/grpc/_channel.py
  2. 38
      src/python/grpcio/grpc/beta/_server_adaptations.py
  3. 1
      src/python/grpcio/tests/tests.json
  4. 249
      src/python/grpcio/tests/unit/_exit_scenarios.py
  5. 185
      src/python/grpcio/tests/unit/_exit_test.py

@ -179,6 +179,7 @@ def _event_handler(state, call, response_deserializer):
def _consume_request_iterator(
request_iterator, state, call, request_serializer):
event_handler = _event_handler(state, call, None)
def consume_request_iterator():
for request in request_iterator:
serialized_request = _common.serialize(request, request_serializer)
@ -212,8 +213,18 @@ def _consume_request_iterator(
)
call.start_batch(cygrpc.Operations(operations), event_handler)
state.due.add(cygrpc.OperationType.send_close_from_client)
thread = threading.Thread(target=consume_request_iterator)
thread.start()
def stop_consumption_thread(timeout):
with state.condition:
if state.code is None:
call.cancel()
state.cancelled = True
_abort(state, grpc.StatusCode.CANCELLED, 'Cancelled!')
state.condition.notify_all()
consumption_thread = _common.CleanupThread(
stop_consumption_thread, target=consume_request_iterator)
consumption_thread.start()
class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call):
@ -652,16 +663,27 @@ class _ChannelCallState(object):
self.managed_calls = None
def _call_spin(state):
while True:
event = state.completion_queue.poll()
completed_call = event.tag(event)
if completed_call is not None:
with state.lock:
state.managed_calls.remove(completed_call)
if not state.managed_calls:
state.managed_calls = None
return
def _run_channel_spin_thread(state):
def channel_spin():
while True:
event = state.completion_queue.poll()
completed_call = event.tag(event)
if completed_call is not None:
with state.lock:
state.managed_calls.remove(completed_call)
if not state.managed_calls:
state.managed_calls = None
return
def stop_channel_spin(timeout):
with state.lock:
if state.managed_calls is not None:
for call in state.managed_calls:
call.cancel()
channel_spin_thread = _common.CleanupThread(
stop_channel_spin, target=channel_spin)
channel_spin_thread.start()
def _create_channel_managed_call(state):
@ -690,8 +712,7 @@ def _create_channel_managed_call(state):
parent, flags, state.completion_queue, method, host, deadline)
if state.managed_calls is None:
state.managed_calls = set((call,))
spin_thread = threading.Thread(target=_call_spin, args=(state,))
spin_thread.start()
_run_channel_spin_thread(state)
else:
state.managed_calls.add(call)
return call
@ -784,11 +805,18 @@ def _poll_connectivity(state, channel, initial_try_to_connect):
_spawn_delivery(state, callbacks)
def _moot(state):
with state.lock:
del state.callbacks_and_connectivities[:]
def _subscribe(state, callback, try_to_connect):
with state.lock:
if not state.callbacks_and_connectivities and not state.polling:
polling_thread = threading.Thread(
target=_poll_connectivity,
def cancel_all_subscriptions(timeout):
_moot(state)
polling_thread = _common.CleanupThread(
cancel_all_subscriptions, target=_poll_connectivity,
args=(state, state.channel, bool(try_to_connect)))
polling_thread.start()
state.polling = True
@ -812,11 +840,6 @@ def _unsubscribe(state, callback):
break
def _moot(state):
with state.lock:
del state.callbacks_and_connectivities[:]
def _options(options):
if options is None:
pairs = ((cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT),)

@ -161,14 +161,24 @@ class _Callback(stream.Consumer):
self._condition.wait()
def _pipe_requests(request_iterator, request_consumer, servicer_context):
for request in request_iterator:
if not servicer_context.is_active():
return
request_consumer.consume(request)
if not servicer_context.is_active():
return
request_consumer.terminate()
def _run_request_pipe_thread(request_iterator, request_consumer,
servicer_context):
thread_joined = threading.Event()
def pipe_requests():
for request in request_iterator:
if not servicer_context.is_active() or thread_joined.is_set():
return
request_consumer.consume(request)
if not servicer_context.is_active() or thread_joined.is_set():
return
request_consumer.terminate()
def stop_request_pipe(timeout):
thread_joined.set()
request_pipe_thread = _common.CleanupThread(
stop_request_pipe, target=pipe_requests)
request_pipe_thread.start()
def _adapt_unary_unary_event(unary_unary_event):
@ -206,10 +216,8 @@ def _adapt_stream_unary_event(stream_unary_event):
raise abandonment.Abandoned()
request_consumer = stream_unary_event(
callback.consume_and_terminate, _FaceServicerContext(servicer_context))
request_pipe_thread = threading.Thread(
target=_pipe_requests,
args=(request_iterator, request_consumer, servicer_context,))
request_pipe_thread.start()
_run_request_pipe_thread(
request_iterator, request_consumer, servicer_context)
return callback.draw_all_values()[0]
return adaptation
@ -221,10 +229,8 @@ def _adapt_stream_stream_event(stream_stream_event):
raise abandonment.Abandoned()
request_consumer = stream_stream_event(
callback, _FaceServicerContext(servicer_context))
request_pipe_thread = threading.Thread(
target=_pipe_requests,
args=(request_iterator, request_consumer, servicer_context,))
request_pipe_thread.start()
_run_request_pipe_thread(
request_iterator, request_consumer, servicer_context)
while True:
response = callback.draw_one_value()
if response is None:

@ -13,6 +13,7 @@
"_connectivity_channel_test.ChannelConnectivityTest",
"_connectivity_channel_test.ConnectivityStatesTest",
"_empty_message_test.EmptyMessageTest",
"_exit_test.ExitTest",
"_face_interface_test.DynamicInvokerBlockingInvocationInlineServiceTest",
"_face_interface_test.DynamicInvokerFutureInvocationAsynchronousEventServiceTest",
"_face_interface_test.GenericInvokerBlockingInvocationInlineServiceTest",

@ -0,0 +1,249 @@
# Copyright 2016, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Defines a number of module-scope gRPC scenarios to test clean exit."""
import argparse
import threading
import time
import grpc
from tests.unit.framework.common import test_constants
WAIT_TIME = 1000
REQUEST = b'request'
UNSTARTED_SERVER = 'unstarted_server'
RUNNING_SERVER = 'running_server'
POLL_CONNECTIVITY_NO_SERVER = 'poll_connectivity_no_server'
POLL_CONNECTIVITY = 'poll_connectivity'
IN_FLIGHT_UNARY_UNARY_CALL = 'in_flight_unary_unary_call'
IN_FLIGHT_UNARY_STREAM_CALL = 'in_flight_unary_stream_call'
IN_FLIGHT_STREAM_UNARY_CALL = 'in_flight_stream_unary_call'
IN_FLIGHT_STREAM_STREAM_CALL = 'in_flight_stream_stream_call'
IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL = 'in_flight_partial_unary_stream_call'
IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL = 'in_flight_partial_stream_unary_call'
IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL = 'in_flight_partial_stream_stream_call'
UNARY_UNARY = b'/test/UnaryUnary'
UNARY_STREAM = b'/test/UnaryStream'
STREAM_UNARY = b'/test/StreamUnary'
STREAM_STREAM = b'/test/StreamStream'
PARTIAL_UNARY_STREAM = b'/test/PartialUnaryStream'
PARTIAL_STREAM_UNARY = b'/test/PartialStreamUnary'
PARTIAL_STREAM_STREAM = b'/test/PartialStreamStream'
TEST_TO_METHOD = {
IN_FLIGHT_UNARY_UNARY_CALL: UNARY_UNARY,
IN_FLIGHT_UNARY_STREAM_CALL: UNARY_STREAM,
IN_FLIGHT_STREAM_UNARY_CALL: STREAM_UNARY,
IN_FLIGHT_STREAM_STREAM_CALL: STREAM_STREAM,
IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL: PARTIAL_UNARY_STREAM,
IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL: PARTIAL_STREAM_UNARY,
IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL: PARTIAL_STREAM_STREAM,
}
def hang_unary_unary(request, servicer_context):
time.sleep(WAIT_TIME)
def hang_unary_stream(request, servicer_context):
time.sleep(WAIT_TIME)
def hang_partial_unary_stream(request, servicer_context):
for _ in range(test_constants.STREAM_LENGTH // 2):
yield request
time.sleep(WAIT_TIME)
def hang_stream_unary(request_iterator, servicer_context):
time.sleep(WAIT_TIME)
def hang_partial_stream_unary(request_iterator, servicer_context):
for _ in range(test_constants.STREAM_LENGTH // 2):
next(request_iterator)
time.sleep(WAIT_TIME)
def hang_stream_stream(request_iterator, servicer_context):
time.sleep(WAIT_TIME)
def hang_partial_stream_stream(request_iterator, servicer_context):
for _ in range(test_constants.STREAM_LENGTH // 2):
yield next(request_iterator)
time.sleep(WAIT_TIME)
class MethodHandler(grpc.RpcMethodHandler):
def __init__(self, request_streaming, response_streaming, partial_hang):
self.request_streaming = request_streaming
self.response_streaming = response_streaming
self.request_deserializer = None
self.response_serializer = None
self.unary_unary = None
self.unary_stream = None
self.stream_unary = None
self.stream_stream = None
if self.request_streaming and self.response_streaming:
if partial_hang:
self.stream_stream = hang_partial_stream_stream
else:
self.stream_stream = hang_stream_stream
elif self.request_streaming:
if partial_hang:
self.stream_unary = hang_partial_stream_unary
else:
self.stream_unary = hang_stream_unary
elif self.response_streaming:
if partial_hang:
self.unary_stream = hang_partial_unary_stream
else:
self.unary_stream = hang_unary_stream
else:
self.unary_unary = hang_unary_unary
class GenericHandler(grpc.GenericRpcHandler):
def service(self, handler_call_details):
if handler_call_details.method == UNARY_UNARY:
return MethodHandler(False, False, False)
elif handler_call_details.method == UNARY_STREAM:
return MethodHandler(False, True, False)
elif handler_call_details.method == STREAM_UNARY:
return MethodHandler(True, False, False)
elif handler_call_details.method == STREAM_STREAM:
return MethodHandler(True, True, False)
elif handler_call_details.method == PARTIAL_UNARY_STREAM:
return MethodHandler(False, True, True)
elif handler_call_details.method == PARTIAL_STREAM_UNARY:
return MethodHandler(True, False, True)
elif handler_call_details.method == PARTIAL_STREAM_STREAM:
return MethodHandler(True, True, True)
else:
return None
# Traditional executors will not exit until all their
# current jobs complete. Because we submit jobs that will
# never finish, we don't want to block exit on these jobs.
class DaemonPool(object):
def submit(self, fn, *args, **kwargs):
thread = threading.Thread(target=fn, args=args, kwargs=kwargs)
thread.daemon = True
thread.start()
def shutdown(self, wait=True):
pass
def infinite_request_iterator():
while True:
yield REQUEST
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('scenario', type=str)
parser.add_argument(
'--wait_for_interrupt', dest='wait_for_interrupt', action='store_true')
args = parser.parse_args()
if args.scenario == UNSTARTED_SERVER:
server = grpc.server((), DaemonPool())
if args.wait_for_interrupt:
time.sleep(WAIT_TIME)
elif args.scenario == RUNNING_SERVER:
server = grpc.server((), DaemonPool())
port = server.add_insecure_port('[::]:0')
server.start()
if args.wait_for_interrupt:
time.sleep(WAIT_TIME)
elif args.scenario == POLL_CONNECTIVITY_NO_SERVER:
channel = grpc.insecure_channel('localhost:12345')
def connectivity_callback(connectivity):
pass
channel.subscribe(connectivity_callback, try_to_connect=True)
if args.wait_for_interrupt:
time.sleep(WAIT_TIME)
elif args.scenario == POLL_CONNECTIVITY:
server = grpc.server((), DaemonPool())
port = server.add_insecure_port('[::]:0')
server.start()
channel = grpc.insecure_channel('localhost:%d' % port)
def connectivity_callback(connectivity):
pass
channel.subscribe(connectivity_callback, try_to_connect=True)
if args.wait_for_interrupt:
time.sleep(WAIT_TIME)
else:
handler = GenericHandler()
server = grpc.server((), DaemonPool())
port = server.add_insecure_port('[::]:0')
server.add_generic_rpc_handlers((handler,))
server.start()
channel = grpc.insecure_channel('localhost:%d' % port)
method = TEST_TO_METHOD[args.scenario]
if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL:
multi_callable = channel.unary_unary(method)
future = multi_callable.future(REQUEST)
result, call = multi_callable.with_call(REQUEST)
elif (args.scenario == IN_FLIGHT_UNARY_STREAM_CALL or
args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL):
multi_callable = channel.unary_stream(method)
response_iterator = multi_callable(REQUEST)
for response in response_iterator:
pass
elif (args.scenario == IN_FLIGHT_STREAM_UNARY_CALL or
args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL):
multi_callable = channel.stream_unary(method)
future = multi_callable.future(infinite_request_iterator())
result, call = multi_callable.with_call(
[REQUEST] * test_constants.STREAM_LENGTH)
elif (args.scenario == IN_FLIGHT_STREAM_STREAM_CALL or
args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL):
multi_callable = channel.stream_stream(method)
response_iterator = multi_callable(infinite_request_iterator())
for response in response_iterator:
pass

@ -0,0 +1,185 @@
# Copyright 2016, Google Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests clean exit of server/client on Python Interpreter exit/sigint.
The tests in this module spawn a subprocess for each test case, the
test is considered successful if it doesn't hang/timeout.
"""
import atexit
import os
import signal
import six
import subprocess
import sys
import threading
import time
import unittest
from tests.unit import _exit_scenarios
SCENARIO_FILE = os.path.abspath(os.path.join(
os.path.dirname(os.path.realpath(__file__)), '_exit_scenarios.py'))
INTERPRETER = sys.executable
BASE_COMMAND = [INTERPRETER, SCENARIO_FILE]
BASE_SIGTERM_COMMAND = BASE_COMMAND + ['--wait_for_interrupt']
INIT_TIME = 1.0
processes = []
process_lock = threading.Lock()
# Make sure we attempt to clean up any
# processes we may have left running
def cleanup_processes():
with process_lock:
for process in processes:
try:
process.kill()
except Exception:
pass
atexit.register(cleanup_processes)
def interrupt_and_wait(process):
with process_lock:
processes.append(process)
time.sleep(INIT_TIME)
os.kill(process.pid, signal.SIGINT)
process.wait()
def wait(process):
with process_lock:
processes.append(process)
process.wait()
class ExitTest(unittest.TestCase):
def test_unstarted_server(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.UNSTARTED_SERVER],
stdout=sys.stdout, stderr=sys.stderr)
wait(process)
def test_unstarted_server_terminate(self):
process = subprocess.Popen(
BASE_SIGTERM_COMMAND + [_exit_scenarios.UNSTARTED_SERVER],
stdout=sys.stdout)
interrupt_and_wait(process)
def test_running_server(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.RUNNING_SERVER],
stdout=sys.stdout, stderr=sys.stderr)
wait(process)
def test_running_server_terminate(self):
process = subprocess.Popen(
BASE_SIGTERM_COMMAND + [_exit_scenarios.RUNNING_SERVER],
stdout=sys.stdout, stderr=sys.stderr)
interrupt_and_wait(process)
def test_poll_connectivity_no_server(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER],
stdout=sys.stdout, stderr=sys.stderr)
wait(process)
def test_poll_connectivity_no_server_terminate(self):
process = subprocess.Popen(
BASE_SIGTERM_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER],
stdout=sys.stdout, stderr=sys.stderr)
interrupt_and_wait(process)
def test_poll_connectivity(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY],
stdout=sys.stdout, stderr=sys.stderr)
wait(process)
def test_poll_connectivity_terminate(self):
process = subprocess.Popen(
BASE_SIGTERM_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY],
stdout=sys.stdout, stderr=sys.stderr)
interrupt_and_wait(process)
def test_in_flight_unary_unary_call(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_UNARY_CALL],
stdout=sys.stdout, stderr=sys.stderr)
interrupt_and_wait(process)
@unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
def test_in_flight_unary_stream_call(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_STREAM_CALL],
stdout=sys.stdout, stderr=sys.stderr)
interrupt_and_wait(process)
def test_in_flight_stream_unary_call(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_UNARY_CALL],
stdout=sys.stdout, stderr=sys.stderr)
interrupt_and_wait(process)
@unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
def test_in_flight_stream_stream_call(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_STREAM_CALL],
stdout=sys.stdout, stderr=sys.stderr)
interrupt_and_wait(process)
@unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
def test_in_flight_partial_unary_stream_call(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL],
stdout=sys.stdout, stderr=sys.stderr)
interrupt_and_wait(process)
def test_in_flight_partial_stream_unary_call(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL],
stdout=sys.stdout, stderr=sys.stderr)
interrupt_and_wait(process)
@unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
def test_in_flight_partial_stream_stream_call(self):
process = subprocess.Popen(
BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL],
stdout=sys.stdout, stderr=sys.stderr)
interrupt_and_wait(process)
if __name__ == '__main__':
unittest.main(verbosity=2)
Loading…
Cancel
Save