Gracefully handle errors from callbacks.

In https://github.com/grpc/grpc/issues/19910, it was pointed out that
raising an exception from a Future callback would cause the channel spin
thread to terminate. If there are outstanding events on the channel,
this will cause calls to Channel.close() to block indefinitely.

This commit ensures that the channel spin thread does not die. Instead,
exceptions will be logged at ERROR level.
pull/20015/head
Richard Belleville 6 years ago
parent 073b234308
commit 09a270d6ad
  1. 3
      src/python/grpcio/grpc/__init__.py
  2. 9
      src/python/grpcio/grpc/_channel.py
  3. 67
      src/python/grpcio_tests/tests/unit/_channel_close_test.py

@ -192,6 +192,9 @@ class Future(six.with_metaclass(abc.ABCMeta)):
If the computation has already completed, the callback will be called If the computation has already completed, the callback will be called
immediately. immediately.
Exceptions raised in the callback will be logged at ERROR level, but
will not terminate any threads of execution.
Args: Args:
fn: A callable taking this Future object as its single parameter. fn: A callable taking this Future object as its single parameter.
""" """

@ -159,7 +159,14 @@ def _event_handler(state, response_deserializer):
state.condition.notify_all() state.condition.notify_all()
done = not state.due done = not state.due
for callback in callbacks: for callback in callbacks:
# TODO(gnossen): Are these *only* user callbacks?
try:
callback() callback()
except Exception as e: # pylint: disable=broad-except
# NOTE(rbellevi): We suppress but log errors here so as not to
# kill the channel spin thread.
logging.error('Exception in callback %s: %s', repr(
callback.func), repr(e))
return done and state.fork_epoch >= cygrpc.get_fork_epoch() return done and state.fork_epoch >= cygrpc.get_fork_epoch()
return handle_event return handle_event
@ -338,7 +345,7 @@ class _Rendezvous(grpc.RpcError, grpc.Future, grpc.Call): # pylint: disable=too
def add_done_callback(self, fn): def add_done_callback(self, fn):
with self._state.condition: with self._state.condition:
if self._state.code is None: if self._state.code is None:
self._state.callbacks.append(lambda: fn(self)) self._state.callbacks.append(functools.partial(fn, self))
return return
fn(self) fn(self)

