add train and test code for super-resolution model lesrcnn and esrgan

own
kongdebug 3 years ago
parent d2e80e829d
commit d29af2909c
  1. 1
      paddlers/datasets/__init__.py
  2. 99
      paddlers/datasets/sr_dataset.py
  3. 1
      paddlers/tasks/__init__.py
  4. 753
      paddlers/tasks/imagerestorer.py
  5. 81
      tutorials/train/image_restoration/drn_train.py
  6. 80
      tutorials/train/image_restoration/esrgan_train.py
  7. 78
      tutorials/train/image_restoration/lesrcnn_train.py

@ -16,3 +16,4 @@ from .voc import VOCDetection
from .seg_dataset import SegDataset
from .cd_dataset import CDDataset
from .clas_dataset import ClasDataset
from .sr_dataset import SRdataset, ComposeTrans

@ -0,0 +1,99 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 超分辨率数据集定义
class SRdataset(object):
def __init__(self,
mode,
gt_floder,
lq_floder,
transforms,
scale,
num_workers=4,
batch_size=8):
if mode == 'train':
preprocess = []
preprocess.append({
'name': 'LoadImageFromFile',
'key': 'lq'
}) # 加载方式
preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'})
preprocess.append(transforms) # 变换方式
self.dataset = {
'name': 'SRDataset',
'gt_folder': gt_floder,
'lq_folder': lq_floder,
'num_workers': num_workers,
'batch_size': batch_size,
'scale': scale,
'preprocess': preprocess
}
if mode == "test":
preprocess = []
preprocess.append({'name': 'LoadImageFromFile', 'key': 'lq'})
preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'})
preprocess.append(transforms)
self.dataset = {
'name': 'SRDataset',
'gt_folder': gt_floder,
'lq_folder': lq_floder,
'scale': scale,
'preprocess': preprocess
}
def __call__(self):
return self.dataset
# 对定义的transforms处理方式组合,返回字典
class ComposeTrans(object):
def __init__(self, input_keys, output_keys, pipelines):
if not isinstance(pipelines, list):
raise TypeError(
'Type of transforms is invalid. Must be List, but received is {}'
.format(type(pipelines)))
if len(pipelines) < 1:
raise ValueError(
'Length of transforms must not be less than 1, but received is {}'
.format(len(pipelines)))
self.transforms = pipelines
self.output_length = len(output_keys) # 当output_keys的长度为3时,是DRN训练
self.input_keys = input_keys
self.output_keys = output_keys
def __call__(self):
pipeline = []
for op in self.transforms:
if op['name'] == 'SRPairedRandomCrop':
op['keys'] = ['image'] * 2
else:
op['keys'] = ['image'] * self.output_length
pipeline.append(op)
if self.output_length == 2:
transform_dict = {
'name': 'Transforms',
'input_keys': self.input_keys,
'pipeline': pipeline
}
else:
transform_dict = {
'name': 'Transforms',
'input_keys': self.input_keys,
'output_keys': self.output_keys,
'pipeline': pipeline
}
return transform_dict

@ -17,3 +17,4 @@ from .segmenter import *
from .changedetector import *
from .classifier import *
from .load_model import load_model
from .imagerestorer import *

