Merge pull request #13667 from mehrdada/servicercontext-abort

Introduce ServicerContext.abort for terminating RPCs with non-OK status
pull/13731/head
Mehrdad Afshari 7 years ago committed by GitHub
commit bd247184e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 40
      src/python/grpcio/grpc/__init__.py
  2. 21
      src/python/grpcio/grpc/_server.py
  3. 3
      src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py
  4. 237
      src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py

@ -834,28 +834,48 @@ class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)):
"""
raise NotImplementedError()
@abc.abstractmethod
def abort(self, code, details):
"""Raises an exception to terminate the RPC with a non-OK status.
The code and details passed as arguments will supercede any existing
ones.
Args:
code: A StatusCode object to be sent to the client.
It must not be StatusCode.OK.
details: An ASCII-encodable string to be sent to the client upon
termination of the RPC.
Raises:
Exception: An exception is always raised to signal the abortion the
RPC to the gRPC runtime.
"""
raise NotImplementedError()
@abc.abstractmethod
def set_code(self, code):
"""Sets the value to be used as status code upon RPC completion.
This method need not be called by method implementations if they wish the
gRPC runtime to determine the status code of the RPC.
This method need not be called by method implementations if they wish
the gRPC runtime to determine the status code of the RPC.
Args:
code: A StatusCode object to be sent to the client.
"""
Args:
code: A StatusCode object to be sent to the client.
"""
raise NotImplementedError()
@abc.abstractmethod
def set_details(self, details):
"""Sets the value to be used as detail string upon RPC completion.
This method need not be called by method implementations if they have no
details to transmit.
This method need not be called by method implementations if they have
no details to transmit.
Args:
details: An arbitrary string to be sent to the client upon completion.
"""
Args:
details: An ASCII-encodable string to be sent to the client upon
termination of the RPC.
"""
raise NotImplementedError()

@ -96,6 +96,7 @@ class _RPCState(object):
self.statused = False
self.rpc_errors = []
self.callbacks = []
self.abortion = None
def _raise_rpc_error(state):
@ -273,6 +274,13 @@ class _Context(grpc.ServicerContext):
with self._state.condition:
self._state.trailing_metadata = trailing_metadata
def abort(self, code, details):
with self._state.condition:
self._state.code = code
self._state.details = _common.encode(details)
self._state.abortion = Exception()
raise self._state.abortion
def set_code(self, code):
with self._state.condition:
self._state.code = code
@ -369,7 +377,10 @@ def _call_behavior(rpc_event, state, behavior, argument, request_deserializer):
return behavior(argument, context), True
except Exception as exception: # pylint: disable=broad-except
with state.condition:
if exception not in state.rpc_errors:
if exception is state.abortion:
_abort(state, rpc_event.operation_call,
cygrpc.StatusCode.unknown, b'RPC Aborted')
elif exception not in state.rpc_errors:
details = 'Exception calling application: {}'.format(exception)
logging.exception(details)
_abort(state, rpc_event.operation_call,
@ -384,7 +395,10 @@ def _take_response_from_response_iterator(rpc_event, state, response_iterator):
return None, True
except Exception as exception: # pylint: disable=broad-except
with state.condition:
if exception not in state.rpc_errors:
if exception is state.abortion:
_abort(state, rpc_event.operation_call,
cygrpc.StatusCode.unknown, b'RPC Aborted')
elif exception not in state.rpc_errors:
details = 'Exception iterating responses: {}'.format(exception)
logging.exception(details)
_abort(state, rpc_event.operation_call,
@ -430,12 +444,11 @@ def _send_response(rpc_event, state, serialized_response):
def _status(rpc_event, state, serialized_response):
with state.condition:
if state.client is not _CANCELLED:
trailing_metadata = state.trailing_metadata
code = _completion_code(state)
details = _details(state)
operations = [
cygrpc.operation_send_status_from_server(
trailing_metadata, code, details, _EMPTY_FLAGS),
state.trailing_metadata, code, details, _EMPTY_FLAGS),
]
if state.initial_metadata_allowed:
operations.append(

@ -67,6 +67,9 @@ class ServicerContext(grpc.ServicerContext):
self._rpc.set_trailing_metadata(
_common.fuss_with_metadata(trailing_metadata))
def abort(self, code, details):
raise NotImplementedError()
def set_code(self, code):
self._rpc.set_code(code)

@ -56,6 +56,7 @@ class _Servicer(object):
def __init__(self):
self._lock = threading.Lock()
self._abort_call = False
self._code = None
self._details = None
self._exception = False
@ -67,10 +68,13 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
if self._abort_call:
context.abort(self._code, self._details)
else:
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
if self._exception:
raise test_control.Defect()
else:
@ -81,10 +85,13 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
if self._abort_call:
context.abort(self._code, self._details)
else:
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
for _ in range(test_constants.STREAM_LENGTH // 2):
yield _SERIALIZED_RESPONSE
if self._exception:
@ -95,14 +102,16 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
# TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
# request iterator.
for ignored_request in request_iterator:
pass
list(request_iterator)
if self._abort_call:
context.abort(self._code, self._details)
else:
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
if self._exception:
raise test_control.Defect()
else:
@ -113,19 +122,25 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
# TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
# request iterator.
for ignored_request in request_iterator:
pass
list(request_iterator)
if self._abort_call:
context.abort(self._code, self._details)
else:
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
for _ in range(test_constants.STREAM_LENGTH // 3):
yield object()
if self._exception:
raise test_control.Defect()
def set_abort_call(self):
with self._lock:
self._abort_call = True
def set_code(self, code):
with self._lock:
self._code = code
@ -212,11 +227,10 @@ class MetadataCodeDetailsTest(unittest.TestCase):
def testSuccessfulUnaryStream(self):
self._servicer.set_details(_DETAILS)
call = self._unary_stream(
response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata()
for _ in call:
pass
received_initial_metadata = response_iterator_call.initial_metadata()
list(response_iterator_call)
self.assertTrue(
test_common.metadata_transmitted(
@ -225,10 +239,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
call.trailing_metadata()))
self.assertIs(grpc.StatusCode.OK, call.code())
self.assertEqual(_DETAILS, call.details())
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(grpc.StatusCode.OK, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testSuccessfulStreamUnary(self):
self._servicer.set_details(_DETAILS)
@ -252,12 +267,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
def testSuccessfulStreamStream(self):
self._servicer.set_details(_DETAILS)
call = self._stream_stream(
response_iterator_call = self._stream_stream(
iter([object()] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata()
for _ in call:
pass
received_initial_metadata = response_iterator_call.initial_metadata()
list(response_iterator_call)
self.assertTrue(
test_common.metadata_transmitted(
@ -266,10 +280,106 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
call.trailing_metadata()))
self.assertIs(grpc.StatusCode.OK, call.code())
self.assertEqual(_DETAILS, call.details())
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(grpc.StatusCode.OK, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testAbortedUnaryUnary(self):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
with self.assertRaises(grpc.RpcError) as exception_context:
self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA, self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_INITIAL_METADATA,
exception_context.exception.initial_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
exception_context.exception.trailing_metadata()))
self.assertIs(_NON_OK_CODE, exception_context.exception.code())
self.assertEqual(_DETAILS, exception_context.exception.details())
def testAbortedUnaryStream(self):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
self.assertEqual(len(list(response_iterator_call)), 0)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA, self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testAbortedStreamUnary(self):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
with self.assertRaises(grpc.RpcError) as exception_context:
self._stream_unary.with_call(
iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA, self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_INITIAL_METADATA,
exception_context.exception.initial_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
exception_context.exception.trailing_metadata()))
self.assertIs(_NON_OK_CODE, exception_context.exception.code())
self.assertEqual(_DETAILS, exception_context.exception.details())
def testAbortedStreamStream(self):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
response_iterator_call = self._stream_stream(
iter([object()] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
self.assertEqual(len(list(response_iterator_call)), 0)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA, self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testCustomCodeUnaryUnary(self):
self._servicer.set_code(_NON_OK_CODE)
@ -296,12 +406,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
call = self._unary_stream(
response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata()
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
for _ in call:
pass
list(response_iterator_call)
self.assertTrue(
test_common.metadata_transmitted(
@ -310,10 +419,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, call.code())
self.assertEqual(_DETAILS, call.details())
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testCustomCodeStreamUnary(self):
self._servicer.set_code(_NON_OK_CODE)
@ -342,13 +452,12 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
call = self._stream_stream(
response_iterator_call = self._stream_stream(
iter([object()] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata()
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError) as exception_context:
for _ in call:
pass
list(response_iterator_call)
self.assertTrue(
test_common.metadata_transmitted(
@ -390,12 +499,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self._servicer.set_details(_DETAILS)
self._servicer.set_exception()
call = self._unary_stream(
response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata()
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
for _ in call:
pass
list(response_iterator_call)
self.assertTrue(
test_common.metadata_transmitted(
@ -404,10 +512,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, call.code())
self.assertEqual(_DETAILS, call.details())
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testCustomCodeExceptionStreamUnary(self):
self._servicer.set_code(_NON_OK_CODE)
@ -438,13 +547,12 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self._servicer.set_details(_DETAILS)
self._servicer.set_exception()
call = self._stream_stream(
response_iterator_call = self._stream_stream(
iter([object()] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata()
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
for _ in call:
pass
list(response_iterator_call)
self.assertTrue(
test_common.metadata_transmitted(
@ -453,10 +561,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, call.code())
self.assertEqual(_DETAILS, call.details())
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testCustomCodeReturnNoneUnaryUnary(self):
self._servicer.set_code(_NON_OK_CODE)

Loading…
Cancel
Save