# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Based on https://github.com/PaddlePaddle/PaddleNLP/blob/develop/tests/common_test.py import unittest import warnings import subprocess import numpy as np import paddle __all__ = ['CommonTest', 'CpuCommonTest', 'run_script'] def run_script(cmd, silent=True, wd=None, timeout=None, echo=True): # XXX: This function is not safe!!! cfg = dict(check=True, shell=True, timeout=timeout) if silent: cfg['stdout'] = subprocess.DEVNULL if wd is not None: cmd = f"cd {wd} && {cmd}" if echo: print(cmd) return subprocess.run(cmd, **cfg) # Assume all elements has same data type def get_container_type(container): container_t = type(container) if container_t in [list, tuple]: if len(container) == 0: return container_t return get_container_type(container[0]) return container_t class _CommonTestNamespace: # Wrap the subclasses of unittest.TestCase that are expected to be inherited from. class CommonTest(unittest.TestCase): CATCH_WARNINGS = False def __init__(self, methodName='runTest'): super(CommonTest, self).__init__(methodName=methodName) self.config = {} self.places = ['cpu'] if paddle.is_compiled_with_cuda(): self.places.append('gpu') @classmethod def setUpClass(cls): """ Set the decorators for all test function """ for key, value in cls.__dict__.items(): if key.startswith('test'): decorator_func_list = ["_test_places"] if cls.CATCH_WARNINGS: decorator_func_list.append("_catch_warnings") for decorator_func in decorator_func_list: decorator_func = getattr(CommonTest, decorator_func) value = decorator_func(value) setattr(cls, key, value) def _catch_warnings(func): """ Catch the warnings and treat them as errors for each test. """ def wrapper(self, *args, **kwargs): with warnings.catch_warnings(record=True) as w: warnings.resetwarnings() # Ignore specified warnings warning_white_list = [UserWarning] for warning in warning_white_list: warnings.simplefilter("ignore", warning) func(self, *args, **kwargs) msg = None if len(w) == 0 else w[0].message self.assertFalse(len(w) > 0, msg) return wrapper def _test_places(func): """ Setting the running place for each test. """ def wrapper(self, *args, **kwargs): places = self.places for place in places: paddle.set_device(place) func(self, *args, **kwargs) return wrapper def _check_output_impl(self, result, expected_result, rtol, atol, equal=True): assertForNormalType = self.assertNotEqual assertForFloat = self.assertFalse if equal: assertForNormalType = self.assertEqual assertForFloat = self.assertTrue result_t = type(result) error_msg = "Output has diff at place:{}. \nExpect: {} \nBut Got: {} in class {}" if result_t in [list, tuple]: result_t = get_container_type(result) if result_t in [ str, int, bool, set, np.bool, np.int32, np.int64, np.str ]: assertForNormalType( result, expected_result, msg=error_msg.format(paddle.get_device(), expected_result, result, self.__class__.__name__)) elif result_t in [float, np.ndarray, np.float32, np.float64]: assertForFloat( np.allclose( result, expected_result, rtol=rtol, atol=atol), msg=error_msg.format(paddle.get_device(), expected_result, result, self.__class__.__name__)) if result_t == np.ndarray: assertForNormalType( result.shape, expected_result.shape, msg=error_msg.format( paddle.get_device(), expected_result.shape, result.shape, self.__class__.__name__)) else: raise ValueError( "result type must be str, int, bool, set, np.bool, np.int32, " "np.int64, np.str, float, np.ndarray, np.float32, np.float64" ) def check_output_equal(self, result, expected_result, rtol=1.e-5, atol=1.e-8): """ Check whether result and expected result are equal, including shape. Args: result (str|int|bool|set|np.ndarray): The result needs to be checked. expected_result (str|int|bool|set|np.ndarray): The type has to be same as result's. Use the expected result to check result. rtol (float, optional): relative tolerance, default 1.e-5. atol (float, optional): absolute tolerance, default 1.e-8 """ self._check_output_impl(result, expected_result, rtol, atol) def check_output_not_equal(self, result, expected_result, rtol=1.e-5, atol=1.e-8): """ Check whether result and expected result are not equal, including shape. Args: result (str|int|bool|set|np.ndarray): The result needs to be checked. expected_result (str|int|bool|set|np.ndarray): The type has to be same as result's. Use the expected result to check result. rtol (float, optional): relative tolerance, default 1.e-5. atol (float, optional): absolute tolerance, default 1.e-8 """ self._check_output_impl( result, expected_result, rtol, atol, equal=False) class CpuCommonTest(CommonTest): def __init__(self, methodName='runTest'): super(CpuCommonTest, self).__init__(methodName=methodName) self.places = ['cpu'] CommonTest = _CommonTestNamespace.CommonTest CpuCommonTest = _CommonTestNamespace.CpuCommonTest