|
|
|
@ -41,6 +41,7 @@ _ERROR_IN_STREAM_STREAM = '/test/ErrorInStreamStream' |
|
|
|
|
_ERROR_IN_STREAM_UNARY = '/test/ErrorInStreamUnary' |
|
|
|
|
_ERROR_WITHOUT_RAISE_IN_UNARY_UNARY = '/test/ErrorWithoutRaiseInUnaryUnary' |
|
|
|
|
_ERROR_WITHOUT_RAISE_IN_STREAM_STREAM = '/test/ErrorWithoutRaiseInStreamStream' |
|
|
|
|
_INVALID_TRAILING_METADATA = '/test/InvalidTrailingMetadata' |
|
|
|
|
|
|
|
|
|
_REQUEST = b'\x00\x00\x00' |
|
|
|
|
_RESPONSE = b'\x01\x01\x01' |
|
|
|
@ -99,6 +100,9 @@ class _GenericHandler(grpc.GenericRpcHandler): |
|
|
|
|
_ERROR_WITHOUT_RAISE_IN_STREAM_STREAM: |
|
|
|
|
grpc.stream_stream_rpc_method_handler( |
|
|
|
|
self._error_without_raise_in_stream_stream), |
|
|
|
|
_INVALID_TRAILING_METADATA: |
|
|
|
|
grpc.unary_unary_rpc_method_handler( |
|
|
|
|
self._invalid_trailing_metadata), |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
@ -199,6 +203,30 @@ class _GenericHandler(grpc.GenericRpcHandler): |
|
|
|
|
assert _REQUEST == request |
|
|
|
|
context.set_code(grpc.StatusCode.INTERNAL) |
|
|
|
|
|
|
|
|
|
async def _invalid_trailing_metadata(self, request, context): |
|
|
|
|
assert _REQUEST == request |
|
|
|
|
for invalid_metadata in [ |
|
|
|
|
42, {}, { |
|
|
|
|
'error': 'error' |
|
|
|
|
}, [{ |
|
|
|
|
'error': "error" |
|
|
|
|
}] |
|
|
|
|
]: |
|
|
|
|
try: |
|
|
|
|
context.set_trailing_metadata(invalid_metadata) |
|
|
|
|
except TypeError: |
|
|
|
|
pass |
|
|
|
|
else: |
|
|
|
|
raise ValueError( |
|
|
|
|
f'No TypeError raised for invalid metadata: {invalid_metadata}' |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
await context.abort(grpc.StatusCode.DATA_LOSS, |
|
|
|
|
details="invalid abort", |
|
|
|
|
trailing_metadata=({ |
|
|
|
|
'error': ('error1', 'error2') |
|
|
|
|
})) |
|
|
|
|
|
|
|
|
|
def service(self, handler_details): |
|
|
|
|
if not self._called.done(): |
|
|
|
|
self._called.set_result(None) |
|
|
|
@ -553,6 +581,16 @@ class TestServer(AioTestBase): |
|
|
|
|
await channel.close() |
|
|
|
|
await server.stop(0) |
|
|
|
|
|
|
|
|
|
async def test_invalid_trailing_metadata(self): |
|
|
|
|
call = self._channel.unary_unary(_INVALID_TRAILING_METADATA)(_REQUEST) |
|
|
|
|
|
|
|
|
|
with self.assertRaises(aio.AioRpcError) as exception_context: |
|
|
|
|
await call |
|
|
|
|
|
|
|
|
|
rpc_error = exception_context.exception |
|
|
|
|
self.assertEqual(grpc.StatusCode.UNKNOWN, rpc_error.code()) |
|
|
|
|
self.assertIn('trailing', rpc_error.details()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
|