@ -0,0 +1,753 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import datetime
import paddle
from paddle.distributed import ParallelEnv
from ..models.ppgan.datasets.builder import build_dataloader
from ..models.ppgan.models.builder import build_model
from ..models.ppgan.utils.visual import tensor2img, save_image
from ..models.ppgan.utils.filesystem import makedirs, save, load
from ..models.ppgan.utils.timer import TimeAverager
from ..models.ppgan.utils.profiler import add_profiler_step
from ..models.ppgan.utils.logger import setup_logger
# 定义AttrDict类实现动态属性
class AttrDict(dict):
def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(key)
def __setattr__(self, key, value):
if key in self.__dict__:
self.__dict__[key] = value
else:
self[key] = value
# 创建AttrDict类
def create_attr_dict(config_dict):
from ast import literal_eval
for key, value in config_dict.items():
if type(value) is dict:
config_dict[key] = value = AttrDict(value)
if isinstance(value, str):
try:
value = literal_eval(value)
except BaseException:
pass
if isinstance(value, AttrDict):
create_attr_dict(config_dict[key])
else:
config_dict[key] = value
# 数据加载类
class IterLoader:
def __init__(self, dataloader):
self._dataloader = dataloader
self.iter_loader = iter(self._dataloader)
self._epoch = 1
@property
def epoch(self):
return self._epoch
def __next__(self):
try:
data = next(self.iter_loader)
except StopIteration:
self._epoch += 1
self.iter_loader = iter(self._dataloader)
data = next(self.iter_loader)
return data
def __len__(self):
return len(self._dataloader)
# 基础训练类
class Restorer:
"""
# trainer calling logic:
#
# build_model || model(BaseModel)
# | ||
# build_dataloader || dataloader
# | ||
# model.setup_lr_schedulers || lr_scheduler
# | ||
# model.setup_optimizers || optimizers
# | ||
# train loop (model.setup_input + model.train_iter) || train loop
# | ||
# print log (model.get_current_losses) ||
# | ||
# save checkpoint (model.nets) \/
"""
def __init__(self, cfg, logger):
# base config
# self.logger = logging.getLogger(__name__)
self.logger = logger
self.cfg = cfg
self.output_dir = cfg.output_dir
self.max_eval_steps = cfg.model.get('max_eval_steps', None)
self.local_rank = ParallelEnv().local_rank
self.world_size = ParallelEnv().nranks
self.log_interval = cfg.log_config.interval
self.visual_interval = cfg.log_config.visiual_interval
self.weight_interval = cfg.snapshot_config.interval
self.start_epoch = 1
self.current_epoch = 1
self.current_iter = 1
self.inner_iter = 1
self.batch_id = 0
self.global_steps = 0
# build model
self.model = build_model(cfg.model)
# multiple gpus prepare
if ParallelEnv().nranks > 1:
self.distributed_data_parallel()
# build metrics
self.metrics = None
self.is_save_img = True
validate_cfg = cfg.get('validate', None)
if validate_cfg and 'metrics' in validate_cfg:
self.metrics = self.model.setup_metrics(validate_cfg['metrics'])
if validate_cfg and 'save_img' in validate_cfg:
self.is_save_img = validate_cfg['save_img']
self.enable_visualdl = cfg.get('enable_visualdl', False)
if self.enable_visualdl:
import visualdl
self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir)
# evaluate only
if not cfg.is_train:
return
# build train dataloader
self.train_dataloader = build_dataloader(cfg.dataset.train)
self.iters_per_epoch = len(self.train_dataloader)
# build lr scheduler
# TODO: has a better way?
if 'lr_scheduler' in cfg and 'iters_per_epoch' in cfg.lr_scheduler:
cfg.lr_scheduler.iters_per_epoch = self.iters_per_epoch
self.lr_schedulers = self.model.setup_lr_schedulers(cfg.lr_scheduler)
# build optimizers
self.optimizers = self.model.setup_optimizers(self.lr_schedulers,
cfg.optimizer)
self.epochs = cfg.get('epochs', None)
if self.epochs:
self.total_iters = self.epochs * self.iters_per_epoch
self.by_epoch = True
else:
self.by_epoch = False
self.total_iters = cfg.total_iters
if self.by_epoch:
self.weight_interval *= self.iters_per_epoch
self.validate_interval = -1
if cfg.get('validate', None) is not None:
self.validate_interval = cfg.validate.get('interval', -1)
self.time_count = {}
self.best_metric = {}
self.model.set_total_iter(self.total_iters)
self.profiler_options = cfg.profiler_options
def distributed_data_parallel(self):
paddle.distributed.init_parallel_env()
find_unused_parameters = self.cfg.get('find_unused_parameters', False)
for net_name, net in self.model.nets.items():
self.model.nets[net_name] = paddle.DataParallel(
net, find_unused_parameters=find_unused_parameters)
def learning_rate_scheduler_step(self):
if isinstance(self.model.lr_scheduler, dict):
for lr_scheduler in self.model.lr_scheduler.values():
lr_scheduler.step()
elif isinstance(self.model.lr_scheduler,
paddle.optimizer.lr.LRScheduler):
self.model.lr_scheduler.step()
else:
raise ValueError(
'lr schedulter must be a dict or an instance of LRScheduler')
def train(self):
reader_cost_averager = TimeAverager()
batch_cost_averager = TimeAverager()
iter_loader = IterLoader(self.train_dataloader)
# set model.is_train = True
self.model.setup_train_mode(is_train=True)
while self.current_iter < (self.total_iters + 1):
self.current_epoch = iter_loader.epoch
self.inner_iter = self.current_iter % self.iters_per_epoch
add_profiler_step(self.profiler_options)
start_time = step_start_time = time.time()
data = next(iter_loader)
reader_cost_averager.record(time.time() - step_start_time)
# unpack data from dataset and apply preprocessing
# data input should be dict
self.model.setup_input(data)
self.model.train_iter(self.optimizers)
batch_cost_averager.record(
time.time() - step_start_time,
num_samples=self.cfg['dataset']['train'].get('batch_size', 1))
step_start_time = time.time()
if self.current_iter % self.log_interval == 0:
self.data_time = reader_cost_averager.get_average()
self.step_time = batch_cost_averager.get_average()
self.ips = batch_cost_averager.get_ips_average()
self.print_log()
reader_cost_averager.reset()
batch_cost_averager.reset()
if self.current_iter % self.visual_interval == 0 and self.local_rank == 0:
self.visual('visual_train')
self.learning_rate_scheduler_step()
if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0:
self.test()
if self.current_iter % self.weight_interval == 0:
self.save(self.current_iter, 'weight', keep=-1)
self.save(self.current_iter)
self.current_iter += 1
def test(self):
if not hasattr(self, 'test_dataloader'):
self.test_dataloader = build_dataloader(
self.cfg.dataset.test, is_train=False)
iter_loader = IterLoader(self.test_dataloader)
if self.max_eval_steps is None:
self.max_eval_steps = len(self.test_dataloader)
if self.metrics:
for metric in self.metrics.values():
metric.reset()
# set model.is_train = False
self.model.setup_train_mode(is_train=False)
for i in range(self.max_eval_steps):
if self.max_eval_steps < self.log_interval or i % self.log_interval == 0:
self.logger.info('Test iter: [%d/%d]' % (
i * self.world_size, self.max_eval_steps * self.world_size))
data = next(iter_loader)
self.model.setup_input(data)
self.model.test_iter(metrics=self.metrics)
if self.is_save_img:
visual_results = {}
current_paths = self.model.get_image_paths()
current_visuals = self.model.get_current_visuals()
if len(current_visuals) > 0 and list(current_visuals.values())[
0].shape == 4:
num_samples = list(current_visuals.values())[0].shape[0]
else:
num_samples = 1
for j in range(num_samples):
if j < len(current_paths):
short_path = os.path.basename(current_paths[j])
basename = os.path.splitext(short_path)[0]
else:
basename = '{:04d}_{:04d}'.format(i, j)
for k, img_tensor in current_visuals.items():
name = '%s_%s' % (basename, k)
if len(img_tensor.shape) == 4:
visual_results.update({name: img_tensor[j]})
else:
visual_results.update({name: img_tensor})
self.visual(
'visual_test',
visual_results=visual_results,
step=self.batch_id,
is_save_image=True)
if self.metrics:
for metric_name, metric in self.metrics.items():
self.logger.info("Metric {}: {:.4f}".format(
metric_name, metric.accumulate()))
def print_log(self):
losses = self.model.get_current_losses()
message = ''
if self.by_epoch:
message += 'Epoch: %d/%d, iter: %d/%d ' % (
self.current_epoch, self.epochs, self.inner_iter,
self.iters_per_epoch)
else:
message += 'Iter: %d/%d ' % (self.current_iter, self.total_iters)
message += f'lr: {self.current_learning_rate:.3e} '
for k, v in losses.items():
message += '%s: %.3f ' % (k, v)
if self.enable_visualdl:
self.vdl_logger.add_scalar(k, v, step=self.global_steps)
if hasattr(self, 'step_time'):
message += 'batch_cost: %.5f sec ' % self.step_time
if hasattr(self, 'data_time'):
message += 'reader_cost: %.5f sec ' % self.data_time
if hasattr(self, 'ips'):
message += 'ips: %.5f images/s ' % self.ips
if hasattr(self, 'step_time'):
eta = self.step_time * (self.total_iters - self.current_iter)
eta = eta if eta > 0 else 0
eta_str = str(datetime.timedelta(seconds=int(eta)))
message += f'eta: {eta_str}'
# print the message
self.logger.info(message)
@property
def current_learning_rate(self):
for optimizer in self.model.optimizers.values():
return optimizer.get_lr()
def visual(self,
results_dir,
visual_results=None,
step=None,
is_save_image=False):
"""
visual the images, use visualdl or directly write to the directory
Parameters:
results_dir (str) -- directory name which contains saved images
visual_results (dict) -- the results images dict
step (int) -- global steps, used in visualdl
is_save_image (bool) -- weather write to the directory or visualdl
"""
self.model.compute_visuals()
if visual_results is None:
visual_results = self.model.get_current_visuals()
min_max = self.cfg.get('min_max', None)
if min_max is None:
min_max = (-1., 1.)
image_num = self.cfg.get('image_num', None)
if (image_num is None) or (not self.enable_visualdl):
image_num = 1
for label, image in visual_results.items():
image_numpy = tensor2img(image, min_max, image_num)
if (not is_save_image) and self.enable_visualdl:
self.vdl_logger.add_image(
results_dir + '/' + label,
image_numpy,
step=step if step else self.global_steps,
dataformats="HWC" if image_num == 1 else "NCHW")
else:
if self.cfg.is_train:
if self.by_epoch:
msg = 'epoch%.3d_' % self.current_epoch
else:
msg = 'iter%.3d_' % self.current_iter
else:
msg = ''
makedirs(os.path.join(self.output_dir, results_dir))
img_path = os.path.join(self.output_dir, results_dir,
msg + '%s.png' % (label))
save_image(image_numpy, img_path)
def save(self, epoch, name='checkpoint', keep=1):
if self.local_rank != 0:
return
assert name in ['checkpoint', 'weight']
state_dicts = {}
if self.by_epoch:
save_filename = 'epoch_%s_%s.pdparams' % (
epoch // self.iters_per_epoch, name)
else:
save_filename = 'iter_%s_%s.pdparams' % (epoch, name)
os.makedirs(self.output_dir, exist_ok=True)
save_path = os.path.join(self.output_dir, save_filename)
for net_name, net in self.model.nets.items():
state_dicts[net_name] = net.state_dict()
if name == 'weight':
save(state_dicts, save_path)
return
state_dicts['epoch'] = epoch
for opt_name, opt in self.model.optimizers.items():
state_dicts[opt_name] = opt.state_dict()
save(state_dicts, save_path)
if keep > 0:
try:
if self.by_epoch:
checkpoint_name_to_be_removed = os.path.join(
self.output_dir, 'epoch_%s_%s.pdparams' % (
(epoch - keep * self.weight_interval) //
self.iters_per_epoch, name))
else:
checkpoint_name_to_be_removed = os.path.join(
self.output_dir, 'iter_%s_%s.pdparams' %
(epoch - keep * self.weight_interval, name))
if os.path.exists(checkpoint_name_to_be_removed):
os.remove(checkpoint_name_to_be_removed)
except Exception as e:
self.logger.info('remove old checkpoints error: {}'.format(e))
def resume(self, checkpoint_path):
state_dicts = load(checkpoint_path)
if state_dicts.get('epoch', None) is not None:
self.start_epoch = state_dicts['epoch'] + 1
self.global_steps = self.iters_per_epoch * state_dicts['epoch']
self.current_iter = state_dicts['epoch'] + 1
for net_name, net in self.model.nets.items():
net.set_state_dict(state_dicts[net_name])
for opt_name, opt in self.model.optimizers.items():
opt.set_state_dict(state_dicts[opt_name])
def load(self, weight_path):
state_dicts = load(weight_path)
for net_name, net in self.model.nets.items():
if net_name in state_dicts:
net.set_state_dict(state_dicts[net_name])
self.logger.info('Loaded pretrained weight for net {}'.format(
net_name))
else:
self.logger.warning(
'Can not find state dict of net {}. Skip load pretrained weight for net {}'
.format(net_name, net_name))
def close(self):
"""
when finish the training need close file handler or other.
"""
if self.enable_visualdl:
self.vdl_logger.close()
# 基础超分模型训练类
class BasicSRNet:
def __init__(self):
self.model = {}
self.optimizer = {}
self.lr_scheduler = {}
self.min_max = ''
def train(
self,
total_iters,
train_dataset,
test_dataset,
output_dir,
validate,
snapshot,
log,
lr_rate,
evaluate_weights='',
resume='',
pretrain_weights='',
periods=[100000],
restart_weights=[1], ):
self.lr_scheduler['learning_rate'] = lr_rate
if self.lr_scheduler['name'] == 'CosineAnnealingRestartLR':
self.lr_scheduler['periods'] = periods
self.lr_scheduler['restart_weights'] = restart_weights
validate = {
'interval': validate,
'save_img': False,
'metrics': {
'psnr': {
'name': 'PSNR',
'crop_border': 4,
'test_y_channel': True
},
'ssim': {
'name': 'SSIM',
'crop_border': 4,
'test_y_channel': True
}
}
}
log_config = {'interval': log, 'visiual_interval': 500}
snapshot_config = {'interval': snapshot}
cfg = {
'total_iters': total_iters,
'output_dir': output_dir,
'min_max': self.min_max,
'model': self.model,
'dataset': {
'train': train_dataset,
'test': test_dataset
},
'lr_scheduler': self.lr_scheduler,
'optimizer': self.optimizer,
'validate': validate,
'log_config': log_config,
'snapshot_config': snapshot_config
}
cfg = AttrDict(cfg)
create_attr_dict(cfg)
cfg.is_train = True
cfg.profiler_options = None
cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime())
if cfg.model.name == 'BaseSRModel':
floderModelName = cfg.model.generator.name
else:
floderModelName = cfg.model.name
cfg.output_dir = os.path.join(cfg.output_dir,
floderModelName + cfg.timestamp)
logger_cfg = setup_logger(cfg.output_dir)
logger_cfg.info('Configs: {}'.format(cfg))
if paddle.is_compiled_with_cuda():
paddle.set_device('gpu')
else:
paddle.set_device('cpu')
# build trainer
trainer = Restorer(cfg, logger_cfg)
# continue train or evaluate, checkpoint need contain epoch and optimizer info
if len(resume) > 0:
trainer.resume(resume)
# evaluate or finute, only load generator weights
elif len(pretrain_weights) > 0:
trainer.load(pretrain_weights)
if len(evaluate_weights) > 0:
trainer.load(evaluate_weights)
trainer.test()
return
# training, when keyboard interrupt save weights
try:
trainer.train()
except KeyboardInterrupt as e:
trainer.save(trainer.current_epoch)
trainer.close()
# DRN模型训练
class DRNet(BasicSRNet):
def __init__(self,
n_blocks=30,
n_feats=16,
n_colors=3,
rgb_range=255,
negval=0.2):
super(DRNet, self).__init__()
self.min_max = '(0., 255.)'
self.generator = {
'name': 'DRNGenerator',
'scale': (2, 4),
'n_blocks': n_blocks,
'n_feats': n_feats,
'n_colors': n_colors,
'rgb_range': rgb_range,
'negval': negval
}
self.pixel_criterion = {'name': 'L1Loss'}
self.model = {
'name': 'DRN',
'generator': self.generator,
'pixel_criterion': self.pixel_criterion
}
self.optimizer = {
'optimG': {
'name': 'Adam',
'net_names': ['generator'],
'weight_decay': 0.0,
'beta1': 0.9,
'beta2': 0.999
},
'optimD': {
'name': 'Adam',
'net_names': ['dual_model_0', 'dual_model_1'],
'weight_decay': 0.0,
'beta1': 0.9,
'beta2': 0.999
}
}
self.lr_scheduler = {
'name': 'CosineAnnealingRestartLR',
'eta_min': 1e-07
}
# 轻量化超分模型LESRCNN训练
class LESRCNNet(BasicSRNet):
def __init__(self, scale=4, multi_scale=False, group=1):
super(LESRCNNet, self).__init__()
self.min_max = '(0., 1.)'
self.generator = {
'name': 'LESRCNNGenerator',
'scale': scale,
'multi_scale': False,
'group': 1
}
self.pixel_criterion = {'name': 'L1Loss'}
self.model = {
'name': 'BaseSRModel',
'generator': self.generator,
'pixel_criterion': self.pixel_criterion
}
self.optimizer = {
'name': 'Adam',
'net_names': ['generator'],
'beta1': 0.9,
'beta2': 0.99
}
self.lr_scheduler = {
'name': 'CosineAnnealingRestartLR',
'eta_min': 1e-07
}
# ESRGAN模型训练
# 若loss_type='gan' 使用感知损失、对抗损失和像素损失
# 若loss_type = 'pixel' 只使用像素损失
class ESRGANet(BasicSRNet):
def __init__(self, loss_type='gan', in_nc=3, out_nc=3, nf=64, nb=23):
super(ESRGANet, self).__init__()
self.min_max = '(0., 1.)'
self.generator = {
'name': 'RRDBNet',
'in_nc': in_nc,
'out_nc': out_nc,
'nf': nf,
'nb': nb
}
if loss_type == 'gan':
# 定义损失函数
self.pixel_criterion = {'name': 'L1Loss', 'loss_weight': 0.01}
self.discriminator = {
'name': 'VGGDiscriminator128',
'in_channels': 3,
'num_feat': 64
}
self.perceptual_criterion = {
'name': 'PerceptualLoss',
'layer_weights': {
'34': 1.0
},
'perceptual_weight': 1.0,
'style_weight': 0.0,
'norm_img': False
}
self.gan_criterion = {
'name': 'GANLoss',
'gan_mode': 'vanilla',
'loss_weight': 0.005
}
# 定义模型
self.model = {
'name': 'ESRGAN',
'generator': self.generator,
'discriminator': self.discriminator,
'pixel_criterion': self.pixel_criterion,
'perceptual_criterion': self.perceptual_criterion,
'gan_criterion': self.gan_criterion
}
self.optimizer = {
'optimG': {
'name': 'Adam',
'net_names': ['generator'],
'weight_decay': 0.0,
'beta1': 0.9,
'beta2': 0.99
},
'optimD': {
'name': 'Adam',
'net_names': ['discriminator'],
'weight_decay': 0.0,
'beta1': 0.9,
'beta2': 0.99
}
}
self.lr_scheduler = {
'name': 'MultiStepDecay',
'milestones': [50000, 100000, 200000, 300000],
'gamma': 0.5
}
else:
self.pixel_criterion = {'name': 'L1Loss'}
self.model = {
'name': 'BaseSRModel',
'generator': self.generator,
'pixel_criterion': self.pixel_criterion
}
self.optimizer = {
'name': 'Adam',
'net_names': ['generator'],
'beta1': 0.9,
'beta2': 0.99
}
self.lr_scheduler = {
'name': 'CosineAnnealingRestartLR',
'eta_min': 1e-07
}

