Add unittests for restoration tasks

own
Bobholamovic 2 years ago
parent 752d8e41cf
commit 4d58b9561d
  1. 2
      paddlers/tasks/base.py
  2. 81
      paddlers/tasks/classifier.py
  3. 20
      paddlers/tasks/object_detector.py
  4. 54
      paddlers/tasks/restorer.py
  5. 2
      paddlers/transforms/operators.py
  6. 54
      tests/data/data_utils.py
  7. 61
      tests/deploy/test_predictor.py
  8. 4
      tests/rs_models/test_cd_models.py
  9. 3
      tests/rs_models/test_det_models.py
  10. 32
      tests/rs_models/test_res_models.py
  11. 4
      tests/rs_models/test_seg_models.py
  12. 43
      tests/transforms/test_operators.py
  13. 4
      tutorials/train/README.md
  14. 2
      tutorials/train/change_detection/changeformer.py
  15. 4
      tutorials/train/image_restoration/drn.py
  16. 4
      tutorials/train/image_restoration/esrgan.py
  17. 4
      tutorials/train/image_restoration/lesrcnn.py
  18. 86
      tutorials/train/image_restoration/rcan.py

@ -267,7 +267,7 @@ class BaseModel(metaclass=ModelMeta):
'The volume of dataset({}) must be larger than batch size({}).'
.format(dataset.num_samples, batch_size))
batch_size_each_card = get_single_card_bs(batch_size=batch_size)
# TODO: Make judgement in detection eval phase.
batch_sampler = DistributedBatchSampler(
dataset,
batch_size=batch_size_each_card,

@ -397,38 +397,37 @@ class BaseClassifier(BaseModel):
):
paddle.distributed.init_parallel_env()
batch_size_each_card = get_single_card_bs(batch_size)
if batch_size_each_card > 1:
batch_size_each_card = 1
batch_size = batch_size_each_card * paddlers.env_info['num']
if batch_size > 1:
logging.warning(
"Classifier only supports batch_size=1 for each gpu/cpu card " \
"during evaluation, so batch_size " \
"is forcibly set to {}.".format(batch_size))
self.eval_data_loader = self.build_data_loader(
eval_dataset, batch_size=batch_size, mode='eval')
logging.info(
"Start to evaluate(total_samples={}, total_steps={})...".format(
eval_dataset.num_samples,
math.ceil(eval_dataset.num_samples * 1.0 / batch_size)))
top1s = []
top5s = []
with paddle.no_grad():
for step, data in enumerate(self.eval_data_loader):
data.append(eval_dataset.transforms.transforms)
outputs = self.run(self.net, data, 'eval')
top1s.append(outputs["top1"])
top5s.append(outputs["top5"])
top1 = np.mean(top1s)
top5 = np.mean(top5s)
eval_metrics = OrderedDict(zip(['top1', 'top5'], [top1, top5]))
if return_details:
# TODO: Add details
return eval_metrics, None
return eval_metrics
"Classifier only supports single card evaluation with batch_size=1 "
"during evaluation, so batch_size is forcibly set to 1.")
batch_size = 1
if nranks < 2 or local_rank == 0:
self.eval_data_loader = self.build_data_loader(
eval_dataset, batch_size=batch_size, mode='eval')
logging.info(
"Start to evaluate(total_samples={}, total_steps={})...".format(
eval_dataset.num_samples, eval_dataset.num_samples))
top1s = []
top5s = []
with paddle.no_grad():
for step, data in enumerate(self.eval_data_loader):
data.append(eval_dataset.transforms.transforms)
outputs = self.run(self.net, data, 'eval')
top1s.append(outputs["top1"])
top5s.append(outputs["top5"])
top1 = np.mean(top1s)
top5 = np.mean(top5s)
eval_metrics = OrderedDict(zip(['top1', 'top5'], [top1, top5]))
if return_details:
# TODO: Add details
return eval_metrics, None
return eval_metrics
def predict(self, img_file, transforms=None):
"""
@ -561,6 +560,26 @@ class BaseClassifier(BaseModel):
raise TypeError(
"`transforms.arrange` must be an ArrangeClassifier object.")
def build_data_loader(self, dataset, batch_size, mode='train'):
if dataset.num_samples < batch_size:
raise ValueError(
'The volume of dataset({}) must be larger than batch size({}).'
.format(dataset.num_samples, batch_size))
if mode != 'train':
return paddle.io.DataLoader(
dataset,
batch_size=batch_size,
shuffle=dataset.shuffle,
drop_last=False,
collate_fn=dataset.batch_transforms,
num_workers=dataset.num_workers,
return_list=True,
use_shared_memory=False)
else:
return super(BaseClassifier, self).build_data_loader(
dataset, batch_size, mode)
class ResNet50_vd(BaseClassifier):
def __init__(self,

@ -983,6 +983,26 @@ class PicoDet(BaseDetector):
use_vdl=use_vdl,
resume_checkpoint=resume_checkpoint)
def build_data_loader(self, dataset, batch_size, mode='train'):
if dataset.num_samples < batch_size:
raise ValueError(
'The volume of dataset({}) must be larger than batch size({}).'
.format(dataset.num_samples, batch_size))
if mode != 'train':
return paddle.io.DataLoader(
dataset,
batch_size=batch_size,
shuffle=dataset.shuffle,
drop_last=False,
collate_fn=dataset.batch_transforms,
num_workers=dataset.num_workers,
return_list=True,
use_shared_memory=False)
else:
return super(BaseDetector, self).build_data_loader(dataset,
batch_size, mode)
class YOLOv3(BaseDetector):
def __init__(self,

@ -35,7 +35,7 @@ from .base import BaseModel
from .utils.res_adapters import GANAdapter, OptimizerAdapter
from .utils.infer_nets import InferResNet
__all__ = []
__all__ = ["DRN", "LESRCNN", "ESRGAN"]
class BaseRestorer(BaseModel):
@ -381,22 +381,22 @@ class BaseRestorer(BaseModel):
):
paddle.distributed.init_parallel_env()
batch_size_each_card = get_single_card_bs(batch_size)
if batch_size_each_card > 1:
batch_size_each_card = 1
batch_size = batch_size_each_card * paddlers.env_info['num']
# TODO: Distributed evaluation
if batch_size > 1:
logging.warning(
"Restorer only supports batch_size=1 for each gpu/cpu card " \
"during evaluation, so batch_size " \
"is forcibly set to {}.".format(batch_size))
"Restorer only supports single card evaluation with batch_size=1 "
"during evaluation, so batch_size is forcibly set to 1.")
batch_size = 1
# TODO: Distributed evaluation
if nranks < 2 or local_rank == 0:
self.eval_data_loader = self.build_data_loader(
eval_dataset, batch_size=batch_size, mode='eval')
# XXX: Hard-code crop_border and test_y_channel
psnr = metrics.PSNR(crop_border=4, test_y_channel=True)
ssim = metrics.SSIM(crop_border=4, test_y_channel=True)
logging.info(
"Start to evaluate(total_samples={}, total_steps={})...".format(
eval_dataset.num_samples, eval_dataset.num_samples))
with paddle.no_grad():
for step, data in enumerate(self.eval_data_loader):
data.append(eval_dataset.transforms.transforms)
@ -404,14 +404,18 @@ class BaseRestorer(BaseModel):
psnr.update(outputs['pred'], outputs['tar'])
ssim.update(outputs['pred'], outputs['tar'])
eval_metrics = OrderedDict(
zip(['psnr', 'ssim'], [psnr.accumulate(), ssim.accumulate()]))
# DO NOT use psnr.accumulate() here, otherwise the program hangs in multi-card training.
assert len(psnr.results) > 0
assert len(ssim.results) > 0
eval_metrics = OrderedDict(
zip(['psnr', 'ssim'],
[np.mean(psnr.results), np.mean(ssim.results)]))
if return_details:
# TODO: Add details
return eval_metrics, None
if return_details:
# TODO: Add details
return eval_metrics, None
return eval_metrics
return eval_metrics
def predict(self, img_file, transforms=None):
"""
@ -591,6 +595,26 @@ class BaseRestorer(BaseModel):
raise TypeError(
"`transforms.arrange` must be an ArrangeRestorer object.")
def build_data_loader(self, dataset, batch_size, mode='train'):
if dataset.num_samples < batch_size:
raise ValueError(
'The volume of dataset({}) must be larger than batch size({}).'
.format(dataset.num_samples, batch_size))
if mode != 'train':
return paddle.io.DataLoader(
dataset,
batch_size=batch_size,
shuffle=dataset.shuffle,
drop_last=False,
collate_fn=dataset.batch_transforms,
num_workers=dataset.num_workers,
return_list=True,
use_shared_memory=False)
else:
return super(BaseRestorer, self).build_data_loader(dataset,
batch_size, mode)
def set_losses(self, losses):
self.losses = losses

