diff --git a/tutorials/train/image_restoration/drn_train.py b/tutorials/train/image_restoration/drn_train.py index 899338c..ecfc358 100644 --- a/tutorials/train/image_restoration/drn_train.py +++ b/tutorials/train/image_restoration/drn_train.py @@ -5,77 +5,76 @@ sys.path.append(os.path.abspath('../PaddleRS')) import paddle import paddlers as pdrs -if __name__ == "__main__": +# 定义训练和验证时的transforms +train_transforms = pdrs.datasets.ComposeTrans( + input_keys=['lq', 'gt'], + output_keys=['lq', 'lqx2', 'gt'], + pipelines=[{ + 'name': 'SRPairedRandomCrop', + 'gt_patch_size': 192, + 'scale': 4, + 'scale_list': True + }, { + 'name': 'PairedRandomHorizontalFlip' + }, { + 'name': 'PairedRandomVerticalFlip' + }, { + 'name': 'PairedRandomTransposeHW' + }, { + 'name': 'Transpose' + }, { + 'name': 'Normalize', + 'mean': [0.0, 0.0, 0.0], + 'std': [1.0, 1.0, 1.0] + }]) - # 定义训练和验证时的transforms - train_transforms = pdrs.datasets.ComposeTrans( - input_keys=['lq', 'gt'], - output_keys=['lq', 'lqx2', 'gt'], - pipelines=[{ - 'name': 'SRPairedRandomCrop', - 'gt_patch_size': 192, - 'scale': 4, - 'scale_list': True - }, { - 'name': 'PairedRandomHorizontalFlip' - }, { - 'name': 'PairedRandomVerticalFlip' - }, { - 'name': 'PairedRandomTransposeHW' - }, { - 'name': 'Transpose' - }, { - 'name': 'Normalize', - 'mean': [0.0, 0.0, 0.0], - 'std': [1.0, 1.0, 1.0] - }]) +test_transforms = pdrs.datasets.ComposeTrans( + input_keys=['lq', 'gt'], + output_keys=['lq', 'gt'], + pipelines=[{ + 'name': 'Transpose' + }, { + 'name': 'Normalize', + 'mean': [0.0, 0.0, 0.0], + 'std': [1.0, 1.0, 1.0] + }]) - test_transforms = pdrs.datasets.ComposeTrans( - input_keys=['lq', 'gt'], - output_keys=['lq', 'gt'], - pipelines=[{ - 'name': 'Transpose' - }, { - 'name': 'Normalize', - 'mean': [0.0, 0.0, 0.0], - 'std': [1.0, 1.0, 1.0] - }]) +# 定义训练集 +train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径 +train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径 +num_workers = 4 +batch_size = 8 +scale = 4 +train_dataset = pdrs.datasets.SRdataset( + mode='train', + gt_floder=train_gt_floder, + lq_floder=train_lq_floder, + transforms=train_transforms(), + scale=scale, + num_workers=num_workers, + batch_size=batch_size) +train_dict = train_dataset() - # 定义训练集 - train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径 - train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径 - num_workers = 4 - batch_size = 8 - scale = 4 - train_dataset = pdrs.datasets.SRdataset( - mode='train', - gt_floder=train_gt_floder, - lq_floder=train_lq_floder, - transforms=train_transforms(), - scale=scale, - num_workers=num_workers, - batch_size=batch_size) - train_dict = train_dataset() +# 定义测试集 +test_gt_floder = r"../work/RSdata_for_SR/test_HR" +test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4" +test_dataset = pdrs.datasets.SRdataset( + mode='test', + gt_floder=test_gt_floder, + lq_floder=test_lq_floder, + transforms=test_transforms(), + scale=scale) - # 定义测试集 - test_gt_floder = r"../work/RSdata_for_SR/test_HR" - test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4" - test_dataset = pdrs.datasets.SRdataset( - mode='test', - gt_floder=test_gt_floder, - lq_floder=test_lq_floder, - transforms=test_transforms(), - scale=scale) +# 初始化模型,可以对网络结构的参数进行调整 +model = pdrs.tasks.DRNet( + n_blocks=30, n_feats=16, n_colors=3, rgb_range=255, negval=0.2) - # 初始化模型,可以对网络结构的参数进行调整 - model = pdrs.tasks.DRNet( - n_blocks=30, n_feats=16, n_colors=3, rgb_range=255, negval=0.2) - - model.train( - total_iters=100000, - train_dataset=train_dataset(), - test_dataset=test_dataset(), - output_dir='output_dir', - validate=5000, - snapshot=5000, - lr_rate=0.0001) +model.train( + total_iters=100000, + train_dataset=train_dataset(), + test_dataset=test_dataset(), + output_dir='output_dir', + validate=5000, + snapshot=5000, + lr_rate=0.0001, + log=10)