Fix restoration bugs

own
Bobholamovic 2 years ago
parent fded16b588
commit d3c1499a87
  1. 1
      docs/apis/train.md
  2. 2
      docs/dev/dev_guide.md
  3. 6
      paddlers/tasks/base.py
  4. 1
      paddlers/tasks/change_detector.py
  5. 1
      paddlers/tasks/classifier.py
  6. 1
      paddlers/tasks/object_detector.py
  7. 48
      paddlers/tasks/restorer.py
  8. 1
      paddlers/tasks/segmenter.py
  9. 11
      paddlers/utils/checkpoint.py

@ -30,6 +30,7 @@
- 一般支持设置`sr_factor`参数,表示超分辨率倍数;对于不支持超分辨率重建任务的模型,`sr_factor`设置为`None`。 - 一般支持设置`sr_factor`参数,表示超分辨率倍数;对于不支持超分辨率重建任务的模型,`sr_factor`设置为`None`。
- 可通过`losses`参数指定模型训练时使用的损失函数,传入实参需为可调用对象或字典。手动指定的`losses`与子类的`default_loss()`方法返回值必须具有相同的格式。 - 可通过`losses`参数指定模型训练时使用的损失函数,传入实参需为可调用对象或字典。手动指定的`losses`与子类的`default_loss()`方法返回值必须具有相同的格式。
- 可通过`min_max`参数指定模型输入、输出的数值范围;若为`None`,则使用类默认的数值范围。
- 不同的子类支持与模型相关的输入参数,详情请参考[模型定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/rs_models/res)和[训练器定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py)。 - 不同的子类支持与模型相关的输入参数,详情请参考[模型定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/rs_models/res)和[训练器定义](https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py)。
### 初始化`BaseSegmenter`子类对象 ### 初始化`BaseSegmenter`子类对象

@ -64,7 +64,7 @@ Args:
2. 在`paddlers/tasks`目录中找到任务对应的训练器定义文件(例如变化检测任务对应`paddlers/tasks/change_detector.py`)。 2. 在`paddlers/tasks`目录中找到任务对应的训练器定义文件(例如变化检测任务对应`paddlers/tasks/change_detector.py`)。
3. 在文件尾部追加新的训练器定义。训练器需要继承自相关的基类(例如`BaseChangeDetector`),重写`__init__()`方法,并根据需要重写其他方法。对训练器`__init__()`方法编写的要求如下: 3. 在文件尾部追加新的训练器定义。训练器需要继承自相关的基类(例如`BaseChangeDetector`),重写`__init__()`方法,并根据需要重写其他方法。对训练器`__init__()`方法编写的要求如下:
- 对于变化检测、场景分类、目标检测、图像分割任务,`__init__()`方法的第1个输入参数是`num_classes`,表示模型输出类别数。对于变化检测、场景分类、图像分割任务,第2个输入参数是`use_mixed_loss`,表示用户是否使用默认定义的混合损失;第3个输入参数是`losses`,表示训练时使用的损失函数。对于图像复原任务,第1个参数是`losses`,含义同上;第2个参数是`rs_factor`,表示超分辨率缩放倍数。 - 对于变化检测、场景分类、目标检测、图像分割任务,`__init__()`方法的第1个输入参数是`num_classes`,表示模型输出类别数。对于变化检测、场景分类、图像分割任务,第2个输入参数是`use_mixed_loss`,表示用户是否使用默认定义的混合损失;第3个输入参数是`losses`,表示训练时使用的损失函数。对于图像复原任务,第1个参数是`losses`,含义同上;第2个参数是`rs_factor`,表示超分辨率缩放倍数;第3个参数是`min_max`,表示输入、输出影像的数值范围
- `__init__()`的所有输入参数都必须有默认值,且在**取默认值的情况下,模型接收3通道RGB输入**。 - `__init__()`的所有输入参数都必须有默认值,且在**取默认值的情况下,模型接收3通道RGB输入**。
- 在`__init__()`中需要更新`params`字典,该字典中的键值对将被用作模型构造时的输入参数。 - 在`__init__()`中需要更新`params`字典,该字典中的键值对将被用作模型构造时的输入参数。