@ -1793,7 +1793,7 @@ class SelectBand(Transform):
def __init__(self, band_list=[1, 2, 3], apply_to_tar=True):
super(SelectBand, self).__init__()
self.band_list = band_list
self.appy_to_tar = apply_to_tar
self.apply_to_tar = apply_to_tar
def apply_im(self, image):
image = select_bands(image, self.band_list)

@ -14,7 +14,6 @@
import os.path as osp
import re
import imghdr
import platform
from collections import OrderedDict
from functools import partial, wraps
@ -34,20 +33,6 @@ def norm_path(path):
return path
def is_pic(im_path):
valid_suffix = [
'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png', 'npy'
]
suffix = im_path.split('.')[-1]
if suffix in valid_suffix:
return True
im_format = imghdr.what(im_path)
_, ext = osp.splitext(im_path)
if im_format == 'tiff' or ext == '.img':
return True
return False
def get_full_path(p, prefix=''):
p = norm_path(p)
return osp.join(prefix, p)
@ -323,15 +308,34 @@ class ConstrDetSample(ConstrSample):
return samples
def build_input_from_file(file_list, prefix='', task='auto', label_list=None):
class ConstrResSample(ConstrSample):
def __init__(self, prefix, label_list, sr_factor=None):
super().__init__(prefix, label_list)
self.sr_factor = sr_factor
def __call__(self, src_path, tar_path):
sample = {
'image': self.get_full_path(src_path),
'target': self.get_full_path(tar_path)
}
if self.sr_factor is not None:
sample['sr_factor'] = self.sr_factor
return sample
def build_input_from_file(file_list,
prefix='',
task='auto',
label_list=None,
**kwargs):
"""
Construct a list of dictionaries from file. Each dict in the list can be used as the input to paddlers.transforms.Transform objects.
Args:
file_list (str): Path of file_list.
file_list (str): Path of file list.
prefix (str, optional): A nonempty `prefix` specifies the directory that stores the images and annotation files. Default: ''.
task (str, optional): Supported values are 'seg', 'det', 'cd', 'clas', and 'auto'. When `task` is set to 'auto', automatically determine the task based on the input.
Default: 'auto'.
task (str, optional): Supported values are 'seg', 'det', 'cd', 'clas', 'res', and 'auto'. When `task` is set to 'auto',
automatically determine the task based on the input. Default: 'auto'.
label_list (str|None, optional): Path of label_list. Default: None.
Returns:
@ -339,22 +343,21 @@ def build_input_from_file(file_list, prefix='', task='auto', label_list=None):
"""
def _determine_task(parts):
task = 'unknown'
if len(parts) in (3, 5):
task = 'cd'
elif len(parts) == 2:
if parts[1].isdigit():
task = 'clas'
elif is_pic(osp.join(prefix, parts[1])):
task = 'seg'
else:
elif parts[1].endswith('.xml'):
task = 'det'
else:
if task == 'unknown':
raise RuntimeError(
"Cannot automatically determine the task type. Please specify `task` manually."
)
return task
if task not in ('seg', 'det', 'cd', 'clas', 'auto'):
if task not in ('seg', 'det', 'cd', 'clas', 'res', 'auto'):
raise ValueError("Invalid value of `task`")
samples = []
@ -366,9 +369,8 @@ def build_input_from_file(file_list, prefix='', task='auto', label_list=None):
if task == 'auto':
task = _determine_task(parts)
if ctor is None:
# Select and build sample constructor
ctor_class = globals()['Constr' + task.capitalize() + 'Sample']
ctor = ctor_class(prefix, label_list)
ctor = ctor_class(prefix, label_list, **kwargs)
sample = ctor(*parts)
if isinstance(sample, list):
samples.extend(sample)

