Merge pull request #13667 from mehrdada/servicercontext-abort

Introduce ServicerContext.abort for terminating RPCs with non-OK status
pull/13731/head
Mehrdad Afshari 7 years ago committed by GitHub
commit bd247184e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 40
      src/python/grpcio/grpc/__init__.py
  2. 21
      src/python/grpcio/grpc/_server.py
  3. 3
      src/python/grpcio_testing/grpc_testing/_server/_servicer_context.py
  4. 237
      src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py

@ -834,28 +834,48 @@ class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)):
""" """
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod
def abort(self, code, details):
"""Raises an exception to terminate the RPC with a non-OK status.
The code and details passed as arguments will supercede any existing
ones.
Args:
code: A StatusCode object to be sent to the client.
It must not be StatusCode.OK.
details: An ASCII-encodable string to be sent to the client upon
termination of the RPC.
Raises:
Exception: An exception is always raised to signal the abortion the
RPC to the gRPC runtime.
"""
raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def set_code(self, code): def set_code(self, code):
"""Sets the value to be used as status code upon RPC completion. """Sets the value to be used as status code upon RPC completion.
This method need not be called by method implementations if they wish the This method need not be called by method implementations if they wish
gRPC runtime to determine the status code of the RPC. the gRPC runtime to determine the status code of the RPC.
Args: Args:
code: A StatusCode object to be sent to the client. code: A StatusCode object to be sent to the client.
""" """
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def set_details(self, details): def set_details(self, details):
"""Sets the value to be used as detail string upon RPC completion. """Sets the value to be used as detail string upon RPC completion.
This method need not be called by method implementations if they have no This method need not be called by method implementations if they have
details to transmit. no details to transmit.
Args: Args:
details: An arbitrary string to be sent to the client upon completion. details: An ASCII-encodable string to be sent to the client upon
""" termination of the RPC.
"""
raise NotImplementedError() raise NotImplementedError()

