python: Context.abort should fail RPC even for StatusCode.OK

grpc.ServicerContext.abort is documented to always raise an exception
to terminate the RPC. The code argument "must not be StatusCode.OK."
However, if you do pass StatusCode.OK, the RPC terminates successfully
on the client side, but returns None.

_server.py: If the user accidentally passes StatusCode.OK, treat it as
    StatusCode.UNKNOWN. This is what happens if the user accidentally
    passes something that is not a StatusCode instance. Additionally
    set details to ''.

_metadata_code_details_test.py: update test to verify the behavior of
    abort with invalid codes.
pull/13965/head
Evan Jones 7 years ago
parent 0ea629c61e
commit 145b199c4d
  1. 6
      src/python/grpcio/grpc/_server.py
  2. 200
      src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py

@ -278,6 +278,12 @@ class _Context(grpc.ServicerContext):
self._state.trailing_metadata = trailing_metadata self._state.trailing_metadata = trailing_metadata
def abort(self, code, details): def abort(self, code, details):
# treat OK like other invalid arguments: fail the RPC
if code == grpc.StatusCode.OK:
logging.error(
'abort() called with StatusCode.OK; returning UNKNOWN')
code = grpc.StatusCode.UNKNOWN
details = ''
with self._state.condition: with self._state.condition:
self._state.code = code self._state.code = code
self._state.details = _common.encode(details) self._state.details = _common.encode(details)