@ -105,7 +105,7 @@ class TestPredictor(CommonTest):
dict_[key], expected_dict[key], rtol=1.e-4, atol=1.e-6)
@TestPredictor.add_tests
# @TestPredictor.add_tests
class TestCDPredictor(TestPredictor):
MODULE = pdrs.tasks.change_detector
TRAINER_NAME_TO_EXPORT_OPTS = {
@ -177,7 +177,7 @@ class TestCDPredictor(TestPredictor):
self.assertEqual(len(out_multi_array_t), num_inputs)
@TestPredictor.add_tests
# @TestPredictor.add_tests
class TestClasPredictor(TestPredictor):
MODULE = pdrs.tasks.classifier
TRAINER_NAME_TO_EXPORT_OPTS = {
@ -185,7 +185,7 @@ class TestClasPredictor(TestPredictor):
}
def check_predictor(self, predictor, trainer):
single_input = "data/ssmt/optical_t1.bmp"
single_input = "data/ssst/optical.bmp"
num_inputs = 2
transforms = pdrs.transforms.Compose([
pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(),
@ -242,7 +242,7 @@ class TestClasPredictor(TestPredictor):
self.check_dict_equal(out_multi_array_p, out_multi_array_t)
@TestPredictor.add_tests
# @TestPredictor.add_tests
class TestDetPredictor(TestPredictor):
MODULE = pdrs.tasks.object_detector
TRAINER_NAME_TO_EXPORT_OPTS = {
@ -253,7 +253,7 @@ class TestDetPredictor(TestPredictor):
# For detection tasks, do NOT ensure the consistence of bboxes.
# This is because the coordinates of bboxes were observed to be very sensitive to numeric errors,
# given that the network is (partially?) randomly initialized.
single_input = "data/ssmt/optical_t1.bmp"
single_input = "data/ssst/optical.bmp"
num_inputs = 2
transforms = pdrs.transforms.Compose([
pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(),
@ -307,10 +307,55 @@ class TestResPredictor(TestPredictor):
MODULE = pdrs.tasks.restorer
def check_predictor(self, predictor, trainer):
pass
# For restoration tasks, do NOT ensure the consistence of numeric values,
# because the output is of uint8 type.
single_input = "data/ssst/optical.bmp"
num_inputs = 2
transforms = pdrs.transforms.Compose([
pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(),
pdrs.transforms.ArrangeRestorer('test')
])
# Single input (file path)
input_ = single_input
predictor.predict(input_, transforms=transforms)
trainer.predict(input_, transforms=transforms)
out_single_file_list_p = predictor.predict(
[input_], transforms=transforms)
self.assertEqual(len(out_single_file_list_p), 1)
out_single_file_list_t = trainer.predict(
[input_], transforms=transforms)
self.assertEqual(len(out_single_file_list_t), 1)
@TestPredictor.add_tests
# Single input (ndarray)
input_ = decode_image(
single_input, to_rgb=False) # Reuse the name `input_`
predictor.predict(input_, transforms=transforms)
trainer.predict(input_, transforms=transforms)
out_single_array_list_p = predictor.predict(
[input_], transforms=transforms)
self.assertEqual(len(out_single_array_list_p), 1)
out_single_array_list_t = trainer.predict(
[input_], transforms=transforms)
self.assertEqual(len(out_single_array_list_t), 1)
# Multiple inputs (file paths)
input_ = [single_input] * num_inputs # Reuse the name `input_`
out_multi_file_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_file_p), num_inputs)
out_multi_file_t = trainer.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_file_t), num_inputs)
# Multiple inputs (ndarrays)
input_ = [decode_image(
single_input, to_rgb=False)] * num_inputs # Reuse the name `input_`
out_multi_array_p = predictor.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_p), num_inputs)
out_multi_array_t = trainer.predict(input_, transforms=transforms)
self.assertEqual(len(out_multi_array_t), num_inputs)
# @TestPredictor.add_tests
class TestSegPredictor(TestPredictor):
MODULE = pdrs.tasks.segmenter
TRAINER_NAME_TO_EXPORT_OPTS = {
@ -318,7 +363,7 @@ class TestSegPredictor(TestPredictor):
}
def check_predictor(self, predictor, trainer):
single_input = "data/ssmt/optical_t1.bmp"
single_input = "data/ssst/optical.bmp"
num_inputs = 2
transforms = pdrs.transforms.Compose([
pdrs.transforms.DecodeImg(), pdrs.transforms.Normalize(),

@ -34,9 +34,7 @@ class TestCDModel(TestModel):
self.check_output_equal(len(output), len(target))
for o, t in zip(output, target):
o = o.numpy()
self.check_output_equal(o.shape[0], t.shape[0])
self.check_output_equal(len(o.shape), 4)
self.check_output_equal(o.shape[2:], t.shape[2:])
self.check_output_equal(o.shape, t.shape)
def set_inputs(self):
if self.EF_MODE == 'Concat':

@ -32,3 +32,6 @@ class TestDetModel(TestModel):
def set_inputs(self):
self.inputs = cycle([self.get_randn_tensor(3)])
def set_targets(self):
self.targets = cycle([None])

@ -15,22 +15,32 @@
import paddlers
from rs_models.test_model import TestModel
__all__ = ['TestRCANModel']
__all__ = []
class TestResModel(TestModel):
def check_output(self, output, target):
pass
output = output.numpy()
self.check_output_equal(output.shape, target.shape)
def set_inputs(self):
pass
def set_targets(self):
pass
def _gen_data(specs):
for spec in specs:
c = spec.get('in_channels', 3)
yield self.get_randn_tensor(c)
self.inputs = _gen_data(self.specs)
class TestRCANModel(TestSegModel):
MODEL_CLASS = paddlers.rs_models.res.RCAN
def set_specs(self):
pass
def set_targets(self):
def _gen_data(specs):
for spec in specs:
# XXX: Hard coding
if 'out_channels' in spec:
c = spec['out_channels']
elif 'in_channels' in spec:
c = spec['in_channels']
else:
c = 3
yield [self.get_zeros_array(c)]
self.targets = _gen_data(self.specs)

@ -26,9 +26,7 @@ class TestSegModel(TestModel):
self.check_output_equal(len(output), len(target))
for o, t in zip(output, target):
o = o.numpy()
self.check_output_equal(o.shape[0], t.shape[0])
self.check_output_equal(len(o.shape), 4)
self.check_output_equal(o.shape[2:], t.shape[2:])
self.check_output_equal(o.shape, t.shape)
def set_inputs(self):
def _gen_data(specs):

@ -164,12 +164,15 @@ class TestTransform(CpuCommonTest):
prefix="./data/ssst"),
build_input_from_file(
"data/ssst/test_optical_seg.txt",
task='seg',
prefix="./data/ssst"),
build_input_from_file(
"data/ssst/test_sar_seg.txt",
task='seg',
prefix="./data/ssst"),
build_input_from_file(
"data/ssst/test_multispectral_seg.txt",
task='seg',
prefix="./data/ssst"),
build_input_from_file(
"data/ssst/test_optical_det.txt",
@ -185,7 +188,23 @@ class TestTransform(CpuCommonTest):
label_list="data/ssst/labels_det.txt"),
build_input_from_file(
"data/ssst/test_det_coco.txt",
task='det',
prefix="./data/ssst"),
build_input_from_file(
"data/ssst/test_optical_res.txt",
task='res',
prefix="./data/ssst",
sr_factor=4),
build_input_from_file(
"data/ssst/test_sar_res.txt",
task='res',
prefix="./data/ssst",
sr_factor=4),
build_input_from_file(
"data/ssst/test_multispectral_res.txt",
task='res',
prefix="./data/ssst",
sr_factor=4),
build_input_from_file(
"data/ssmt/test_mixed_binary.txt",
prefix="./data/ssmt"),
@ -227,6 +246,8 @@ class TestTransform(CpuCommonTest):
self.aux_mask_values = [
set(aux_mask.ravel()) for aux_mask in sample['aux_masks']
]
if 'target' in sample:
self.target_shape = sample['target'].shape
return sample
def _out_hook_not_keep_ratio(sample):
@ -243,6 +264,21 @@ class TestTransform(CpuCommonTest):
for aux_mask, amv in zip(sample['aux_masks'],
self.aux_mask_values):
self.assertLessEqual(set(aux_mask.ravel()), amv)
if 'target' in sample:
if 'sr_factor' in sample:
self.check_output_equal(
sample['target'].shape[:2],
T.functions.calc_hr_shape(TARGET_SIZE,
sample['sr_factor']))
else:
self.check_output_equal(sample['target'].shape[:2],
TARGET_SIZE)
self.check_output_equal(
sample['target'].shape[0] / self.target_shape[0],
sample['image'].shape[0] / self.image_shape[0])
self.check_output_equal(
sample['target'].shape[1] / self.target_shape[1],
sample['image'].shape[1] / self.image_shape[1])
# TODO: Test gt_bbox and gt_poly
return sample
@ -260,6 +296,13 @@ class TestTransform(CpuCommonTest):
for aux_mask, ori_aux_mask_shape in zip(sample['aux_masks'],
self.aux_mask_shapes):
__check_ratio(aux_mask.shape, ori_aux_mask_shape)
if 'target' in sample:
self.check_output_equal(
sample['target'].shape[0] / self.target_shape[0],
sample['image'].shape[0] / self.image_shape[0])
self.check_output_equal(
sample['target'].shape[1] / self.target_shape[1],
sample['image'].shape[1] / self.image_shape[1])
# TODO: Test gt_bbox and gt_poly
return sample

@ -9,11 +9,11 @@
|change_detection/changeformer.py | 变化检测 | ChangeFormer |
|change_detection/dsamnet.py | 变化检测 | DSAMNet |
|change_detection/dsifn.py | 变化检测 | DSIFN |
|change_detection/snunet.py | 变化检测 | SNUNet |
|change_detection/stanet.py | 变化检测 | STANet |
|change_detection/fc_ef.py | 变化检测 | FC-EF |
|change_detection/fc_siam_conc.py | 变化检测 | FC-Siam-conc |
|change_detection/fc_siam_diff.py | 变化检测 | FC-Siam-diff |
|change_detection/snunet.py | 变化检测 | SNUNet |
|change_detection/stanet.py | 变化检测 | STANet |
|classification/hrnet.py | 场景分类 | HRNet |
|classification/mobilenetv3.py | 场景分类 | MobileNetV3 |
|classification/resnet50_vd.py | 场景分类 | ResNet50-vd |

@ -72,7 +72,7 @@ eval_dataset = pdrs.datasets.CDDataset(
binarize_labels=True)
# 使用默认参数构建ChangeFormer模型
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/model_zoo.md
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/change_detector.py
model = pdrs.tasks.cd.ChangeFormer()

@ -75,9 +75,9 @@ model.train(
train_dataset=train_dataset,
train_batch_size=8,
eval_dataset=eval_dataset,
save_interval_epochs=1,
save_interval_epochs=5,
# 每多少次迭代记录一次日志
log_interval_steps=5,
log_interval_steps=10,
save_dir=EXP_DIR,
# 初始学习率大小
learning_rate=0.001,

@ -75,9 +75,9 @@ model.train(
train_dataset=train_dataset,
train_batch_size=8,
eval_dataset=eval_dataset,
save_interval_epochs=1,
save_interval_epochs=5,
# 每多少次迭代记录一次日志
log_interval_steps=5,
log_interval_steps=10,
save_dir=EXP_DIR,
# 初始学习率大小
learning_rate=0.001,

@ -75,9 +75,9 @@ model.train(
train_dataset=train_dataset,
train_batch_size=8,
eval_dataset=eval_dataset,
save_interval_epochs=1,
save_interval_epochs=5,
# 每多少次迭代记录一次日志
log_interval_steps=5,
log_interval_steps=10,
save_dir=EXP_DIR,
# 初始学习率大小
learning_rate=0.001,

@ -1,86 +0,0 @@
#!/usr/bin/env python
# 图像复原模型RCAN训练示例脚本
# 执行此脚本前,请确认已正确安装PaddleRS库
import paddlers as pdrs
from paddlers import transforms as T
# 数据集存放目录
DATA_DIR = './data/rssr/'
# 训练集`file_list`文件路径
TRAIN_FILE_LIST_PATH = './data/rssr/train.txt'
# 验证集`file_list`文件路径
EVAL_FILE_LIST_PATH = './data/rssr/val.txt'
# 实验目录,保存输出的模型权重和结果
EXP_DIR = './output/rcan/'
# 下载和解压遥感影像超分辨率数据集
pdrs.utils.download_and_decompress(
'https://paddlers.bj.bcebos.com/datasets/rssr.zip', path='./data/')
# 定义训练和验证时使用的数据变换(数据增强、预处理等)
# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
train_transforms = T.Compose([
# 读取影像
T.DecodeImg(),
# 将输入影像缩放到256x256大小
T.Resize(target_size=256),
# 以50%的概率实施随机水平翻转
T.RandomHorizontalFlip(prob=0.5),
# 以50%的概率实施随机垂直翻转
T.RandomVerticalFlip(prob=0.5),
# 将数据归一化到[0,1]
T.Normalize(
mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
T.ArrangeRestorer('train')
])
eval_transforms = T.Compose([
T.DecodeImg(),
T.Resize(target_size=256),
# 验证阶段与训练阶段的数据归一化方式必须相同
T.Normalize(
mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
T.ArrangeRestorer('eval')
])
# 分别构建训练和验证所用的数据集
train_dataset = pdrs.datasets.ResDataset(
data_dir=DATA_DIR,
file_list=TRAIN_FILE_LIST_PATH,
transforms=train_transforms,
num_workers=0,
shuffle=True)
eval_dataset = pdrs.datasets.ResDataset(
data_dir=DATA_DIR,
file_list=EVAL_FILE_LIST_PATH,
transforms=eval_transforms,
num_workers=0,
shuffle=False)
# 使用默认参数构建RCAN模型
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py
model = pdrs.tasks.res.RCAN()
# 执行模型训练
model.train(
num_epochs=10,
train_dataset=train_dataset,
train_batch_size=8,
eval_dataset=eval_dataset,
save_interval_epochs=1,
# 每多少次迭代记录一次日志
log_interval_steps=50,
save_dir=EXP_DIR,
# 初始学习率大小
learning_rate=0.01,
# 是否使用early stopping策略,当精度不再改善时提前终止训练
early_stop=False,
# 是否启用VisualDL日志功能
use_vdl=True,
# 指定从某个检查点继续训练
resume_checkpoint=None)
Loading…
Cancel
Save