|
|
@ -18,6 +18,7 @@ import paddle |
|
|
|
import numpy as np |
|
|
|
import numpy as np |
|
|
|
from paddle.static import InputSpec |
|
|
|
from paddle.static import InputSpec |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from paddlers.utils import logging |
|
|
|
from testing_utils import CommonTest |
|
|
|
from testing_utils import CommonTest |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -37,20 +38,26 @@ class _TestModelNamespace: |
|
|
|
for i, ( |
|
|
|
for i, ( |
|
|
|
input, model, target |
|
|
|
input, model, target |
|
|
|
) in enumerate(zip(self.inputs, self.models, self.targets)): |
|
|
|
) in enumerate(zip(self.inputs, self.models, self.targets)): |
|
|
|
with self.subTest(i=i): |
|
|
|
try: |
|
|
|
if isinstance(input, list): |
|
|
|
if isinstance(input, list): |
|
|
|
output = model(*input) |
|
|
|
output = model(*input) |
|
|
|
else: |
|
|
|
else: |
|
|
|
output = model(input) |
|
|
|
output = model(input) |
|
|
|
self.check_output(output, target) |
|
|
|
self.check_output(output, target) |
|
|
|
|
|
|
|
except: |
|
|
|
|
|
|
|
logging.warning(f"Model built with spec{i} failed!") |
|
|
|
|
|
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
def test_to_static(self): |
|
|
|
def test_to_static(self): |
|
|
|
for i, ( |
|
|
|
for i, ( |
|
|
|
input, model, target |
|
|
|
input, model, target |
|
|
|
) in enumerate(zip(self.inputs, self.models, self.targets)): |
|
|
|
) in enumerate(zip(self.inputs, self.models, self.targets)): |
|
|
|
with self.subTest(i=i): |
|
|
|
try: |
|
|
|
static_model = paddle.jit.to_static( |
|
|
|
static_model = paddle.jit.to_static( |
|
|
|
model, input_spec=self.get_input_spec(model, input)) |
|
|
|
model, input_spec=self.get_input_spec(model, input)) |
|
|
|
|
|
|
|
except: |
|
|
|
|
|
|
|
logging.warning(f"Model built with spec{i} failed!") |
|
|
|
|
|
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
def check_output(self, output, target): |
|
|
|
def check_output(self, output, target): |
|
|
|
pass |
|
|
|
pass |
|
|
@ -117,4 +124,27 @@ class _TestModelNamespace: |
|
|
|
return input_spec |
|
|
|
return input_spec |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def allow_oom(cls): |
|
|
|
|
|
|
|
def _deco(func): |
|
|
|
|
|
|
|
def _wrapper(self, *args, **kwargs): |
|
|
|
|
|
|
|
try: |
|
|
|
|
|
|
|
func(self, *args, **kwargs) |
|
|
|
|
|
|
|
except (SystemError, RuntimeError, OSError) as e: |
|
|
|
|
|
|
|
msg = str(e) |
|
|
|
|
|
|
|
if "Out of memory error" in msg \ |
|
|
|
|
|
|
|
or "(External) CUDNN error(4), CUDNN_STATUS_INTERNAL_ERROR." in msg: |
|
|
|
|
|
|
|
logging.warning("An OOM error has been ignored.") |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return _wrapper |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for key, value in inspect.getmembers(cls): |
|
|
|
|
|
|
|
if key.startswith('test'): |
|
|
|
|
|
|
|
value = _deco(value) |
|
|
|
|
|
|
|
setattr(cls, key, value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return cls |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TestModel = _TestModelNamespace.TestModel |
|
|
|
TestModel = _TestModelNamespace.TestModel |
|
|
|