Update drn_train.py

fix some bug about train DRN model
own
kongdebug 3 years ago committed by GitHub
parent d29af2909c
commit b2d22d0c89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 137
      tutorials/train/image_restoration/drn_train.py

@ -5,77 +5,76 @@ sys.path.append(os.path.abspath('../PaddleRS'))
import paddle import paddle
import paddlers as pdrs import paddlers as pdrs
if __name__ == "__main__": # 定义训练和验证时的transforms
train_transforms = pdrs.datasets.ComposeTrans(
input_keys=['lq', 'gt'],
output_keys=['lq', 'lqx2', 'gt'],
pipelines=[{
'name': 'SRPairedRandomCrop',
'gt_patch_size': 192,
'scale': 4,
'scale_list': True
}, {
'name': 'PairedRandomHorizontalFlip'
}, {
'name': 'PairedRandomVerticalFlip'
}, {
'name': 'PairedRandomTransposeHW'
}, {
'name': 'Transpose'
}, {
'name': 'Normalize',
'mean': [0.0, 0.0, 0.0],
'std': [1.0, 1.0, 1.0]
}])
# 定义训练和验证时的transforms test_transforms = pdrs.datasets.ComposeTrans(
train_transforms = pdrs.datasets.ComposeTrans( input_keys=['lq', 'gt'],
input_keys=['lq', 'gt'], output_keys=['lq', 'gt'],
output_keys=['lq', 'lqx2', 'gt'], pipelines=[{
pipelines=[{ 'name': 'Transpose'
'name': 'SRPairedRandomCrop', }, {
'gt_patch_size': 192, 'name': 'Normalize',
'scale': 4, 'mean': [0.0, 0.0, 0.0],
'scale_list': True 'std': [1.0, 1.0, 1.0]
}, { }])
'name': 'PairedRandomHorizontalFlip'
}, {
'name': 'PairedRandomVerticalFlip'
}, {
'name': 'PairedRandomTransposeHW'
}, {
'name': 'Transpose'
}, {
'name': 'Normalize',
'mean': [0.0, 0.0, 0.0],
'std': [1.0, 1.0, 1.0]
}])
test_transforms = pdrs.datasets.ComposeTrans( # 定义训练集
input_keys=['lq', 'gt'], train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径
output_keys=['lq', 'gt'], train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径
pipelines=[{ num_workers = 4
'name': 'Transpose' batch_size = 8
}, { scale = 4
'name': 'Normalize', train_dataset = pdrs.datasets.SRdataset(
'mean': [0.0, 0.0, 0.0], mode='train',
'std': [1.0, 1.0, 1.0] gt_floder=train_gt_floder,
}]) lq_floder=train_lq_floder,
transforms=train_transforms(),
scale=scale,
num_workers=num_workers,
batch_size=batch_size)
train_dict = train_dataset()
# 定义训练集 # 定义测试集
train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径 test_gt_floder = r"../work/RSdata_for_SR/test_HR"
train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径 test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4"
num_workers = 4 test_dataset = pdrs.datasets.SRdataset(
batch_size = 8 mode='test',
scale = 4 gt_floder=test_gt_floder,
train_dataset = pdrs.datasets.SRdataset( lq_floder=test_lq_floder,
mode='train', transforms=test_transforms(),
gt_floder=train_gt_floder, scale=scale)
lq_floder=train_lq_floder,
transforms=train_transforms(),
scale=scale,
num_workers=num_workers,
batch_size=batch_size)
train_dict = train_dataset()
# 定义测试集 # 初始化模型,可以对网络结构的参数进行调整
test_gt_floder = r"../work/RSdata_for_SR/test_HR" model = pdrs.tasks.DRNet(
test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4" n_blocks=30, n_feats=16, n_colors=3, rgb_range=255, negval=0.2)
test_dataset = pdrs.datasets.SRdataset(
mode='test',
gt_floder=test_gt_floder,
lq_floder=test_lq_floder,
transforms=test_transforms(),
scale=scale)
# 初始化模型,可以对网络结构的参数进行调整 model.train(
model = pdrs.tasks.DRNet( total_iters=100000,
n_blocks=30, n_feats=16, n_colors=3, rgb_range=255, negval=0.2) train_dataset=train_dataset(),
test_dataset=test_dataset(),
model.train( output_dir='output_dir',
total_iters=100000, validate=5000,
train_dataset=train_dataset(), snapshot=5000,
test_dataset=test_dataset(), lr_rate=0.0001,
output_dir='output_dir', log=10)
validate=5000,
snapshot=5000,
lr_rate=0.0001)

Loading…
Cancel
Save