Bypass OOM error for some models (#27)

own
Lin Manhui 3 years ago committed by GitHub
parent 66cf12d3a1
commit ee05f40d72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 11
      tests/rs_models/test_cd_models.py
  2. 34
      tests/rs_models/test_model.py

@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import platform
from itertools import cycle from itertools import cycle
import paddlers import paddlers
from rs_models.test_model import TestModel from rs_models.test_model import TestModel, allow_oom
__all__ = [ __all__ = [
'TestBITModel', 'TestCDNetModel', 'TestChangeStarModel', 'TestDSAMNetModel', 'TestBITModel', 'TestCDNetModel', 'TestChangeStarModel', 'TestDSAMNetModel',
@ -202,6 +201,7 @@ class TestSNUNetModel(TestCDModel):
] # yapf: disable ] # yapf: disable
@allow_oom
class TestSTANetModel(TestCDModel): class TestSTANetModel(TestCDModel):
MODEL_CLASS = paddlers.rs_models.cd.STANet MODEL_CLASS = paddlers.rs_models.cd.STANet
@ -216,6 +216,7 @@ class TestSTANetModel(TestCDModel):
] # yapf: disable ] # yapf: disable
@allow_oom
class TestChangeFormerModel(TestCDModel): class TestChangeFormerModel(TestCDModel):
MODEL_CLASS = paddlers.rs_models.cd.ChangeFormer MODEL_CLASS = paddlers.rs_models.cd.ChangeFormer
@ -226,9 +227,3 @@ class TestChangeFormerModel(TestCDModel):
dict(**base_spec, decoder_softmax=True), dict(**base_spec, decoder_softmax=True),
dict(**base_spec, embed_dim=56) dict(**base_spec, embed_dim=56)
] # yapf: disable ] # yapf: disable
# HACK:FIXME: We observe an OOM error when running TestSTANetModel.test_forward() on a Windows machine.
# Currently, we do not perform this test.
if platform.system() == 'Windows':
TestSTANetModel.test_forward = lambda self: None

@ -18,6 +18,7 @@ import paddle
import numpy as np import numpy as np
from paddle.static import InputSpec from paddle.static import InputSpec
from paddlers.utils import logging
from testing_utils import CommonTest from testing_utils import CommonTest
@ -37,20 +38,26 @@ class _TestModelNamespace:
for i, ( for i, (
input, model, target input, model, target
) in enumerate(zip(self.inputs, self.models, self.targets)): ) in enumerate(zip(self.inputs, self.models, self.targets)):
with self.subTest(i=i): try:
if isinstance(input, list): if isinstance(input, list):
output = model(*input) output = model(*input)
else: else:
output = model(input) output = model(input)
self.check_output(output, target) self.check_output(output, target)
except:
logging.warning(f"Model built with spec{i} failed!")
raise
def test_to_static(self): def test_to_static(self):
for i, ( for i, (
input, model, target input, model, target
) in enumerate(zip(self.inputs, self.models, self.targets)): ) in enumerate(zip(self.inputs, self.models, self.targets)):
with self.subTest(i=i): try:
static_model = paddle.jit.to_static( static_model = paddle.jit.to_static(
model, input_spec=self.get_input_spec(model, input)) model, input_spec=self.get_input_spec(model, input))
except:
logging.warning(f"Model built with spec{i} failed!")
raise
def check_output(self, output, target): def check_output(self, output, target):
pass pass
@ -117,4 +124,27 @@ class _TestModelNamespace:
return input_spec return input_spec
def allow_oom(cls):
def _deco(func):
def _wrapper(self, *args, **kwargs):
try:
func(self, *args, **kwargs)
except (SystemError, RuntimeError, OSError) as e:
msg = str(e)
if "Out of memory error" in msg \
or "(External) CUDNN error(4), CUDNN_STATUS_INTERNAL_ERROR." in msg:
logging.warning("An OOM error has been ignored.")
else:
raise
return _wrapper
for key, value in inspect.getmembers(cls):
if key.startswith('test'):
value = _deco(value)
setattr(cls, key, value)
return cls
TestModel = _TestModelNamespace.TestModel TestModel = _TestModelNamespace.TestModel

Loading…
Cancel
Save