python: Context.abort should fail RPC even for StatusCode.OK

grpc.ServicerContext.abort is documented to always raise an exception
to terminate the RPC. The code argument "must not be StatusCode.OK."
However, if you do pass StatusCode.OK, the RPC terminates successfully
on the client side, but returns None.

_server.py: If the user accidentally passes StatusCode.OK, treat it as
    StatusCode.UNKNOWN. This is what happens if the user accidentally
    passes something that is not a StatusCode instance. Additionally
    set details to ''.

_metadata_code_details_test.py: update test to verify the behavior of
    abort with invalid codes.
pull/13965/head
Evan Jones 7 years ago
parent 0ea629c61e
commit 145b199c4d
  1. 6
      src/python/grpcio/grpc/_server.py
  2. 200
      src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py

@ -278,6 +278,12 @@ class _Context(grpc.ServicerContext):
self._state.trailing_metadata = trailing_metadata
def abort(self, code, details):
# treat OK like other invalid arguments: fail the RPC
if code == grpc.StatusCode.OK:
logging.error(
'abort() called with StatusCode.OK; returning UNKNOWN')
code = grpc.StatusCode.UNKNOWN
details = ''
with self._state.condition:
self._state.code = code
self._state.details = _common.encode(details)

@ -50,6 +50,12 @@ _SERVER_TRAILING_METADATA = (('server-trailing-md-key',
_NON_OK_CODE = grpc.StatusCode.NOT_FOUND
_DETAILS = 'Test details!'
# calling abort should always fail an RPC, even for "invalid" codes
_ABORT_CODES = (_NON_OK_CODE, 3, grpc.StatusCode.OK)
_EXPECTED_CLIENT_CODES = (_NON_OK_CODE, grpc.StatusCode.UNKNOWN,
grpc.StatusCode.UNKNOWN)
_EXPECTED_DETAILS = (_DETAILS, _DETAILS, '')
class _Servicer(object):
@ -302,99 +308,119 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self.assertEqual(_DETAILS, response_iterator_call.details())
def testAbortedUnaryUnary(self):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
with self.assertRaises(grpc.RpcError) as exception_context:
self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA, self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_INITIAL_METADATA,
exception_context.exception.initial_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
exception_context.exception.trailing_metadata()))
self.assertIs(_NON_OK_CODE, exception_context.exception.code())
self.assertEqual(_DETAILS, exception_context.exception.details())
test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES,
_EXPECTED_DETAILS)
for abort_code, expected_code, expected_details in test_cases:
self._servicer.set_code(abort_code)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
with self.assertRaises(grpc.RpcError) as exception_context:
self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA,
self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_INITIAL_METADATA,
exception_context.exception.initial_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
exception_context.exception.trailing_metadata()))
self.assertIs(expected_code, exception_context.exception.code())
self.assertEqual(expected_details,
exception_context.exception.details())
def testAbortedUnaryStream(self):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
self.assertEqual(len(list(response_iterator_call)), 0)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA, self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES,
_EXPECTED_DETAILS)
for abort_code, expected_code, expected_details in test_cases:
self._servicer.set_code(abort_code)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
response_iterator_call = self._unary_stream(
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = \
response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
self.assertEqual(len(list(response_iterator_call)), 0)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA,
self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(expected_code, response_iterator_call.code())
self.assertEqual(expected_details, response_iterator_call.details())
def testAbortedStreamUnary(self):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
with self.assertRaises(grpc.RpcError) as exception_context:
self._stream_unary.with_call(
iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA, self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_INITIAL_METADATA,
exception_context.exception.initial_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
exception_context.exception.trailing_metadata()))
self.assertIs(_NON_OK_CODE, exception_context.exception.code())
self.assertEqual(_DETAILS, exception_context.exception.details())
test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES,
_EXPECTED_DETAILS)
for abort_code, expected_code, expected_details in test_cases:
self._servicer.set_code(abort_code)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
with self.assertRaises(grpc.RpcError) as exception_context:
self._stream_unary.with_call(
iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA,
self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_INITIAL_METADATA,
exception_context.exception.initial_metadata()))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
exception_context.exception.trailing_metadata()))
self.assertIs(expected_code, exception_context.exception.code())
self.assertEqual(expected_details,
exception_context.exception.details())
def testAbortedStreamStream(self):
self._servicer.set_code(_NON_OK_CODE)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
response_iterator_call = self._stream_stream(
iter([object()] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
self.assertEqual(len(list(response_iterator_call)), 0)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA, self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(_NON_OK_CODE, response_iterator_call.code())
self.assertEqual(_DETAILS, response_iterator_call.details())
test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES,
_EXPECTED_DETAILS)
for abort_code, expected_code, expected_details in test_cases:
self._servicer.set_code(abort_code)
self._servicer.set_details(_DETAILS)
self._servicer.set_abort_call()
response_iterator_call = self._stream_stream(
iter([object()] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA)
received_initial_metadata = \
response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError):
self.assertEqual(len(list(response_iterator_call)), 0)
self.assertTrue(
test_common.metadata_transmitted(
_CLIENT_METADATA,
self._servicer.received_client_metadata()))
self.assertTrue(
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
received_initial_metadata))
self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(expected_code, response_iterator_call.code())
self.assertEqual(expected_details, response_iterator_call.details())
def testCustomCodeUnaryUnary(self):
self._servicer.set_code(_NON_OK_CODE)

Loading…
Cancel
Save