From d3c1499a87a4e3d504c3be56b37003cce18ed280 Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Tue, 27 Sep 2022 21:59:36 +0800 Subject: [PATCH] Fix restoration bugs --- docs/apis/train.md | 1 + docs/dev/dev_guide.md | 2 +- paddlers/tasks/base.py | 6 ++-- paddlers/tasks/change_detector.py | 1 + paddlers/tasks/classifier.py | 1 + paddlers/tasks/object_detector.py | 1 + paddlers/tasks/restorer.py | 48 +++++++++++++++++++++++++------ paddlers/tasks/segmenter.py | 1 + paddlers/utils/checkpoint.py | 11 +++++-- 9 files changed, 57 insertions(+), 15 deletions(-) diff --git a/docs/apis/train.md b/docs/apis/train.md index db9528f..7de47d3 100644 --- a/docs/apis/train.md +++ b/docs/apis/train.md @@ -30,6 +30,7 @@ - 一般支持设置`sr_factor`参数,表示超分辨率倍数;对于不支持超分辨率重建任务的模型,`sr_factor`设置为`None`。 - 可通过`losses`参数指定模型训练时使用的损失函数,传入实参需为可调用对象或字典。手动指定的`losses`与子类的`default_loss()`方法返回值必须具有相同的格式。 +- 可通过`min_max`参数指定模型输入、输出的数值范围;若为`None`,则使用类默认的数值范围。 - 不同的子类支持与模型相关的输入参数,详情请参考[模型定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/rs_models/res)和[训练器定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py)。 ### 初始化`BaseSegmenter`子类对象 diff --git a/docs/dev/dev_guide.md b/docs/dev/dev_guide.md index e4de181..8e95232 100644 --- a/docs/dev/dev_guide.md +++ b/docs/dev/dev_guide.md @@ -64,7 +64,7 @@ Args: 2. 在`paddlers/tasks`目录中找到任务对应的训练器定义文件(例如变化检测任务对应`paddlers/tasks/change_detector.py`)。 3. 在文件尾部追加新的训练器定义。训练器需要继承自相关的基类(例如`BaseChangeDetector`),重写`__init__()`方法,并根据需要重写其他方法。对训练器`__init__()`方法编写的要求如下: - - 对于变化检测、场景分类、目标检测、图像分割任务,`__init__()`方法的第1个输入参数是`num_classes`,表示模型输出类别数。对于变化检测、场景分类、图像分割任务,第2个输入参数是`use_mixed_loss`,表示用户是否使用默认定义的混合损失;第3个输入参数是`losses`,表示训练时使用的损失函数。对于图像复原任务,第1个参数是`losses`,含义同上;第2个参数是`rs_factor`,表示超分辨率缩放倍数。 + - 对于变化检测、场景分类、目标检测、图像分割任务,`__init__()`方法的第1个输入参数是`num_classes`,表示模型输出类别数。对于变化检测、场景分类、图像分割任务,第2个输入参数是`use_mixed_loss`,表示用户是否使用默认定义的混合损失;第3个输入参数是`losses`,表示训练时使用的损失函数。对于图像复原任务,第1个参数是`losses`,含义同上;第2个参数是`rs_factor`,表示超分辨率缩放倍数;第3个参数是`min_max`,表示输入、输出影像的数值范围。 - `__init__()`的所有输入参数都必须有默认值,且在**取默认值的情况下,模型接收3通道RGB输入**。 - 在`__init__()`中需要更新`params`字典,该字典中的键值对将被用作模型构造时的输入参数。 diff --git a/paddlers/tasks/base.py b/paddlers/tasks/base.py index 0c32bb5..a95befb 100644 --- a/paddlers/tasks/base.py +++ b/paddlers/tasks/base.py @@ -90,7 +90,8 @@ class BaseModel(metaclass=ModelMeta): pretrain_weights=None, save_dir='.', resume_checkpoint=None, - is_backbone_weights=False): + is_backbone_weights=False, + load_optim_state=True): if pretrain_weights is not None and \ not osp.exists(pretrain_weights): if not osp.isdir(save_dir): @@ -148,7 +149,8 @@ class BaseModel(metaclass=ModelMeta): self.net, self.optimizer, model_name=self.model_name, - checkpoint=resume_checkpoint) + checkpoint=resume_checkpoint, + load_optim_state=load_optim_state) def get_model_info(self, get_raw_params=False, inplace=True): if inplace: diff --git a/paddlers/tasks/change_detector.py b/paddlers/tasks/change_detector.py index 59d3991..5ec1437 100644 --- a/paddlers/tasks/change_detector.py +++ b/paddlers/tasks/change_detector.py @@ -528,6 +528,7 @@ class BaseChangeDetector(BaseModel): return eval_metrics, eval_details return eval_metrics + @paddle.no_grad() def predict(self, img_file, transforms=None): """ Do inference. diff --git a/paddlers/tasks/classifier.py b/paddlers/tasks/classifier.py index 5773f62..c1074a3 100644 --- a/paddlers/tasks/classifier.py +++ b/paddlers/tasks/classifier.py @@ -432,6 +432,7 @@ class BaseClassifier(BaseModel): return eval_metrics + @paddle.no_grad() def predict(self, img_file, transforms=None): """ Do inference. diff --git a/paddlers/tasks/object_detector.py b/paddlers/tasks/object_detector.py index f42d7cc..e750075 100644 --- a/paddlers/tasks/object_detector.py +++ b/paddlers/tasks/object_detector.py @@ -567,6 +567,7 @@ class BaseDetector(BaseModel): return scores, self.eval_details return scores + @paddle.no_grad() def predict(self, img_file, transforms=None): """ Do inference. diff --git a/paddlers/tasks/restorer.py b/paddlers/tasks/restorer.py index ff8708b..9943f80 100644 --- a/paddlers/tasks/restorer.py +++ b/paddlers/tasks/restorer.py @@ -43,7 +43,12 @@ class BaseRestorer(BaseModel): MIN_MAX = (0., 1.) TEST_OUT_KEY = None - def __init__(self, model_name, losses=None, sr_factor=None, **params): + def __init__(self, + model_name, + losses=None, + sr_factor=None, + min_max=None, + **params): self.init_params = locals() if 'with_net' in self.init_params: del self.init_params['with_net'] @@ -55,6 +60,8 @@ class BaseRestorer(BaseModel): params.pop('with_net', None) self.net = self.build_net(**params) self.find_unused_parameters = True + if min_max is None: + self.min_max = self.MIN_MAX def build_net(self, **params): # Currently, only use models from cmres. @@ -283,11 +290,13 @@ class BaseRestorer(BaseModel): exit=True) pretrained_dir = osp.join(save_dir, 'pretrain') is_backbone_weights = pretrain_weights == 'IMAGENET' + # XXX: Currently, do not load optimizer state dict. self.initialize_net( pretrain_weights=pretrain_weights, save_dir=pretrained_dir, resume_checkpoint=resume_checkpoint, - is_backbone_weights=is_backbone_weights) + is_backbone_weights=is_backbone_weights, + load_optim_state=False) self.train_loop( num_epochs=num_epochs, @@ -434,6 +443,7 @@ class BaseRestorer(BaseModel): return eval_metrics + @paddle.no_grad() def predict(self, img_file, transforms=None): """ Do inference. @@ -653,9 +663,9 @@ class BaseRestorer(BaseModel): if copy: im = im.copy() if clip: - im = np.clip(im, self.MIN_MAX[0], self.MIN_MAX[1]) - im -= im.min() - im /= im.max() + 1e-32 + im = np.clip(im, self.min_max[0], self.min_max[1]) + im -= self.min_max[0] + im /= self.min_max[1] - self.min_max[0] if quantize: im *= 255 im = im.astype('uint8') @@ -668,6 +678,7 @@ class DRN(BaseRestorer): def __init__(self, losses=None, sr_factor=4, + min_max=None, scales=(2, 4), n_blocks=30, n_feats=16, @@ -691,7 +702,11 @@ class DRN(BaseRestorer): self.dual_loss_weight = dual_loss_weight self.scales = scales super(DRN, self).__init__( - model_name='DRN', losses=losses, sr_factor=sr_factor, **params) + model_name='DRN', + losses=losses, + sr_factor=sr_factor, + min_max=min_max, + **params) def build_net(self, **params): from ppgan.modules.init import init_weights @@ -769,6 +784,7 @@ class LESRCNN(BaseRestorer): def __init__(self, losses=None, sr_factor=4, + min_max=None, multi_scale=False, group=1, **params): @@ -778,7 +794,11 @@ class LESRCNN(BaseRestorer): 'group': group }) super(LESRCNN, self).__init__( - model_name='LESRCNN', losses=losses, sr_factor=sr_factor, **params) + model_name='LESRCNN', + losses=losses, + sr_factor=sr_factor, + min_max=min_max, + **params) def build_net(self, **params): net = ppgan.models.generators.LESRCNNGenerator(**params) @@ -789,6 +809,7 @@ class ESRGAN(BaseRestorer): def __init__(self, losses=None, sr_factor=4, + min_max=None, use_gan=True, in_channels=3, out_channels=3, @@ -805,7 +826,11 @@ class ESRGAN(BaseRestorer): }) self.use_gan = use_gan super(ESRGAN, self).__init__( - model_name='ESRGAN', losses=losses, sr_factor=sr_factor, **params) + model_name='ESRGAN', + losses=losses, + sr_factor=sr_factor, + min_max=min_max, + **params) def build_net(self, **params): from ppgan.modules.init import init_weights @@ -932,6 +957,7 @@ class RCAN(BaseRestorer): def __init__(self, losses=None, sr_factor=4, + min_max=None, n_resgroups=10, n_resblocks=20, n_feats=64, @@ -950,4 +976,8 @@ class RCAN(BaseRestorer): 'reduction': reduction }) super(RCAN, self).__init__( - model_name='RCAN', losses=losses, sr_factor=sr_factor, **params) + model_name='RCAN', + losses=losses, + sr_factor=sr_factor, + min_max=min_max, + **params) diff --git a/paddlers/tasks/segmenter.py b/paddlers/tasks/segmenter.py index 5cac754..a9a60a7 100644 --- a/paddlers/tasks/segmenter.py +++ b/paddlers/tasks/segmenter.py @@ -497,6 +497,7 @@ class BaseSegmenter(BaseModel): return eval_metrics, eval_details return eval_metrics + @paddle.no_grad() def predict(self, img_file, transforms=None): """ Do inference. diff --git a/paddlers/utils/checkpoint.py b/paddlers/utils/checkpoint.py index 401d7e3..ca59645 100644 --- a/paddlers/utils/checkpoint.py +++ b/paddlers/utils/checkpoint.py @@ -527,11 +527,16 @@ def load_optimizer(optimizer, state_dict_path): optimizer.set_state_dict(optim_state_dict) -def load_checkpoint(model, optimizer, model_name, checkpoint): +def load_checkpoint(model, + optimizer, + model_name, + checkpoint, + load_optim_state=True): logging.info("Loading checkpoint from {}".format(checkpoint)) load_pretrain_weights( model, pretrain_weights=osp.join(checkpoint, 'model.pdparams'), model_name=model_name) - load_optimizer( - optimizer, state_dict_path=osp.join(checkpoint, "model.pdopt")) + if load_optim_state: + load_optimizer( + optimizer, state_dict_path=osp.join(checkpoint, "model.pdopt"))