Merge pull request #17899 from ericgribkoff/undead_server

python: do not store raised exception in _Context.abort()
pull/17925/head
Eric Gribkoff 6 years ago committed by GitHub
commit 96403bc640
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 10
      src/python/grpcio/grpc/_server.py
  2. 28
      src/python/grpcio_tests/tests/unit/_abort_test.py

@ -100,7 +100,7 @@ class _RPCState(object):
self.statused = False
self.rpc_errors = []
self.callbacks = []
self.abortion = None
self.aborted = False
def _raise_rpc_error(state):
@ -287,8 +287,8 @@ class _Context(grpc.ServicerContext):
with self._state.condition:
self._state.code = code
self._state.details = _common.encode(details)
self._state.abortion = Exception()
raise self._state.abortion
self._state.aborted = True
raise Exception()
def abort_with_status(self, status):
self._state.trailing_metadata = status.trailing_metadata
@ -392,7 +392,7 @@ def _call_behavior(rpc_event, state, behavior, argument, request_deserializer):
return behavior(argument, context), True
except Exception as exception: # pylint: disable=broad-except
with state.condition:
if exception is state.abortion:
if state.aborted:
_abort(state, rpc_event.call, cygrpc.StatusCode.unknown,
b'RPC Aborted')
elif exception not in state.rpc_errors:
@ -410,7 +410,7 @@ def _take_response_from_response_iterator(rpc_event, state, response_iterator):
return None, True
except Exception as exception: # pylint: disable=broad-except
with state.condition:
if exception is state.abortion:
if state.aborted:
_abort(state, rpc_event.call, cygrpc.StatusCode.unknown,
b'RPC Aborted')
elif exception not in state.rpc_errors:

@ -15,7 +15,9 @@
import unittest
import collections
import gc
import logging
import weakref
import grpc
@ -39,7 +41,15 @@ class _Status(
pass
class _Object(object):
pass
do_not_leak_me = _Object()
def abort_unary_unary(request, servicer_context):
this_should_not_be_leaked = do_not_leak_me
servicer_context.abort(
grpc.StatusCode.INTERNAL,
_ABORT_DETAILS,
@ -101,6 +111,24 @@ class AbortTest(unittest.TestCase):
self.assertEqual(rpc_error.code(), grpc.StatusCode.INTERNAL)
self.assertEqual(rpc_error.details(), _ABORT_DETAILS)
# This test ensures that abort() does not store the raised exception, which
# on Python 3 (via the `__traceback__` attribute) holds a reference to
# all local vars. Storing the raised exception can prevent GC and stop the
# grpc_call from being unref'ed, even after server shutdown.
def test_abort_does_not_leak_local_vars(self):
global do_not_leak_me # pylint: disable=global-statement
weak_ref = weakref.ref(do_not_leak_me)
# Servicer will abort() after creating a local ref to do_not_leak_me.
with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_ABORT)(_REQUEST)
rpc_error = exception_context.exception
do_not_leak_me = None
# Force garbage collection
gc.collect()
self.assertIsNone(weak_ref())
def test_abort_with_status(self):
with self.assertRaises(grpc.RpcError) as exception_context:
self._channel.unary_unary(_ABORT_WITH_STATUS)(_REQUEST)

Loading…
Cancel
Save