From 145b199c4ddcd5e7b985585f3e733ae5092adbbb Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Wed, 10 Jan 2018 10:47:48 -0500 Subject: [PATCH] python: Context.abort should fail RPC even for StatusCode.OK grpc.ServicerContext.abort is documented to always raise an exception to terminate the RPC. The code argument "must not be StatusCode.OK." However, if you do pass StatusCode.OK, the RPC terminates successfully on the client side, but returns None. _server.py: If the user accidentally passes StatusCode.OK, treat it as StatusCode.UNKNOWN. This is what happens if the user accidentally passes something that is not a StatusCode instance. Additionally set details to ''. _metadata_code_details_test.py: update test to verify the behavior of abort with invalid codes. --- src/python/grpcio/grpc/_server.py | 6 + .../tests/unit/_metadata_code_details_test.py | 200 ++++++++++-------- 2 files changed, 119 insertions(+), 87 deletions(-) diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index 1cdb2d45b6c..a89987a86df 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -278,6 +278,12 @@ class _Context(grpc.ServicerContext): self._state.trailing_metadata = trailing_metadata def abort(self, code, details): + # treat OK like other invalid arguments: fail the RPC + if code == grpc.StatusCode.OK: + logging.error( + 'abort() called with StatusCode.OK; returning UNKNOWN') + code = grpc.StatusCode.UNKNOWN + details = '' with self._state.condition: self._state.code = code self._state.details = _common.encode(details) diff --git a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py index bb6ac704970..ca10bd4dab5 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py @@ -50,6 +50,12 @@ _SERVER_TRAILING_METADATA = (('server-trailing-md-key', _NON_OK_CODE = grpc.StatusCode.NOT_FOUND _DETAILS = 'Test details!' +# calling abort should always fail an RPC, even for "invalid" codes +_ABORT_CODES = (_NON_OK_CODE, 3, grpc.StatusCode.OK) +_EXPECTED_CLIENT_CODES = (_NON_OK_CODE, grpc.StatusCode.UNKNOWN, + grpc.StatusCode.UNKNOWN) +_EXPECTED_DETAILS = (_DETAILS, _DETAILS, '') + class _Servicer(object): @@ -302,99 +308,119 @@ class MetadataCodeDetailsTest(unittest.TestCase): self.assertEqual(_DETAILS, response_iterator_call.details()) def testAbortedUnaryUnary(self): - self._servicer.set_code(_NON_OK_CODE) - self._servicer.set_details(_DETAILS) - self._servicer.set_abort_call() - - with self.assertRaises(grpc.RpcError) as exception_context: - self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) - - self.assertTrue( - test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) - self.assertTrue( - test_common.metadata_transmitted( - _SERVER_INITIAL_METADATA, - exception_context.exception.initial_metadata())) - self.assertTrue( - test_common.metadata_transmitted( - _SERVER_TRAILING_METADATA, - exception_context.exception.trailing_metadata())) - self.assertIs(_NON_OK_CODE, exception_context.exception.code()) - self.assertEqual(_DETAILS, exception_context.exception.details()) + test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES, + _EXPECTED_DETAILS) + for abort_code, expected_code, expected_details in test_cases: + self._servicer.set_code(abort_code) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, + self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(expected_code, exception_context.exception.code()) + self.assertEqual(expected_details, + exception_context.exception.details()) def testAbortedUnaryStream(self): - self._servicer.set_code(_NON_OK_CODE) - self._servicer.set_details(_DETAILS) - self._servicer.set_abort_call() - - response_iterator_call = self._unary_stream( - _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) - received_initial_metadata = response_iterator_call.initial_metadata() - with self.assertRaises(grpc.RpcError): - self.assertEqual(len(list(response_iterator_call)), 0) - - self.assertTrue( - test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) - self.assertTrue( - test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, - received_initial_metadata)) - self.assertTrue( - test_common.metadata_transmitted( - _SERVER_TRAILING_METADATA, - response_iterator_call.trailing_metadata())) - self.assertIs(_NON_OK_CODE, response_iterator_call.code()) - self.assertEqual(_DETAILS, response_iterator_call.details()) + test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES, + _EXPECTED_DETAILS) + for abort_code, expected_code, expected_details in test_cases: + self._servicer.set_code(abort_code) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + response_iterator_call = self._unary_stream( + _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA) + received_initial_metadata = \ + response_iterator_call.initial_metadata() + with self.assertRaises(grpc.RpcError): + self.assertEqual(len(list(response_iterator_call)), 0) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, + self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(expected_code, response_iterator_call.code()) + self.assertEqual(expected_details, response_iterator_call.details()) def testAbortedStreamUnary(self): - self._servicer.set_code(_NON_OK_CODE) - self._servicer.set_details(_DETAILS) - self._servicer.set_abort_call() - - with self.assertRaises(grpc.RpcError) as exception_context: - self._stream_unary.with_call( - iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), - metadata=_CLIENT_METADATA) - - self.assertTrue( - test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) - self.assertTrue( - test_common.metadata_transmitted( - _SERVER_INITIAL_METADATA, - exception_context.exception.initial_metadata())) - self.assertTrue( - test_common.metadata_transmitted( - _SERVER_TRAILING_METADATA, - exception_context.exception.trailing_metadata())) - self.assertIs(_NON_OK_CODE, exception_context.exception.code()) - self.assertEqual(_DETAILS, exception_context.exception.details()) + test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES, + _EXPECTED_DETAILS) + for abort_code, expected_code, expected_details in test_cases: + self._servicer.set_code(abort_code) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + with self.assertRaises(grpc.RpcError) as exception_context: + self._stream_unary.with_call( + iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, + self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_INITIAL_METADATA, + exception_context.exception.initial_metadata())) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + exception_context.exception.trailing_metadata())) + self.assertIs(expected_code, exception_context.exception.code()) + self.assertEqual(expected_details, + exception_context.exception.details()) def testAbortedStreamStream(self): - self._servicer.set_code(_NON_OK_CODE) - self._servicer.set_details(_DETAILS) - self._servicer.set_abort_call() - - response_iterator_call = self._stream_stream( - iter([object()] * test_constants.STREAM_LENGTH), - metadata=_CLIENT_METADATA) - received_initial_metadata = response_iterator_call.initial_metadata() - with self.assertRaises(grpc.RpcError): - self.assertEqual(len(list(response_iterator_call)), 0) - - self.assertTrue( - test_common.metadata_transmitted( - _CLIENT_METADATA, self._servicer.received_client_metadata())) - self.assertTrue( - test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, - received_initial_metadata)) - self.assertTrue( - test_common.metadata_transmitted( - _SERVER_TRAILING_METADATA, - response_iterator_call.trailing_metadata())) - self.assertIs(_NON_OK_CODE, response_iterator_call.code()) - self.assertEqual(_DETAILS, response_iterator_call.details()) + test_cases = zip(_ABORT_CODES, _EXPECTED_CLIENT_CODES, + _EXPECTED_DETAILS) + for abort_code, expected_code, expected_details in test_cases: + self._servicer.set_code(abort_code) + self._servicer.set_details(_DETAILS) + self._servicer.set_abort_call() + + response_iterator_call = self._stream_stream( + iter([object()] * test_constants.STREAM_LENGTH), + metadata=_CLIENT_METADATA) + received_initial_metadata = \ + response_iterator_call.initial_metadata() + with self.assertRaises(grpc.RpcError): + self.assertEqual(len(list(response_iterator_call)), 0) + + self.assertTrue( + test_common.metadata_transmitted( + _CLIENT_METADATA, + self._servicer.received_client_metadata())) + self.assertTrue( + test_common.metadata_transmitted(_SERVER_INITIAL_METADATA, + received_initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted( + _SERVER_TRAILING_METADATA, + response_iterator_call.trailing_metadata())) + self.assertIs(expected_code, response_iterator_call.code()) + self.assertEqual(expected_details, response_iterator_call.details()) def testCustomCodeUnaryUnary(self): self._servicer.set_code(_NON_OK_CODE)