diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index b44840272c9..6caaece82c4 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -100,7 +100,7 @@ class _RPCState(object): self.statused = False self.rpc_errors = [] self.callbacks = [] - self.abortion = None + self.aborted = False def _raise_rpc_error(state): @@ -287,8 +287,8 @@ class _Context(grpc.ServicerContext): with self._state.condition: self._state.code = code self._state.details = _common.encode(details) - self._state.abortion = Exception() - raise self._state.abortion + self._state.aborted = True + raise Exception() def abort_with_status(self, status): self._state.trailing_metadata = status.trailing_metadata @@ -392,7 +392,7 @@ def _call_behavior(rpc_event, state, behavior, argument, request_deserializer): return behavior(argument, context), True except Exception as exception: # pylint: disable=broad-except with state.condition: - if exception is state.abortion: + if state.aborted: _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, b'RPC Aborted') elif exception not in state.rpc_errors: @@ -410,7 +410,7 @@ def _take_response_from_response_iterator(rpc_event, state, response_iterator): return None, True except Exception as exception: # pylint: disable=broad-except with state.condition: - if exception is state.abortion: + if state.aborted: _abort(state, rpc_event.call, cygrpc.StatusCode.unknown, b'RPC Aborted') elif exception not in state.rpc_errors: diff --git a/src/python/grpcio_tests/tests/unit/_abort_test.py b/src/python/grpcio_tests/tests/unit/_abort_test.py index 6438f6897a0..636f1379ad8 100644 --- a/src/python/grpcio_tests/tests/unit/_abort_test.py +++ b/src/python/grpcio_tests/tests/unit/_abort_test.py @@ -15,7 +15,9 @@ import unittest import collections +import gc import logging +import weakref import grpc @@ -39,7 +41,15 @@ class _Status( pass +class _Object(object): + pass + + +do_not_leak_me = _Object() + + def abort_unary_unary(request, servicer_context): + this_should_not_be_leaked = do_not_leak_me servicer_context.abort( grpc.StatusCode.INTERNAL, _ABORT_DETAILS, @@ -101,6 +111,24 @@ class AbortTest(unittest.TestCase): self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL) self.assertEqual(rpc_error.details(), _ABORT_DETAILS) + # This test ensures that abort() does not store the raised exception, which + # on Python 3 (via the `__traceback__` attribute) holds a reference to + # all local vars. Storing the raised exception can prevent GC and stop the + # grpc_call from being unref'ed, even after server shutdown. + def test_abort_does_not_leak_local_vars(self): + global do_not_leak_me # pylint: disable=global-statement + weak_ref = weakref.ref(do_not_leak_me) + + # Servicer will abort() after creating a local ref to do_not_leak_me. + with self.assertRaises(grpc.RpcError) as exception_context: + self._channel.unary_unary(_ABORT)(_REQUEST) + rpc_error = exception_context.exception + + do_not_leak_me = None + # Force garbage collection + gc.collect() + self.assertIsNone(weak_ref()) + def test_abort_with_status(self): with self.assertRaises(grpc.RpcError) as exception_context: self._channel.unary_unary(_ABORT_WITH_STATUS)(_REQUEST)