@ -96,6 +96,7 @@ class _RPCState(object):
self.statused = False self.statused = False
self.rpc_errors = [] self.rpc_errors = []
self.callbacks = [] self.callbacks = []
self.abortion = None
def _raise_rpc_error(state): def _raise_rpc_error(state):
@ -273,6 +274,13 @@ class _Context(grpc.ServicerContext):
with self._state.condition: with self._state.condition:
self._state.trailing_metadata = trailing_metadata self._state.trailing_metadata = trailing_metadata
def abort(self, code, details):
with self._state.condition:
self._state.code = code
self._state.details = _common.encode(details)
self._state.abortion = Exception()
raise self._state.abortion
def set_code(self, code): def set_code(self, code):
with self._state.condition: with self._state.condition:
self._state.code = code self._state.code = code
@ -369,7 +377,10 @@ def _call_behavior(rpc_event, state, behavior, argument, request_deserializer):
return behavior(argument, context), True return behavior(argument, context), True
except Exception as exception: # pylint: disable=broad-except except Exception as exception: # pylint: disable=broad-except
with state.condition: with state.condition:
if exception not in state.rpc_errors: if exception is state.abortion:
_abort(state, rpc_event.operation_call,
cygrpc.StatusCode.unknown, b'RPC Aborted')
elif exception not in state.rpc_errors:
details = 'Exception calling application: {}'.format(exception) details = 'Exception calling application: {}'.format(exception)
logging.exception(details) logging.exception(details)
_abort(state, rpc_event.operation_call, _abort(state, rpc_event.operation_call,
@ -384,7 +395,10 @@ def _take_response_from_response_iterator(rpc_event, state, response_iterator):
return None, True return None, True
except Exception as exception: # pylint: disable=broad-except except Exception as exception: # pylint: disable=broad-except
with state.condition: with state.condition:
if exception not in state.rpc_errors: if exception is state.abortion:
_abort(state, rpc_event.operation_call,
cygrpc.StatusCode.unknown, b'RPC Aborted')
elif exception not in state.rpc_errors:
details = 'Exception iterating responses: {}'.format(exception) details = 'Exception iterating responses: {}'.format(exception)
logging.exception(details) logging.exception(details)
_abort(state, rpc_event.operation_call, _abort(state, rpc_event.operation_call,
@ -430,12 +444,11 @@ def _send_response(rpc_event, state, serialized_response):
def _status(rpc_event, state, serialized_response): def _status(rpc_event, state, serialized_response):
with state.condition: with state.condition:
if state.client is not _CANCELLED: if state.client is not _CANCELLED:
trailing_metadata = state.trailing_metadata
code = _completion_code(state) code = _completion_code(state)
details = _details(state) details = _details(state)
operations = [ operations = [
cygrpc.operation_send_status_from_server( cygrpc.operation_send_status_from_server(
trailing_metadata, code, details, _EMPTY_FLAGS), state.trailing_metadata, code, details, _EMPTY_FLAGS),
] ]
if state.initial_metadata_allowed: if state.initial_metadata_allowed:
operations.append( operations.append(

@ -67,6 +67,9 @@ class ServicerContext(grpc.ServicerContext):
self._rpc.set_trailing_metadata( self._rpc.set_trailing_metadata(
_common.fuss_with_metadata(trailing_metadata)) _common.fuss_with_metadata(trailing_metadata))
def abort(self, code, details):
raise NotImplementedError()
def set_code(self, code): def set_code(self, code):
self._rpc.set_code(code) self._rpc.set_code(code)

@ -56,6 +56,7 @@ class _Servicer(object):
def __init__(self): def __init__(self):
self._lock = threading.Lock() self._lock = threading.Lock()
self._abort_call = False
self._code = None self._code = None
self._details = None self._details = None
self._exception = False self._exception = False
@ -67,10 +68,13 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata() self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
if self._code is not None: if self._abort_call:
context.set_code(self._code) context.abort(self._code, self._details)
if self._details is not None: else:
context.set_details(self._details) if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
if self._exception: if self._exception:
raise test_control.Defect() raise test_control.Defect()
else: else:
@ -81,10 +85,13 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata() self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
if self._code is not None: if self._abort_call:
context.set_code(self._code) context.abort(self._code, self._details)
if self._details is not None: else:
context.set_details(self._details) if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
for _ in range(test_constants.STREAM_LENGTH // 2): for _ in range(test_constants.STREAM_LENGTH // 2):
yield _SERIALIZED_RESPONSE yield _SERIALIZED_RESPONSE
if self._exception: if self._exception:
@ -95,14 +102,16 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata() self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
# TODO(https://github.com/grpc/grpc/issues/6891): just ignore the # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
# request iterator. # request iterator.
for ignored_request in request_iterator: list(request_iterator)
pass if self._abort_call:
context.abort(self._code, self._details)
else:
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
if self._exception: if self._exception:
raise test_control.Defect() raise test_control.Defect()
else: else:
@ -113,19 +122,25 @@ class _Servicer(object):
self._received_client_metadata = context.invocation_metadata() self._received_client_metadata = context.invocation_metadata()
context.send_initial_metadata(_SERVER_INITIAL_METADATA) context.send_initial_metadata(_SERVER_INITIAL_METADATA)
context.set_trailing_metadata(_SERVER_TRAILING_METADATA) context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
# TODO(https://github.com/grpc/grpc/issues/6891): just ignore the # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
# request iterator. # request iterator.
for ignored_request in request_iterator: list(request_iterator)
pass if self._abort_call:
context.abort(self._code, self._details)
else:
if self._code is not None:
context.set_code(self._code)
if self._details is not None:
context.set_details(self._details)
for _ in range(test_constants.STREAM_LENGTH // 3): for _ in range(test_constants.STREAM_LENGTH // 3):
yield object() yield object()
if self._exception: if self._exception:
raise test_control.Defect() raise test_control.Defect()
def set_abort_call(self):
with self._lock:
self._abort_call = True
def set_code(self, code): def set_code(self, code):
with self._lock: with self._lock:
self._code = code self._code = code
@ -212,11 +227,10 @@ class MetadataCodeDetailsTest(unittest.TestCase):
def testSuccessfulUnaryStream(self): def testSuccessfulUnaryStream(self):
self._servicer.set_details(_DETAILS) self._servicer.set_details(_DETAILS)
call = self._unary_stream( response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata() received_initial_metadata = response_iterator_call.initial_metadata()
for _ in call: list(response_iterator_call)
pass
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted( test_common.metadata_transmitted(
@ -225,10 +239,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata)) received_initial_metadata))
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, test_common.metadata_transmitted(
call.trailing_metadata())) _SERVER_TRAILING_METADATA,
self.assertIs(grpc.StatusCode.OK, call.code()) response_iterator_call.trailing_metadata()))
self.assertEqual(_DETAILS, call.details()) self.assertIs(grpc.StatusCode.OK, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testSuccessfulStreamUnary(self): def testSuccessfulStreamUnary(self):
self._servicer.set_details(_DETAILS) self._servicer.set_details(_DETAILS)
@ -252,12 +267,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
def testSuccessfulStreamStream(self): def testSuccessfulStreamStream(self):
self._servicer.set_details(_DETAILS) self._servicer.set_details(_DETAILS)
call = self._stream_stream( response_iterator_call = self._stream_stream(
iter([object()] * test_constants.STREAM_LENGTH), iter([object()] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA) metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata() received_initial_metadata = response_iterator_call.initial_metadata()
for _ in call: list(response_iterator_call)
pass
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted( test_common.metadata_transmitted(
@ -266,10 +280,106 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata)) received_initial_metadata))
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, test_common.metadata_transmitted(
call.trailing_metadata())) _SERVER_TRAILING_METADATA,
self.assertIs(grpc.StatusCode.OK, call.code()) response_iterator_call.trailing_metadata()))
self.assertEqual(_DETAILS, call.details()) self.assertIs(grpc.StatusCode.OK, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testAbortedUnaryUnary(self):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
with self.assertRaises(grpc.RpcError) as exception_context:
self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA, self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_INITIAL_METADATA,
exception_context.exception.initial_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
exception_context.exception.trailing_metadata()))
self.assertIs(_NON_OK_CODE, exception_context.exception.code())
self.assertEqual(_DETAILS, exception_context.exception.details())
def testAbortedUnaryStream(self):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
self.assertEqual(len(list(response_iterator_call)), 0)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA, self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testAbortedStreamUnary(self):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
with self.assertRaises(grpc.RpcError) as exception_context:
self._stream_unary.with_call(
iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA, self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_INITIAL_METADATA,
exception_context.exception.initial_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
exception_context.exception.trailing_metadata()))
self.assertIs(_NON_OK_CODE, exception_context.exception.code())
self.assertEqual(_DETAILS, exception_context.exception.details())
def testAbortedStreamStream(self):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
response_iterator_call = self._stream_stream(
iter([object()] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
self.assertEqual(len(list(response_iterator_call)), 0)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA, self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testCustomCodeUnaryUnary(self): def testCustomCodeUnaryUnary(self):
self._servicer.set_code(_NON_OK_CODE) self._servicer.set_code(_NON_OK_CODE)
@ -296,12 +406,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self._servicer.set_code(_NON_OK_CODE) self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS) self._servicer.set_details(_DETAILS)
call = self._unary_stream( response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata() received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError): with self.assertRaises(grpc.RpcError):
for _ in call: list(response_iterator_call)
pass
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted( test_common.metadata_transmitted(
@ -310,10 +419,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata)) received_initial_metadata))
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, test_common.metadata_transmitted(
call.trailing_metadata())) _SERVER_TRAILING_METADATA,
self.assertIs(_NON_OK_CODE, call.code()) response_iterator_call.trailing_metadata()))
self.assertEqual(_DETAILS, call.details()) self.assertIs(_NON_OK_CODE, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testCustomCodeStreamUnary(self): def testCustomCodeStreamUnary(self):
self._servicer.set_code(_NON_OK_CODE) self._servicer.set_code(_NON_OK_CODE)
@ -342,13 +452,12 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self._servicer.set_code(_NON_OK_CODE) self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS) self._servicer.set_details(_DETAILS)
call = self._stream_stream( response_iterator_call = self._stream_stream(
iter([object()] * test_constants.STREAM_LENGTH), iter([object()] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA) metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata() received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError) as exception_context: with self.assertRaises(grpc.RpcError) as exception_context:
for _ in call: list(response_iterator_call)
pass
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted( test_common.metadata_transmitted(
@ -390,12 +499,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self._servicer.set_details(_DETAILS) self._servicer.set_details(_DETAILS)
self._servicer.set_exception() self._servicer.set_exception()
call = self._unary_stream( response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata() received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError): with self.assertRaises(grpc.RpcError):
for _ in call: list(response_iterator_call)
pass
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted( test_common.metadata_transmitted(
@ -404,10 +512,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata)) received_initial_metadata))
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, test_common.metadata_transmitted(
call.trailing_metadata())) _SERVER_TRAILING_METADATA,
self.assertIs(_NON_OK_CODE, call.code()) response_iterator_call.trailing_metadata()))
self.assertEqual(_DETAILS, call.details()) self.assertIs(_NON_OK_CODE, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testCustomCodeExceptionStreamUnary(self): def testCustomCodeExceptionStreamUnary(self):
self._servicer.set_code(_NON_OK_CODE) self._servicer.set_code(_NON_OK_CODE)
@ -438,13 +547,12 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self._servicer.set_details(_DETAILS) self._servicer.set_details(_DETAILS)
self._servicer.set_exception() self._servicer.set_exception()
call = self._stream_stream( response_iterator_call = self._stream_stream(
iter([object()] * test_constants.STREAM_LENGTH), iter([object()] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA) metadata=_CLIENT_METADATA)
received_initial_metadata = call.initial_metadata() received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError): with self.assertRaises(grpc.RpcError):
for _ in call: list(response_iterator_call)
pass
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted( test_common.metadata_transmitted(
@ -453,10 +561,11 @@ class MetadataCodeDetailsTest(unittest.TestCase):
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata)) received_initial_metadata))
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted(_SERVER_TRAILING_METADATA, test_common.metadata_transmitted(
call.trailing_metadata())) _SERVER_TRAILING_METADATA,
self.assertIs(_NON_OK_CODE, call.code()) response_iterator_call.trailing_metadata()))
self.assertEqual(_DETAILS, call.details()) self.assertIs(_NON_OK_CODE, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
def testCustomCodeReturnNoneUnaryUnary(self): def testCustomCodeReturnNoneUnaryUnary(self):
self._servicer.set_code(_NON_OK_CODE) self._servicer.set_code(_NON_OK_CODE)

Loading…
Cancel
Save