diff --git a/tests/rs_models/test_cd_models.py b/tests/rs_models/test_cd_models.py index f712fd3..8478ea4 100644 --- a/tests/rs_models/test_cd_models.py +++ b/tests/rs_models/test_cd_models.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import platform from itertools import cycle import paddlers -from rs_models.test_model import TestModel +from rs_models.test_model import TestModel, allow_oom __all__ = [ 'TestBITModel', 'TestCDNetModel', 'TestChangeStarModel', 'TestDSAMNetModel', @@ -202,6 +201,7 @@ class TestSNUNetModel(TestCDModel): ] # yapf: disable +@allow_oom class TestSTANetModel(TestCDModel): MODEL_CLASS = paddlers.rs_models.cd.STANet @@ -216,6 +216,7 @@ class TestSTANetModel(TestCDModel): ] # yapf: disable +@allow_oom class TestChangeFormerModel(TestCDModel): MODEL_CLASS = paddlers.rs_models.cd.ChangeFormer @@ -226,9 +227,3 @@ class TestChangeFormerModel(TestCDModel): dict(**base_spec, decoder_softmax=True), dict(**base_spec, embed_dim=56) ] # yapf: disable - - -# HACK:FIXME: We observe an OOM error when running TestSTANetModel.test_forward() on a Windows machine. -# Currently, we do not perform this test. -if platform.system() == 'Windows': - TestSTANetModel.test_forward = lambda self: None diff --git a/tests/rs_models/test_model.py b/tests/rs_models/test_model.py index 06c4777..3c1c555 100644 --- a/tests/rs_models/test_model.py +++ b/tests/rs_models/test_model.py @@ -18,6 +18,7 @@ import paddle import numpy as np from paddle.static import InputSpec +from paddlers.utils import logging from testing_utils import CommonTest @@ -37,20 +38,26 @@ class _TestModelNamespace: for i, ( input, model, target ) in enumerate(zip(self.inputs, self.models, self.targets)): - with self.subTest(i=i): + try: if isinstance(input, list): output = model(*input) else: output = model(input) self.check_output(output, target) + except: + logging.warning(f"Model built with spec{i} failed!") + raise def test_to_static(self): for i, ( input, model, target ) in enumerate(zip(self.inputs, self.models, self.targets)): - with self.subTest(i=i): + try: static_model = paddle.jit.to_static( model, input_spec=self.get_input_spec(model, input)) + except: + logging.warning(f"Model built with spec{i} failed!") + raise def check_output(self, output, target): pass @@ -117,4 +124,27 @@ class _TestModelNamespace: return input_spec +def allow_oom(cls): + def _deco(func): + def _wrapper(self, *args, **kwargs): + try: + func(self, *args, **kwargs) + except (SystemError, RuntimeError, OSError) as e: + msg = str(e) + if "Out of memory error" in msg \ + or "(External) CUDNN error(4), CUDNN_STATUS_INTERNAL_ERROR." in msg: + logging.warning("An OOM error has been ignored.") + else: + raise + + return _wrapper + + for key, value in inspect.getmembers(cls): + if key.startswith('test'): + value = _deco(value) + setattr(cls, key, value) + + return cls + + TestModel = _TestModelNamespace.TestModel