@ -27,8 +27,11 @@ _BEAT = 0.5
_SOME_TIME = 5 _SOME_TIME = 5
_MORE_TIME = 10 _MORE_TIME = 10
_STREAM_URI = 'Meffod'
_UNARY_URI = 'MeffodMan'
class _MethodHandler(grpc.RpcMethodHandler):
class _StreamingMethodHandler(grpc.RpcMethodHandler):
request_streaming = True request_streaming = True
response_streaming = True response_streaming = True
@ -40,13 +43,28 @@ class _MethodHandler(grpc.RpcMethodHandler):
yield request * 2 yield request * 2
_METHOD_HANDLER = _MethodHandler() class _UnaryMethodHandler(grpc.RpcMethodHandler):
request_streaming = False
response_streaming = False
request_deserializer = None
response_serializer = None
def unary_unary(self, request, servicer_context):
return request * 2
_STREAMING_METHOD_HANDLER = _StreamingMethodHandler()
_UNARY_METHOD_HANDLER = _UnaryMethodHandler()
class _GenericHandler(grpc.GenericRpcHandler): class _GenericHandler(grpc.GenericRpcHandler):
def service(self, handler_call_details): def service(self, handler_call_details):
return _METHOD_HANDLER if handler_call_details.method == _STREAM_URI:
return _STREAMING_METHOD_HANDLER
else:
return _UNARY_METHOD_HANDLER
_GENERIC_HANDLER = _GenericHandler() _GENERIC_HANDLER = _GenericHandler()
@ -94,6 +112,24 @@ class _Pipe(object):
self.close() self.close()
class EndlessIterator(object):
def __init__(self, msg):
self._msg = msg
def __iter__(self):
return self
def _next(self):
return self._msg
def __next__(self):
return self._next()
def next(self):
return self._next()
class ChannelCloseTest(unittest.TestCase): class ChannelCloseTest(unittest.TestCase):
def setUp(self): def setUp(self):
@ -108,7 +144,7 @@ class ChannelCloseTest(unittest.TestCase):
def test_close_immediately_after_call_invocation(self): def test_close_immediately_after_call_invocation(self):
channel = grpc.insecure_channel('localhost:{}'.format(self._port)) channel = grpc.insecure_channel('localhost:{}'.format(self._port))
multi_callable = channel.stream_stream('Meffod') multi_callable = channel.stream_stream(_STREAM_URI)
request_iterator = _Pipe(()) request_iterator = _Pipe(())
response_iterator = multi_callable(request_iterator) response_iterator = multi_callable(request_iterator)
channel.close() channel.close()
@ -118,7 +154,7 @@ class ChannelCloseTest(unittest.TestCase):
def test_close_while_call_active(self): def test_close_while_call_active(self):
channel = grpc.insecure_channel('localhost:{}'.format(self._port)) channel = grpc.insecure_channel('localhost:{}'.format(self._port))
multi_callable = channel.stream_stream('Meffod') multi_callable = channel.stream_stream(_STREAM_URI)
request_iterator = _Pipe((b'abc',)) request_iterator = _Pipe((b'abc',))
response_iterator = multi_callable(request_iterator) response_iterator = multi_callable(request_iterator)
next(response_iterator) next(response_iterator)
@ -130,7 +166,7 @@ class ChannelCloseTest(unittest.TestCase):
def test_context_manager_close_while_call_active(self): def test_context_manager_close_while_call_active(self):
with grpc.insecure_channel('localhost:{}'.format( with grpc.insecure_channel('localhost:{}'.format(
self._port)) as channel: # pylint: disable=bad-continuation self._port)) as channel: # pylint: disable=bad-continuation
multi_callable = channel.stream_stream('Meffod') multi_callable = channel.stream_stream(_STREAM_URI)
request_iterator = _Pipe((b'abc',)) request_iterator = _Pipe((b'abc',))
response_iterator = multi_callable(request_iterator) response_iterator = multi_callable(request_iterator)
next(response_iterator) next(response_iterator)
@ -141,7 +177,7 @@ class ChannelCloseTest(unittest.TestCase):
def test_context_manager_close_while_many_calls_active(self): def test_context_manager_close_while_many_calls_active(self):
with grpc.insecure_channel('localhost:{}'.format( with grpc.insecure_channel('localhost:{}'.format(
self._port)) as channel: # pylint: disable=bad-continuation self._port)) as channel: # pylint: disable=bad-continuation
multi_callable = channel.stream_stream('Meffod') multi_callable = channel.stream_stream(_STREAM_URI)
request_iterators = tuple( request_iterators = tuple(
_Pipe((b'abc',)) _Pipe((b'abc',))
for _ in range(test_constants.THREAD_CONCURRENCY)) for _ in range(test_constants.THREAD_CONCURRENCY))
@ -158,7 +194,7 @@ class ChannelCloseTest(unittest.TestCase):
def test_many_concurrent_closes(self): def test_many_concurrent_closes(self):
channel = grpc.insecure_channel('localhost:{}'.format(self._port)) channel = grpc.insecure_channel('localhost:{}'.format(self._port))
multi_callable = channel.stream_stream('Meffod') multi_callable = channel.stream_stream(_STREAM_URI)
request_iterator = _Pipe((b'abc',)) request_iterator = _Pipe((b'abc',))
response_iterator = multi_callable(request_iterator) response_iterator = multi_callable(request_iterator)
next(response_iterator) next(response_iterator)
@ -181,6 +217,21 @@ class ChannelCloseTest(unittest.TestCase):
self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED) self.assertIs(response_iterator.code(), grpc.StatusCode.CANCELLED)
def test_exception_in_callback(self):
with grpc.insecure_channel('localhost:{}'.format(
self._port)) as channel:
stream_multi_callable = channel.stream_stream(_STREAM_URI)
request_iterator = (str(i).encode('ascii') for i in range(9999))
endless_iterator = EndlessIterator(b'abc')
stream_response_iterator = stream_multi_callable(endless_iterator)
future = channel.unary_unary(_UNARY_URI).future(b'abc')
def on_done_callback(future):
raise Exception("This should not cause a deadlock.")
future.add_done_callback(on_done_callback)
future.result()
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig() logging.basicConfig()

Loading…
Cancel
Save