@ -0,0 +1,81 @@
import os
import sys
sys.path.append(os.path.abspath('../PaddleRS'))
import paddle
import paddlers as pdrs
if __name__ == "__main__":
# 定义训练和验证时的transforms
train_transforms = pdrs.datasets.ComposeTrans(
input_keys=['lq', 'gt'],
output_keys=['lq', 'lqx2', 'gt'],
pipelines=[{
'name': 'SRPairedRandomCrop',
'gt_patch_size': 192,
'scale': 4,
'scale_list': True
}, {
'name': 'PairedRandomHorizontalFlip'
}, {
'name': 'PairedRandomVerticalFlip'
}, {
'name': 'PairedRandomTransposeHW'
}, {
'name': 'Transpose'
}, {
'name': 'Normalize',
'mean': [0.0, 0.0, 0.0],
'std': [1.0, 1.0, 1.0]
}])
test_transforms = pdrs.datasets.ComposeTrans(
input_keys=['lq', 'gt'],
output_keys=['lq', 'gt'],
pipelines=[{
'name': 'Transpose'
}, {
'name': 'Normalize',
'mean': [0.0, 0.0, 0.0],
'std': [1.0, 1.0, 1.0]
}])
# 定义训练集
train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径
train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径
num_workers = 4
batch_size = 8
scale = 4
train_dataset = pdrs.datasets.SRdataset(
mode='train',
gt_floder=train_gt_floder,
lq_floder=train_lq_floder,
transforms=train_transforms(),
scale=scale,
num_workers=num_workers,
batch_size=batch_size)
train_dict = train_dataset()
# 定义测试集
test_gt_floder = r"../work/RSdata_for_SR/test_HR"
test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4"
test_dataset = pdrs.datasets.SRdataset(
mode='test',
gt_floder=test_gt_floder,
lq_floder=test_lq_floder,
transforms=test_transforms(),
scale=scale)
# 初始化模型,可以对网络结构的参数进行调整
model = pdrs.tasks.DRNet(
n_blocks=30, n_feats=16, n_colors=3, rgb_range=255, negval=0.2)
model.train(
total_iters=100000,
train_dataset=train_dataset(),
test_dataset=test_dataset(),
output_dir='output_dir',
validate=5000,
snapshot=5000,
lr_rate=0.0001)

