[Aio] Validate the input type for set_trailing_metadata and abort (#27958)

* [Aio] Validate the input type for set_trailing_metadata and abort

* Correct the checking of sequence type
pull/28018/head
Lidi Zheng 3 years ago committed by GitHub
parent 1654e512b3
commit f7d4f8e13c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 16
      src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi
  2. 2
      src/python/grpcio/grpc/_cython/_cygrpc/aio/server.pyx.pxi
  3. 38
      src/python/grpcio_tests/tests_aio/unit/server_test.py

@ -14,6 +14,8 @@
from cpython.version cimport PY_MAJOR_VERSION, PY_MINOR_VERSION
TYPE_METADATA_STRING = "Tuple[Tuple[str, Union[str, bytes]]...]"
cdef grpc_status_code get_status_code(object code) except *:
if isinstance(code, int):
@ -184,3 +186,17 @@ else:
def get_working_loop():
"""Returns a running event loop."""
return asyncio.get_event_loop()
def raise_if_not_valid_trailing_metadata(object metadata):
if not hasattr(metadata, '__iter__') or isinstance(metadata, dict):
raise TypeError(f'Invalid trailing metadata type, expected {TYPE_METADATA_STRING}: {metadata}')
for item in metadata:
if not isinstance(item, tuple):
raise TypeError(f'Invalid trailing metadata type, expected {TYPE_METADATA_STRING}: {metadata}')
if len(item) != 2:
raise TypeError(f'Invalid trailing metadata type, expected {TYPE_METADATA_STRING}: {metadata}')
if not isinstance(item[0], str):
raise TypeError(f'Invalid trailing metadata type, expected {TYPE_METADATA_STRING}: {metadata}')
if not isinstance(item[1], str) and not isinstance(item[1], bytes):
raise TypeError(f'Invalid trailing metadata type, expected {TYPE_METADATA_STRING}: {metadata}')

@ -175,6 +175,7 @@ cdef class _ServicerContext:
if trailing_metadata == _IMMUTABLE_EMPTY_METADATA and self._rpc_state.trailing_metadata:
trailing_metadata = self._rpc_state.trailing_metadata
else:
raise_if_not_valid_trailing_metadata(trailing_metadata)
self._rpc_state.trailing_metadata = trailing_metadata
if details == '' and self._rpc_state.status_details:
@ -201,6 +202,7 @@ cdef class _ServicerContext:
await self.abort(status.code, status.details, status.trailing_metadata)
def set_trailing_metadata(self, object metadata):
raise_if_not_valid_trailing_metadata(metadata)
self._rpc_state.trailing_metadata = tuple(metadata)
def trailing_metadata(self):

@ -41,6 +41,7 @@ _ERROR_IN_STREAM_STREAM = '/test/ErrorInStreamStream'
_ERROR_IN_STREAM_UNARY = '/test/ErrorInStreamUnary'
_ERROR_WITHOUT_RAISE_IN_UNARY_UNARY = '/test/ErrorWithoutRaiseInUnaryUnary'
_ERROR_WITHOUT_RAISE_IN_STREAM_STREAM = '/test/ErrorWithoutRaiseInStreamStream'
_INVALID_TRAILING_METADATA = '/test/InvalidTrailingMetadata'
_REQUEST = b'\x00\x00\x00'
_RESPONSE = b'\x01\x01\x01'
@ -99,6 +100,9 @@ class _GenericHandler(grpc.GenericRpcHandler):
_ERROR_WITHOUT_RAISE_IN_STREAM_STREAM:
grpc.stream_stream_rpc_method_handler(
self._error_without_raise_in_stream_stream),
_INVALID_TRAILING_METADATA:
grpc.unary_unary_rpc_method_handler(
self._invalid_trailing_metadata),
}
@staticmethod
@ -199,6 +203,30 @@ class _GenericHandler(grpc.GenericRpcHandler):
assert _REQUEST == request
context.set_code(grpc.StatusCode.INTERNAL)
async def _invalid_trailing_metadata(self, request, context):
assert _REQUEST == request
for invalid_metadata in [
42, {}, {
'error': 'error'
}, [{
'error': "error"
}]
]:
try:
context.set_trailing_metadata(invalid_metadata)
except TypeError:
pass
else:
raise ValueError(
f'No TypeError raised for invalid metadata: {invalid_metadata}'
)
await context.abort(grpc.StatusCode.DATA_LOSS,
details="invalid abort",
trailing_metadata=({
'error': ('error1', 'error2')
}))
def service(self, handler_details):
if not self._called.done():
self._called.set_result(None)
@ -553,6 +581,16 @@ class TestServer(AioTestBase):
await channel.close()
await server.stop(0)
async def test_invalid_trailing_metadata(self):
call = self._channel.unary_unary(_INVALID_TRAILING_METADATA)(_REQUEST)
with self.assertRaises(aio.AioRpcError) as exception_context:
await call
rpc_error = exception_context.exception
self.assertEqual(grpc.StatusCode.UNKNOWN, rpc_error.code())
self.assertIn('trailing', rpc_error.details())
if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG)

Loading…
Cancel
Save