@ -50,6 +50,12 @@ _SERVER_TRAILING_METADATA = (('server-trailing-md-key',
_NON_OK_CODE = grpc.StatusCode.NOT_FOUND _NON_OK_CODE = grpc.StatusCode.NOT_FOUND
_DETAILS = 'Test details!' _DETAILS = 'Test details!'
# calling abort should always fail an RPC, even for "invalid" codes
_ABORT_CODES = (_NON_OK_CODE, 3, grpc.StatusCode.OK)
_EXPECTED_CLIENT_CODES = (_NON_OK_CODE, grpc.StatusCode.UNKNOWN,
grpc.StatusCode.UNKNOWN)
_EXPECTED_DETAILS = (_DETAILS, _DETAILS, '')
class _Servicer(object): class _Servicer(object):
@ -302,99 +308,119 @@ class MetadataCodeDetailsTest(unittest.TestCase):
self.assertEqual(_DETAILS, response_iterator_call.details()) self.assertEqual(_DETAILS, response_iterator_call.details())
def testAbortedUnaryUnary(self): def testAbortedUnaryUnary(self):
self._servicer.set_code(_NON_OK_CODE) test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES,
self._servicer.set_details(_DETAILS) _EXPECTED_DETAILS)
self._servicer.set_abort_call() for abort_code, expected_code, expected_details in test_cases:
self._servicer.set_code(abort_code)
with self.assertRaises(grpc.RpcError) as exception_context: self._servicer.set_details(_DETAILS)
self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) self._servicer.set_abort_call()
self.assertTrue( with self.assertRaises(grpc.RpcError) as exception_context:
test_common.metadata_transmitted( self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
_CLIENT_METADATA, self._servicer.received_client_metadata()))
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted( test_common.metadata_transmitted(
_SERVER_INITIAL_METADATA, _CLIENT_METADATA,
exception_context.exception.initial_metadata())) self._servicer.received_client_metadata()))
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted( test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA, _SERVER_INITIAL_METADATA,
exception_context.exception.trailing_metadata())) exception_context.exception.initial_metadata()))
self.assertIs(_NON_OK_CODE, exception_context.exception.code()) self.assertTrue(
self.assertEqual(_DETAILS, exception_context.exception.details()) test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
exception_context.exception.trailing_metadata()))
self.assertIs(expected_code, exception_context.exception.code())
self.assertEqual(expected_details,
exception_context.exception.details())
def testAbortedUnaryStream(self): def testAbortedUnaryStream(self):
self._servicer.set_code(_NON_OK_CODE) test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES,
self._servicer.set_details(_DETAILS) _EXPECTED_DETAILS)
self._servicer.set_abort_call() for abort_code, expected_code, expected_details in test_cases:
self._servicer.set_code(abort_code)
response_iterator_call = self._unary_stream( self._servicer.set_details(_DETAILS)
_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) self._servicer.set_abort_call()
received_initial_metadata = response_iterator_call.initial_metadata()
with self.assertRaises(grpc.RpcError): response_iterator_call = self._unary_stream(
self.assertEqual(len(list(response_iterator_call)), 0) _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
received_initial_metadata = \
self.assertTrue( response_iterator_call.initial_metadata()
test_common.metadata_transmitted( with self.assertRaises(grpc.RpcError):
_CLIENT_METADATA, self._servicer.received_client_metadata())) self.assertEqual(len(list(response_iterator_call)), 0)
self.assertTrue(
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, self.assertTrue(
received_initial_metadata)) test_common.metadata_transmitted(
self.assertTrue( _CLIENT_METADATA,
test_common.metadata_transmitted( self._servicer.received_client_metadata()))
_SERVER_TRAILING_METADATA, self.assertTrue(
response_iterator_call.trailing_metadata())) test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
self.assertIs(_NON_OK_CODE, response_iterator_call.code()) received_initial_metadata))
self.assertEqual(_DETAILS, response_iterator_call.details()) self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(expected_code, response_iterator_call.code())
self.assertEqual(expected_details, response_iterator_call.details())
def testAbortedStreamUnary(self): def testAbortedStreamUnary(self):
self._servicer.set_code(_NON_OK_CODE) test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES,
self._servicer.set_details(_DETAILS) _EXPECTED_DETAILS)
self._servicer.set_abort_call() for abort_code, expected_code, expected_details in test_cases:
self._servicer.set_code(abort_code)
with self.assertRaises(grpc.RpcError) as exception_context: self._servicer.set_details(_DETAILS)
self._stream_unary.with_call( self._servicer.set_abort_call()
iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
metadata=_CLIENT_METADATA) with self.assertRaises(grpc.RpcError) as exception_context:
self._stream_unary.with_call(
self.assertTrue( iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
test_common.metadata_transmitted( metadata=_CLIENT_METADATA)
_CLIENT_METADATA, self._servicer.received_client_metadata()))
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted( test_common.metadata_transmitted(
_SERVER_INITIAL_METADATA, _CLIENT_METADATA,
exception_context.exception.initial_metadata())) self._servicer.received_client_metadata()))
self.assertTrue( self.assertTrue(
test_common.metadata_transmitted( test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA, _SERVER_INITIAL_METADATA,
exception_context.exception.trailing_metadata())) exception_context.exception.initial_metadata()))
self.assertIs(_NON_OK_CODE, exception_context.exception.code()) self.assertTrue(
self.assertEqual(_DETAILS, exception_context.exception.details()) test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
exception_context.exception.trailing_metadata()))
self.assertIs(expected_code, exception_context.exception.code())
self.assertEqual(expected_details,
exception_context.exception.details())
def testAbortedStreamStream(self): def testAbortedStreamStream(self):
self._servicer.set_code(_NON_OK_CODE) test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES,
self._servicer.set_details(_DETAILS) _EXPECTED_DETAILS)
self._servicer.set_abort_call() for abort_code, expected_code, expected_details in test_cases:
self._servicer.set_code(abort_code)
response_iterator_call = self._stream_stream( self._servicer.set_details(_DETAILS)
iter([object()] * test_constants.STREAM_LENGTH), self._servicer.set_abort_call()
metadata=_CLIENT_METADATA)
received_initial_metadata = response_iterator_call.initial_metadata() response_iterator_call = self._stream_stream(
with self.assertRaises(grpc.RpcError): iter([object()] * test_constants.STREAM_LENGTH),
self.assertEqual(len(list(response_iterator_call)), 0) metadata=_CLIENT_METADATA)
received_initial_metadata = \
self.assertTrue( response_iterator_call.initial_metadata()
test_common.metadata_transmitted( with self.assertRaises(grpc.RpcError):
_CLIENT_METADATA, self._servicer.received_client_metadata())) self.assertEqual(len(list(response_iterator_call)), 0)
self.assertTrue(
test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, self.assertTrue(
received_initial_metadata)) test_common.metadata_transmitted(
self.assertTrue( _CLIENT_METADATA,
test_common.metadata_transmitted( self._servicer.received_client_metadata()))
_SERVER_TRAILING_METADATA, self.assertTrue(
response_iterator_call.trailing_metadata())) test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
self.assertIs(_NON_OK_CODE, response_iterator_call.code()) received_initial_metadata))
self.assertEqual(_DETAILS, response_iterator_call.details()) self.assertTrue(
test_common.metadata_transmitted(
_SERVER_TRAILING_METADATA,
response_iterator_call.trailing_metadata()))
self.assertIs(expected_code, response_iterator_call.code())
self.assertEqual(expected_details, response_iterator_call.details())
def testCustomCodeUnaryUnary(self): def testCustomCodeUnaryUnary(self):
self._servicer.set_code(_NON_OK_CODE) self._servicer.set_code(_NON_OK_CODE)

Loading…
Cancel
Save