@ -0,0 +1,80 @@
import os
import sys
sys.path.append(os.path.abspath('../PaddleRS'))
import paddlers as pdrs
# 定义训练和验证时的transforms
train_transforms = pdrs.datasets.ComposeTrans(
input_keys=['lq', 'gt'],
output_keys=['lq', 'gt'],
pipelines=[{
'name': 'SRPairedRandomCrop',
'gt_patch_size': 128,
'scale': 4
}, {
'name': 'PairedRandomHorizontalFlip'
}, {
'name': 'PairedRandomVerticalFlip'
}, {
'name': 'PairedRandomTransposeHW'
}, {
'name': 'Transpose'
}, {
'name': 'Normalize',
'mean': [0.0, 0.0, 0.0],
'std': [255.0, 255.0, 255.0]
}])
test_transforms = pdrs.datasets.ComposeTrans(
input_keys=['lq', 'gt'],
output_keys=['lq', 'gt'],
pipelines=[{
'name': 'Transpose'
}, {
'name': 'Normalize',
'mean': [0.0, 0.0, 0.0],
'std': [255.0, 255.0, 255.0]
}])
# 定义训练集
train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径
train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径
num_workers = 6
batch_size = 32
scale = 4
train_dataset = pdrs.datasets.SRdataset(
mode='train',
gt_floder=train_gt_floder,
lq_floder=train_lq_floder,
transforms=train_transforms(),
scale=scale,
num_workers=num_workers,
batch_size=batch_size)
# 定义测试集
test_gt_floder = r"../work/RSdata_for_SR/test_HR"
test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4"
test_dataset = pdrs.datasets.SRdataset(
mode='test',
gt_floder=test_gt_floder,
lq_floder=test_lq_floder,
transforms=test_transforms(),
scale=scale)
# 初始化模型,可以对网络结构的参数进行调整
# 若loss_type='gan' 使用感知损失、对抗损失和像素损失
# 若loss_type = 'pixel' 只使用像素损失
model = pdrs.tasks.ESRGANet(loss_type='pixel')
model.train(
total_iters=1000000,
train_dataset=train_dataset(),
test_dataset=test_dataset(),
output_dir='output_dir',
validate=5000,
snapshot=5000,
log=100,
lr_rate=0.0001,
periods=[250000, 250000, 250000, 250000],
restart_weights=[1, 1, 1, 1])