@ -90,7 +90,8 @@ class BaseModel(metaclass=ModelMeta):
pretrain_weights=None, pretrain_weights=None,
save_dir='.', save_dir='.',
resume_checkpoint=None, resume_checkpoint=None,
is_backbone_weights=False): is_backbone_weights=False,
load_optim_state=True):
if pretrain_weights is not None and \ if pretrain_weights is not None and \
not osp.exists(pretrain_weights): not osp.exists(pretrain_weights):
if not osp.isdir(save_dir): if not osp.isdir(save_dir):
@ -148,7 +149,8 @@ class BaseModel(metaclass=ModelMeta):
self.net, self.net,
self.optimizer, self.optimizer,
model_name=self.model_name, model_name=self.model_name,
checkpoint=resume_checkpoint) checkpoint=resume_checkpoint,
load_optim_state=load_optim_state)
def get_model_info(self, get_raw_params=False, inplace=True): def get_model_info(self, get_raw_params=False, inplace=True):
if inplace: if inplace:

@ -528,6 +528,7 @@ class BaseChangeDetector(BaseModel):
return eval_metrics, eval_details return eval_metrics, eval_details
return eval_metrics return eval_metrics
@paddle.no_grad()
def predict(self, img_file, transforms=None): def predict(self, img_file, transforms=None):
""" """
Do inference. Do inference.

@ -432,6 +432,7 @@ class BaseClassifier(BaseModel):
return eval_metrics return eval_metrics
@paddle.no_grad()
def predict(self, img_file, transforms=None): def predict(self, img_file, transforms=None):
""" """
Do inference. Do inference.

@ -567,6 +567,7 @@ class BaseDetector(BaseModel):
return scores, self.eval_details return scores, self.eval_details
return scores return scores
@paddle.no_grad()
def predict(self, img_file, transforms=None): def predict(self, img_file, transforms=None):
""" """
Do inference. Do inference.

