From 506db80475d2fed0293cecee83cec86cf20accc3 Mon Sep 17 00:00:00 2001 From: "Chris St. Pierre" Date: Wed, 28 Apr 2021 12:15:11 -0500 Subject: [PATCH] Expose code and details from context on the server side (#25457) * Expose code and details from context on the server side This makes them accessible to server-side interceptors. * Complete synchronous implementation * Add experimental API notes * Complete async implementation * Revert bogus copyright change * Fix metadata tests for py2.7 * Add new test to tests.json * Check error codes in tests * Bump autogenerated Python version --- src/python/grpcio/grpc/__init__.py | 30 ++++++++++ .../grpc/_cython/_cygrpc/aio/server.pyx.pxi | 9 +++ src/python/grpcio/grpc/_server.py | 9 +++ src/python/grpcio/grpc/aio/_base_server.py | 30 ++++++++++ src/python/grpcio_tests/tests/tests.json | 1 + .../tests/unit/_metadata_code_details_test.py | 58 +++++++++++++++++++ .../tests_aio/unit/metadata_test.py | 28 +++++++++ 7 files changed, 165 insertions(+) diff --git a/src/python/grpcio/grpc/__init__.py b/src/python/grpcio/grpc/__init__.py index 69803ed1616..ea393c44c5e 100644 --- a/src/python/grpcio/grpc/__init__.py +++ b/src/python/grpcio/grpc/__init__.py @@ -1174,6 +1174,16 @@ class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)): """ raise NotImplementedError() + def trailing_metadata(self): + """Access value to be used as trailing metadata upon RPC completion. + + This is an EXPERIMENTAL API. + + Returns: + The trailing :term:`metadata` for the RPC. + """ + raise NotImplementedError() + @abc.abstractmethod def abort(self, code, details): """Raises an exception to terminate the RPC with a non-OK status. @@ -1237,6 +1247,26 @@ class ServicerContext(six.with_metaclass(abc.ABCMeta, RpcContext)): """ raise NotImplementedError() + def code(self): + """Accesses the value to be used as status code upon RPC completion. + + This is an EXPERIMENTAL API. + + Returns: + The StatusCode value for the RPC. + """ + raise NotImplementedError() + + def details(self): + """Accesses the value to be used as detail string upon RPC completion. + + This is an EXPERIMENTAL API. + + Returns: + The details string of the RPC. + """ + raise NotImplementedError() + def disable_next_message_compression(self): """Disables compression for the next response message. diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index 38c0d713522..2b46c0cc29f 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -197,15 +197,24 @@ cdef class _ServicerContext: def set_trailing_metadata(self, object metadata): self._rpc_state.trailing_metadata = tuple(metadata) + def trailing_metadata(self): + return self._rpc_state.trailing_metadata + def invocation_metadata(self): return self._rpc_state.invocation_metadata() def set_code(self, object code): self._rpc_state.status_code = get_status_code(code) + def code(self): + return self._rpc_state.status_code + def set_details(self, str details): self._rpc_state.status_details = details + def details(self): + return self._rpc_state.status_details + def set_compression(self, object compression): if self._rpc_state.metadata_sent: raise RuntimeError('Compression setting must be specified before sending initial metadata') diff --git a/src/python/grpcio/grpc/_server.py b/src/python/grpcio/grpc/_server.py index 069ffa79822..6c6f02de8a0 100644 --- a/src/python/grpcio/grpc/_server.py +++ b/src/python/grpcio/grpc/_server.py @@ -305,6 +305,9 @@ class _Context(grpc.ServicerContext): with self._state.condition: self._state.trailing_metadata = trailing_metadata + def trailing_metadata(self): + return self._state.trailing_metadata + def abort(self, code, details): # treat OK like other invalid arguments: fail the RPC if code == grpc.StatusCode.OK: @@ -326,10 +329,16 @@ class _Context(grpc.ServicerContext): with self._state.condition: self._state.code = code + def code(self): + return self._state.code + def set_details(self, details): with self._state.condition: self._state.details = _common.encode(details) + def details(self): + return self._state.details + def _finalize_state(self): pass diff --git a/src/python/grpcio/grpc/aio/_base_server.py b/src/python/grpcio/grpc/aio/_base_server.py index d262642c8a9..e22ef3e8303 100644 --- a/src/python/grpcio/grpc/aio/_base_server.py +++ b/src/python/grpcio/grpc/aio/_base_server.py @@ -304,3 +304,33 @@ class ServicerContext(Generic[RequestType, ResponseType], abc.ABC): remaining for the RPC to complete before it is considered to have timed out, or None if no deadline was specified for the RPC. """ + + def trailing_metadata(self): + """Access value to be used as trailing metadata upon RPC completion. + + This is an EXPERIMENTAL API. + + Returns: + The trailing :term:`metadata` for the RPC. + """ + raise NotImplementedError() + + def code(self): + """Accesses the value to be used as status code upon RPC completion. + + This is an EXPERIMENTAL API. + + Returns: + The StatusCode value for the RPC. + """ + raise NotImplementedError() + + def details(self): + """Accesses the value to be used as detail string upon RPC completion. + + This is an EXPERIMENTAL API. + + Returns: + The details string of the RPC. + """ + raise NotImplementedError() diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json index bd4139b4922..93f55bfd52e 100644 --- a/src/python/grpcio_tests/tests/tests.json +++ b/src/python/grpcio_tests/tests/tests.json @@ -60,6 +60,7 @@ "unit._invocation_defects_test.InvocationDefectsTest", "unit._local_credentials_test.LocalCredentialsTest", "unit._logging_test.LoggingTest", + "unit._metadata_code_details_test.InspectContextTest", "unit._metadata_code_details_test.MetadataCodeDetailsTest", "unit._metadata_flags_test.MetadataFlagsTest", "unit._metadata_test.MetadataTest", diff --git a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py index 5b06eb2bfe8..900fabd19af 100644 --- a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py +++ b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py @@ -658,6 +658,64 @@ class MetadataCodeDetailsTest(unittest.TestCase): self.assertEqual(_DETAILS, exception_context.exception.details()) +class _InspectServicer(_Servicer): + + def __init__(self): + super(_InspectServicer, self).__init__() + self.actual_code = None + self.actual_details = None + self.actual_trailing_metadata = None + + def unary_unary(self, request, context): + super(_InspectServicer, self).unary_unary(request, context) + + self.actual_code = context.code() + self.actual_details = context.details() + self.actual_trailing_metadata = context.trailing_metadata() + + +class InspectContextTest(unittest.TestCase): + + def setUp(self): + self._servicer = _InspectServicer() + self._server = test_common.test_server() + self._server.add_generic_rpc_handlers( + (_generic_handler(self._servicer),)) + port = self._server.add_insecure_port('[::]:0') + self._server.start() + + self._channel = grpc.insecure_channel('localhost:{}'.format(port)) + self._unary_unary = self._channel.unary_unary( + '/'.join(( + '', + _SERVICE, + _UNARY_UNARY, + )), + request_serializer=_REQUEST_SERIALIZER, + response_deserializer=_RESPONSE_DESERIALIZER, + ) + + def tearDown(self): + self._server.stop(None) + self._channel.close() + + def testCodeDetailsInContext(self): + self._servicer.set_code(_NON_OK_CODE) + self._servicer.set_details(_DETAILS) + + with self.assertRaises(grpc.RpcError) as exc_info: + self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA) + + err = exc_info.exception + self.assertEqual(_NON_OK_CODE, err.code()) + + self.assertEqual(self._servicer.actual_code, _NON_OK_CODE) + self.assertEqual(self._servicer.actual_details.decode('utf-8'), + _DETAILS) + self.assertEqual(self._servicer.actual_trailing_metadata, + _SERVER_TRAILING_METADATA) + + if __name__ == '__main__': logging.basicConfig() unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py index 2261446b3ea..8f8c39214c8 100644 --- a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py @@ -33,6 +33,7 @@ _TEST_GENERIC_HANDLER = '/test/TestGenericHandler' _TEST_UNARY_STREAM = '/test/TestUnaryStream' _TEST_STREAM_UNARY = '/test/TestStreamUnary' _TEST_STREAM_STREAM = '/test/TestStreamStream' +_TEST_INSPECT_CONTEXT = '/test/TestInspectContext' _REQUEST = b'\x00\x00\x00' _RESPONSE = b'\x01\x01\x01' @@ -75,6 +76,9 @@ _INVALID_METADATA_TEST_CASES = ( ), ) +_NON_OK_CODE = grpc.StatusCode.NOT_FOUND +_DETAILS = 'Test details!' + class _TestGenericHandlerForMethods(grpc.GenericRpcHandler): @@ -95,6 +99,8 @@ class _TestGenericHandlerForMethods(grpc.GenericRpcHandler): grpc.stream_unary_rpc_method_handler(self._test_stream_unary), _TEST_STREAM_STREAM: grpc.stream_stream_rpc_method_handler(self._test_stream_stream), + _TEST_INSPECT_CONTEXT: + grpc.unary_unary_rpc_method_handler(self._test_inspect_context), } @staticmethod @@ -153,6 +159,19 @@ class _TestGenericHandlerForMethods(grpc.GenericRpcHandler): yield _RESPONSE context.set_trailing_metadata(_TRAILING_METADATA) + @staticmethod + async def _test_inspect_context(request, context): + assert _REQUEST == request + context.set_code(_NON_OK_CODE) + context.set_details(_DETAILS) + context.set_trailing_metadata(_TRAILING_METADATA) + + # ensure that we can read back the data we set on the context + assert context.get_code() == _NON_OK_CODE + assert context.get_details() == _DETAILS + assert context.get_trailing_metadata() == _TRAILING_METADATA + return _RESPONSE + def service(self, handler_call_details): return self._routing_table.get(handler_call_details.method) @@ -291,6 +310,15 @@ class TestMetadata(AioTestBase): self.assertEqual(expected_sum, metadata_obj + aio.Metadata( ('third', '3'))) + async def test_inspect_context(self): + multicallable = self._client.unary_unary(_TEST_INSPECT_CONTEXT) + call = multicallable(_REQUEST) + with self.assertRaises(grpc.RpcError) as exc_data: + await call + + err = exc_data.exception + self.assertEqual(_NON_OK_CODE, err.code()) + if __name__ == '__main__': logging.basicConfig(level=logging.DEBUG)