@ -0,0 +1,78 @@
import os
import sys
sys.path.append(os.path.abspath('../PaddleRS'))
import paddlers as pdrs
# 定义训练和验证时的transforms
train_transforms = pdrs.datasets.ComposeTrans(
input_keys=['lq', 'gt'],
output_keys=['lq', 'gt'],
pipelines=[{
'name': 'SRPairedRandomCrop',
'gt_patch_size': 192,
'scale': 4
}, {
'name': 'PairedRandomHorizontalFlip'
}, {
'name': 'PairedRandomVerticalFlip'
}, {
'name': 'PairedRandomTransposeHW'
}, {
'name': 'Transpose'
}, {
'name': 'Normalize',
'mean': [0.0, 0.0, 0.0],
'std': [255.0, 255.0, 255.0]
}])
test_transforms = pdrs.datasets.ComposeTrans(
input_keys=['lq', 'gt'],
output_keys=['lq', 'gt'],
pipelines=[{
'name': 'Transpose'
}, {
'name': 'Normalize',
'mean': [0.0, 0.0, 0.0],
'std': [255.0, 255.0, 255.0]
}])
# 定义训练集
train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径
train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径
num_workers = 4
batch_size = 16
scale = 4
train_dataset = pdrs.datasets.SRdataset(
mode='train',
gt_floder=train_gt_floder,
lq_floder=train_lq_floder,
transforms=train_transforms(),
scale=scale,
num_workers=num_workers,
batch_size=batch_size)
# 定义测试集
test_gt_floder = r"../work/RSdata_for_SR/test_HR"
test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4"
test_dataset = pdrs.datasets.SRdataset(
mode='test',
gt_floder=test_gt_floder,
lq_floder=test_lq_floder,
transforms=test_transforms(),
scale=scale)
# 初始化模型,可以对网络结构的参数进行调整
model = pdrs.tasks.LESRCNNet(scale=4, multi_scale=False, group=1)
model.train(
total_iters=1000000,
train_dataset=train_dataset(),
test_dataset=test_dataset(),
output_dir='output_dir',
validate=5000,
snapshot=5000,
log=100,
lr_rate=0.0001,
periods=[250000, 250000, 250000, 250000],
restart_weights=[1, 1, 1, 1])
Loading…
Cancel
Save