From f7d4f8e13cf80b22355169d69838d13ecb146214 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Thu, 11 Nov 2021 13:44:55 -0800 Subject: [PATCH] [Aio] Validate the input type for set_trailing_metadata and abort (#27958) * [Aio] Validate the input type for set_trailing_metadata and abort * Correct the checking of sequence type --- .../grpc/_cython/_cygrpc/aio/common.pyx.pxi | 16 ++++++++ .../grpc/_cython/_cygrpc/aio/server.pyx.pxi | 2 + .../tests_aio/unit/server_test.py | 38 +++++++++++++++++++ 3 files changed, 56 insertions(+) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi index d5113ae7d14..2bbe5498900 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi @@ -14,6 +14,8 @@ from cpython.version cimport PY_MAJOR_VERSION, PY_MINOR_VERSION +TYPE_METADATA_STRING = "Tuple[Tuple[str, Union[str, bytes]]...]" + cdef grpc_status_code get_status_code(object code) except *: if isinstance(code, int): @@ -184,3 +186,17 @@ else: def get_working_loop(): """Returns a running event loop.""" return asyncio.get_event_loop() + + +def raise_if_not_valid_trailing_metadata(object metadata): + if not hasattr(metadata, '__iter__') or isinstance(metadata, dict): + raise TypeError(f'Invalid trailing metadata type, expected {TYPE_METADATA_STRING}: {metadata}') + for item in metadata: + if not isinstance(item, tuple): + raise TypeError(f'Invalid trailing metadata type, expected {TYPE_METADATA_STRING}: {metadata}') + if len(item) != 2: + raise TypeError(f'Invalid trailing metadata type, expected {TYPE_METADATA_STRING}: {metadata}') + if not isinstance(item[0], str): + raise TypeError(f'Invalid trailing metadata type, expected {TYPE_METADATA_STRING}: {metadata}') + if not isinstance(item[1], str) and not isinstance(item[1], bytes): + raise TypeError(f'Invalid trailing metadata type, expected {TYPE_METADATA_STRING}: {metadata}') diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi index 2976ca4415e..1023a2006a6 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi @@ -175,6 +175,7 @@ cdef class _ServicerContext: if trailing_metadata == _IMMUTABLE_EMPTY_METADATA and self._rpc_state.trailing_metadata: trailing_metadata = self._rpc_state.trailing_metadata else: + raise_if_not_valid_trailing_metadata(trailing_metadata) self._rpc_state.trailing_metadata = trailing_metadata if details == '' and self._rpc_state.status_details: @@ -201,6 +202,7 @@ cdef class _ServicerContext: await self.abort(status.code, status.details, status.trailing_metadata) def set_trailing_metadata(self, object metadata): + raise_if_not_valid_trailing_metadata(metadata) self._rpc_state.trailing_metadata = tuple(metadata) def trailing_metadata(self): diff --git a/src/python/grpcio_tests/tests_aio/unit/server_test.py b/src/python/grpcio_tests/tests_aio/unit/server_test.py index 3be34a7a729..8a6995e99aa 100644 --- a/src/python/grpcio_tests/tests_aio/unit/server_test.py +++ b/src/python/grpcio_tests/tests_aio/unit/server_test.py @@ -41,6 +41,7 @@ _ERROR_IN_STREAM_STREAM = '/test/ErrorInStreamStream' _ERROR_IN_STREAM_UNARY = '/test/ErrorInStreamUnary' _ERROR_WITHOUT_RAISE_IN_UNARY_UNARY = '/test/ErrorWithoutRaiseInUnaryUnary' _ERROR_WITHOUT_RAISE_IN_STREAM_STREAM = '/test/ErrorWithoutRaiseInStreamStream' +_INVALID_TRAILING_METADATA = '/test/InvalidTrailingMetadata' _REQUEST = b'\x00\x00\x00' _RESPONSE = b'\x01\x01\x01' @@ -99,6 +100,9 @@ class _GenericHandler(grpc.GenericRpcHandler): _ERROR_WITHOUT_RAISE_IN_STREAM_STREAM: grpc.stream_stream_rpc_method_handler( self._error_without_raise_in_stream_stream), + _INVALID_TRAILING_METADATA: + grpc.unary_unary_rpc_method_handler( + self._invalid_trailing_metadata), } @staticmethod @@ -199,6 +203,30 @@ class _GenericHandler(grpc.GenericRpcHandler): assert _REQUEST == request context.set_code(grpc.StatusCode.INTERNAL) + async def _invalid_trailing_metadata(self, request, context): + assert _REQUEST == request + for invalid_metadata in [ + 42, {}, { + 'error': 'error' + }, [{ + 'error': "error" + }] + ]: + try: + context.set_trailing_metadata(invalid_metadata) + except TypeError: + pass + else: + raise ValueError( + f'No TypeError raised for invalid metadata: {invalid_metadata}' + ) + + await context.abort(grpc.StatusCode.DATA_LOSS, + details="invalid abort", + trailing_metadata=({ + 'error': ('error1', 'error2') + })) + def service(self, handler_details): if not self._called.done(): self._called.set_result(None) @@ -553,6 +581,16 @@ class TestServer(AioTestBase): await channel.close() await server.stop(0) + async def test_invalid_trailing_metadata(self): + call = self._channel.unary_unary(_INVALID_TRAILING_METADATA)(_REQUEST) + + with self.assertRaises(aio.AioRpcError) as exception_context: + await call + + rpc_error = exception_context.exception + self.assertEqual(grpc.StatusCode.UNKNOWN, rpc_error.code()) + self.assertIn('trailing', rpc_error.details()) + if __name__ == '__main__': logging.basicConfig(level=logging.DEBUG)