You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
81 lines
2.1 KiB
81 lines
2.1 KiB
3 years ago
|
import os
|
||
|
import sys
|
||
|
sys.path.append(os.path.abspath('../PaddleRS'))
|
||
|
|
||
|
import paddlers as pdrs
|
||
|
|
||
|
# 定义训练和验证时的transforms
|
||
|
train_transforms = pdrs.datasets.ComposeTrans(
|
||
|
input_keys=['lq', 'gt'],
|
||
|
output_keys=['lq', 'gt'],
|
||
|
pipelines=[{
|
||
|
'name': 'SRPairedRandomCrop',
|
||
|
'gt_patch_size': 128,
|
||
|
'scale': 4
|
||
|
}, {
|
||
|
'name': 'PairedRandomHorizontalFlip'
|
||
|
}, {
|
||
|
'name': 'PairedRandomVerticalFlip'
|
||
|
}, {
|
||
|
'name': 'PairedRandomTransposeHW'
|
||
|
}, {
|
||
|
'name': 'Transpose'
|
||
|
}, {
|
||
|
'name': 'Normalize',
|
||
|
'mean': [0.0, 0.0, 0.0],
|
||
|
'std': [255.0, 255.0, 255.0]
|
||
|
}])
|
||
|
|
||
|
test_transforms = pdrs.datasets.ComposeTrans(
|
||
|
input_keys=['lq', 'gt'],
|
||
|
output_keys=['lq', 'gt'],
|
||
|
pipelines=[{
|
||
|
'name': 'Transpose'
|
||
|
}, {
|
||
|
'name': 'Normalize',
|
||
|
'mean': [0.0, 0.0, 0.0],
|
||
|
'std': [255.0, 255.0, 255.0]
|
||
|
}])
|
||
|
|
||
|
# 定义训练集
|
||
|
train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径
|
||
|
train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径
|
||
|
num_workers = 6
|
||
|
batch_size = 32
|
||
|
scale = 4
|
||
|
train_dataset = pdrs.datasets.SRdataset(
|
||
|
mode='train',
|
||
|
gt_floder=train_gt_floder,
|
||
|
lq_floder=train_lq_floder,
|
||
|
transforms=train_transforms(),
|
||
|
scale=scale,
|
||
|
num_workers=num_workers,
|
||
|
batch_size=batch_size)
|
||
|
|
||
|
# 定义测试集
|
||
|
test_gt_floder = r"../work/RSdata_for_SR/test_HR"
|
||
|
test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4"
|
||
|
test_dataset = pdrs.datasets.SRdataset(
|
||
|
mode='test',
|
||
|
gt_floder=test_gt_floder,
|
||
|
lq_floder=test_lq_floder,
|
||
|
transforms=test_transforms(),
|
||
|
scale=scale)
|
||
|
|
||
|
# 初始化模型,可以对网络结构的参数进行调整
|
||
|
# 若loss_type='gan' 使用感知损失、对抗损失和像素损失
|
||
|
# 若loss_type = 'pixel' 只使用像素损失
|
||
|
model = pdrs.tasks.ESRGANet(loss_type='pixel')
|
||
|
|
||
|
model.train(
|
||
|
total_iters=1000000,
|
||
|
train_dataset=train_dataset(),
|
||
|
test_dataset=test_dataset(),
|
||
|
output_dir='output_dir',
|
||
|
validate=5000,
|
||
|
snapshot=5000,
|
||
|
log=100,
|
||
|
lr_rate=0.0001,
|
||
|
periods=[250000, 250000, 250000, 250000],
|
||
|
restart_weights=[1, 1, 1, 1])
|