@ -43,7 +43,12 @@ class BaseRestorer(BaseModel):
MIN_MAX = (0., 1.) MIN_MAX = (0., 1.)
TEST_OUT_KEY = None TEST_OUT_KEY = None
def __init__(self, model_name, losses=None, sr_factor=None, **params): def __init__(self,
model_name,
losses=None,
sr_factor=None,
min_max=None,
**params):
self.init_params = locals() self.init_params = locals()
if 'with_net' in self.init_params: if 'with_net' in self.init_params:
del self.init_params['with_net'] del self.init_params['with_net']
@ -55,6 +60,8 @@ class BaseRestorer(BaseModel):
params.pop('with_net', None) params.pop('with_net', None)
self.net = self.build_net(**params) self.net = self.build_net(**params)
self.find_unused_parameters = True self.find_unused_parameters = True
if min_max is None:
self.min_max = self.MIN_MAX
def build_net(self, **params): def build_net(self, **params):
# Currently, only use models from cmres. # Currently, only use models from cmres.
@ -283,11 +290,13 @@ class BaseRestorer(BaseModel):
exit=True) exit=True)
pretrained_dir = osp.join(save_dir, 'pretrain') pretrained_dir = osp.join(save_dir, 'pretrain')
is_backbone_weights = pretrain_weights == 'IMAGENET' is_backbone_weights = pretrain_weights == 'IMAGENET'
# XXX: Currently, do not load optimizer state dict.
self.initialize_net( self.initialize_net(
pretrain_weights=pretrain_weights, pretrain_weights=pretrain_weights,
save_dir=pretrained_dir, save_dir=pretrained_dir,
resume_checkpoint=resume_checkpoint, resume_checkpoint=resume_checkpoint,
is_backbone_weights=is_backbone_weights) is_backbone_weights=is_backbone_weights,
load_optim_state=False)
self.train_loop( self.train_loop(
num_epochs=num_epochs, num_epochs=num_epochs,
@ -434,6 +443,7 @@ class BaseRestorer(BaseModel):
return eval_metrics return eval_metrics
@paddle.no_grad()
def predict(self, img_file, transforms=None): def predict(self, img_file, transforms=None):
""" """
Do inference. Do inference.
@ -653,9 +663,9 @@ class BaseRestorer(BaseModel):
if copy: if copy:
im = im.copy() im = im.copy()
if clip: if clip:
im = np.clip(im, self.MIN_MAX[0], self.MIN_MAX[1]) im = np.clip(im, self.min_max[0], self.min_max[1])
im -= im.min() im -= self.min_max[0]
im /= im.max() + 1e-32 im /= self.min_max[1] - self.min_max[0]
if quantize: if quantize:
im *= 255 im *= 255
im = im.astype('uint8') im = im.astype('uint8')
@ -668,6 +678,7 @@ class DRN(BaseRestorer):
def __init__(self, def __init__(self,
losses=None, losses=None,
sr_factor=4, sr_factor=4,
min_max=None,
scales=(2, 4), scales=(2, 4),
n_blocks=30, n_blocks=30,
n_feats=16, n_feats=16,
@ -691,7 +702,11 @@ class DRN(BaseRestorer):
self.dual_loss_weight = dual_loss_weight self.dual_loss_weight = dual_loss_weight
self.scales = scales self.scales = scales
super(DRN, self).__init__( super(DRN, self).__init__(
model_name='DRN', losses=losses, sr_factor=sr_factor, **params) model_name='DRN',
losses=losses,
sr_factor=sr_factor,
min_max=min_max,
**params)
def build_net(self, **params): def build_net(self, **params):
from ppgan.modules.init import init_weights from ppgan.modules.init import init_weights
@ -769,6 +784,7 @@ class LESRCNN(BaseRestorer):
def __init__(self, def __init__(self,
losses=None, losses=None,
sr_factor=4, sr_factor=4,
min_max=None,
multi_scale=False, multi_scale=False,
group=1, group=1,
**params): **params):
@ -778,7 +794,11 @@ class LESRCNN(BaseRestorer):
'group': group 'group': group
}) })
super(LESRCNN, self).__init__( super(LESRCNN, self).__init__(
model_name='LESRCNN', losses=losses, sr_factor=sr_factor, **params) model_name='LESRCNN',
losses=losses,
sr_factor=sr_factor,
min_max=min_max,
**params)
def build_net(self, **params): def build_net(self, **params):
net = ppgan.models.generators.LESRCNNGenerator(**params) net = ppgan.models.generators.LESRCNNGenerator(**params)
@ -789,6 +809,7 @@ class ESRGAN(BaseRestorer):
def __init__(self, def __init__(self,
losses=None, losses=None,
sr_factor=4, sr_factor=4,
min_max=None,
use_gan=True, use_gan=True,
in_channels=3, in_channels=3,
out_channels=3, out_channels=3,
@ -805,7 +826,11 @@ class ESRGAN(BaseRestorer):
}) })
self.use_gan = use_gan self.use_gan = use_gan
super(ESRGAN, self).__init__( super(ESRGAN, self).__init__(
model_name='ESRGAN', losses=losses, sr_factor=sr_factor, **params) model_name='ESRGAN',
losses=losses,
sr_factor=sr_factor,
min_max=min_max,
**params)
def build_net(self, **params): def build_net(self, **params):
from ppgan.modules.init import init_weights from ppgan.modules.init import init_weights
@ -932,6 +957,7 @@ class RCAN(BaseRestorer):
def __init__(self, def __init__(self,
losses=None, losses=None,
sr_factor=4, sr_factor=4,
min_max=None,
n_resgroups=10, n_resgroups=10,
n_resblocks=20, n_resblocks=20,
n_feats=64, n_feats=64,
@ -950,4 +976,8 @@ class RCAN(BaseRestorer):
'reduction': reduction 'reduction': reduction
}) })
super(RCAN, self).__init__( super(RCAN, self).__init__(
model_name='RCAN', losses=losses, sr_factor=sr_factor, **params) model_name='RCAN',
losses=losses,
sr_factor=sr_factor,
min_max=min_max,
**params)

@ -497,6 +497,7 @@ class BaseSegmenter(BaseModel):
return eval_metrics, eval_details return eval_metrics, eval_details
return eval_metrics return eval_metrics
@paddle.no_grad()
def predict(self, img_file, transforms=None): def predict(self, img_file, transforms=None):
""" """
Do inference. Do inference.

@ -527,11 +527,16 @@ def load_optimizer(optimizer, state_dict_path):
optimizer.set_state_dict(optim_state_dict) optimizer.set_state_dict(optim_state_dict)
def load_checkpoint(model, optimizer, model_name, checkpoint): def load_checkpoint(model,
optimizer,
model_name,
checkpoint,
load_optim_state=True):
logging.info("Loading checkpoint from {}".format(checkpoint)) logging.info("Loading checkpoint from {}".format(checkpoint))
load_pretrain_weights( load_pretrain_weights(
model, model,
pretrain_weights=osp.join(checkpoint, 'model.pdparams'), pretrain_weights=osp.join(checkpoint, 'model.pdparams'),
model_name=model_name) model_name=model_name)
load_optimizer( if load_optim_state:
optimizer, state_dict_path=osp.join(checkpoint, "model.pdopt")) load_optimizer(
optimizer, state_dict_path=osp.join(checkpoint, "model.pdopt"))

Loading…
Cancel
Save