You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
152 lines
4.8 KiB
152 lines
4.8 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
import inspect |
|
|
|
import paddle |
|
import numpy as np |
|
from paddle.static import InputSpec |
|
|
|
from paddlers.utils import logging |
|
from testing_utils import CommonTest |
|
|
|
|
|
class _TestModelNamespace: |
|
class TestModel(CommonTest): |
|
MODEL_CLASS = None |
|
DEFAULT_HW = (256, 256) |
|
DEFAULT_BATCH_SIZE = 2 |
|
|
|
def setUp(self): |
|
self.set_specs() |
|
self.set_inputs() |
|
self.set_targets() |
|
self.set_models() |
|
|
|
def test_forward(self): |
|
for i, ( |
|
input, model, target |
|
) in enumerate(zip(self.inputs, self.models, self.targets)): |
|
try: |
|
if isinstance(input, list): |
|
output = model(*input) |
|
else: |
|
output = model(input) |
|
self.check_output(output, target) |
|
except: |
|
logging.warning(f"Model built with spec{i} failed!") |
|
raise |
|
|
|
def test_to_static(self): |
|
for i, ( |
|
input, model, target |
|
) in enumerate(zip(self.inputs, self.models, self.targets)): |
|
try: |
|
static_model = paddle.jit.to_static( |
|
model, input_spec=self.get_input_spec(model, input)) |
|
except: |
|
logging.warning(f"Model built with spec{i} failed!") |
|
raise |
|
|
|
def check_output(self, output, target): |
|
pass |
|
|
|
def set_specs(self): |
|
self.specs = [] |
|
|
|
def set_models(self): |
|
self.models = (self.build_model(spec) for spec in self.specs) |
|
|
|
def set_inputs(self): |
|
self.inputs = [] |
|
|
|
def set_targets(self): |
|
self.targets = [] |
|
|
|
def build_model(self, spec): |
|
if '_phase' in spec: |
|
phase = spec.pop('_phase') |
|
else: |
|
phase = 'train' |
|
if '_stop_grad' in spec: |
|
stop_grad = spec.pop('_stop_grad') |
|
else: |
|
stop_grad = False |
|
|
|
model = self.MODEL_CLASS(**spec) |
|
|
|
if phase == 'train': |
|
model.train() |
|
elif phase == 'eval': |
|
model.eval() |
|
if stop_grad: |
|
for p in model.parameters(): |
|
p.stop_gradient = True |
|
|
|
return model |
|
|
|
def get_shape(self, c, b=None, h=None, w=None): |
|
if h is None or w is None: |
|
h, w = self.DEFAULT_HW |
|
if b is None: |
|
b = self.DEFAULT_BATCH_SIZE |
|
return (b, c, h, w) |
|
|
|
def get_zeros_array(self, c, b=None, h=None, w=None): |
|
shape = self.get_shape(c, b, h, w) |
|
return np.zeros(shape) |
|
|
|
def get_randn_tensor(self, c, b=None, h=None, w=None): |
|
shape = self.get_shape(c, b, h, w) |
|
return paddle.randn(shape) |
|
|
|
def get_input_spec(self, model, input): |
|
if not isinstance(input, list): |
|
input = [input] |
|
input_spec = [] |
|
for param_name, tensor in zip( |
|
inspect.signature(model.forward).parameters, input): |
|
# XXX: Hard-code dtype |
|
input_spec.append( |
|
InputSpec( |
|
shape=tensor.shape, name=param_name, dtype='float32')) |
|
return input_spec |
|
|
|
|
|
def allow_oom(cls): |
|
def _deco(func): |
|
def _wrapper(self, *args, **kwargs): |
|
try: |
|
func(self, *args, **kwargs) |
|
except (SystemError, RuntimeError, OSError, MemoryError) as e: |
|
# XXX: This may not cover all OOM cases. |
|
msg = str(e) |
|
if "Out of memory error" in msg \ |
|
or "(External) CUDNN error(4), CUDNN_STATUS_INTERNAL_ERROR." in msg \ |
|
or isinstance(e, MemoryError): |
|
logging.warning("An OOM error has been ignored.") |
|
else: |
|
raise |
|
|
|
return _wrapper |
|
|
|
for key, value in inspect.getmembers(cls): |
|
if key.startswith('test'): |
|
value = _deco(value) |
|
setattr(cls, key, value) |
|
|
|
return cls |
|
|
|
|
|
TestModel = _TestModelNamespace.TestModel
|
|
|