diff --git a/paddlers/datasets/__init__.py b/paddlers/datasets/__init__.py index f611c54..20027e0 100644 --- a/paddlers/datasets/__init__.py +++ b/paddlers/datasets/__init__.py @@ -15,4 +15,5 @@ from .voc import VOCDetection from .seg_dataset import SegDataset from .cd_dataset import CDDataset -from .clas_dataset import ClasDataset \ No newline at end of file +from .clas_dataset import ClasDataset +from .sr_dataset import SRdataset, ComposeTrans \ No newline at end of file diff --git a/paddlers/datasets/sr_dataset.py b/paddlers/datasets/sr_dataset.py new file mode 100644 index 0000000..17748bf --- /dev/null +++ b/paddlers/datasets/sr_dataset.py @@ -0,0 +1,99 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# 超分辨率数据集定义 +class SRdataset(object): + def __init__(self, + mode, + gt_floder, + lq_floder, + transforms, + scale, + num_workers=4, + batch_size=8): + if mode == 'train': + preprocess = [] + preprocess.append({ + 'name': 'LoadImageFromFile', + 'key': 'lq' + }) # 加载方式 + preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'}) + preprocess.append(transforms) # 变换方式 + self.dataset = { + 'name': 'SRDataset', + 'gt_folder': gt_floder, + 'lq_folder': lq_floder, + 'num_workers': num_workers, + 'batch_size': batch_size, + 'scale': scale, + 'preprocess': preprocess + } + + if mode == "test": + preprocess = [] + preprocess.append({'name': 'LoadImageFromFile', 'key': 'lq'}) + preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'}) + preprocess.append(transforms) + self.dataset = { + 'name': 'SRDataset', + 'gt_folder': gt_floder, + 'lq_folder': lq_floder, + 'scale': scale, + 'preprocess': preprocess + } + + def __call__(self): + return self.dataset + + +# 对定义的transforms处理方式组合,返回字典 +class ComposeTrans(object): + def __init__(self, input_keys, output_keys, pipelines): + if not isinstance(pipelines, list): + raise TypeError( + 'Type of transforms is invalid. Must be List, but received is {}' + .format(type(pipelines))) + if len(pipelines) < 1: + raise ValueError( + 'Length of transforms must not be less than 1, but received is {}' + .format(len(pipelines))) + self.transforms = pipelines + self.output_length = len(output_keys) # 当output_keys的长度为3时,是DRN训练 + self.input_keys = input_keys + self.output_keys = output_keys + + def __call__(self): + pipeline = [] + for op in self.transforms: + if op['name'] == 'SRPairedRandomCrop': + op['keys'] = ['image'] * 2 + else: + op['keys'] = ['image'] * self.output_length + pipeline.append(op) + if self.output_length == 2: + transform_dict = { + 'name': 'Transforms', + 'input_keys': self.input_keys, + 'pipeline': pipeline + } + else: + transform_dict = { + 'name': 'Transforms', + 'input_keys': self.input_keys, + 'output_keys': self.output_keys, + 'pipeline': pipeline + } + + return transform_dict diff --git a/paddlers/tasks/__init__.py b/paddlers/tasks/__init__.py index 34219b0..e8193d1 100644 --- a/paddlers/tasks/__init__.py +++ b/paddlers/tasks/__init__.py @@ -17,3 +17,4 @@ from .segmenter import * from .changedetector import * from .classifier import * from .load_model import load_model +from .imagerestorer import * \ No newline at end of file diff --git a/paddlers/tasks/imagerestorer.py b/paddlers/tasks/imagerestorer.py new file mode 100644 index 0000000..d05447a --- /dev/null +++ b/paddlers/tasks/imagerestorer.py @@ -0,0 +1,753 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +import datetime + +import paddle +from paddle.distributed import ParallelEnv + +from ..models.ppgan.datasets.builder import build_dataloader +from ..models.ppgan.models.builder import build_model +from ..models.ppgan.utils.visual import tensor2img, save_image +from ..models.ppgan.utils.filesystem import makedirs, save, load +from ..models.ppgan.utils.timer import TimeAverager +from ..models.ppgan.utils.profiler import add_profiler_step +from ..models.ppgan.utils.logger import setup_logger + + +# 定义AttrDict类实现动态属性 +class AttrDict(dict): + def __getattr__(self, key): + try: + return self[key] + except KeyError: + raise AttributeError(key) + + def __setattr__(self, key, value): + if key in self.__dict__: + self.__dict__[key] = value + else: + self[key] = value + + +# 创建AttrDict类 +def create_attr_dict(config_dict): + from ast import literal_eval + for key, value in config_dict.items(): + if type(value) is dict: + config_dict[key] = value = AttrDict(value) + if isinstance(value, str): + try: + value = literal_eval(value) + except BaseException: + pass + if isinstance(value, AttrDict): + create_attr_dict(config_dict[key]) + else: + config_dict[key] = value + + +# 数据加载类 +class IterLoader: + def __init__(self, dataloader): + self._dataloader = dataloader + self.iter_loader = iter(self._dataloader) + self._epoch = 1 + + @property + def epoch(self): + return self._epoch + + def __next__(self): + try: + data = next(self.iter_loader) + except StopIteration: + self._epoch += 1 + self.iter_loader = iter(self._dataloader) + data = next(self.iter_loader) + + return data + + def __len__(self): + return len(self._dataloader) + + +# 基础训练类 +class Restorer: + """ + # trainer calling logic: + # + # build_model || model(BaseModel) + # | || + # build_dataloader || dataloader + # | || + # model.setup_lr_schedulers || lr_scheduler + # | || + # model.setup_optimizers || optimizers + # | || + # train loop (model.setup_input + model.train_iter) || train loop + # | || + # print log (model.get_current_losses) || + # | || + # save checkpoint (model.nets) \/ + """ + + def __init__(self, cfg, logger): + # base config + # self.logger = logging.getLogger(__name__) + self.logger = logger + self.cfg = cfg + self.output_dir = cfg.output_dir + self.max_eval_steps = cfg.model.get('max_eval_steps', None) + + self.local_rank = ParallelEnv().local_rank + self.world_size = ParallelEnv().nranks + self.log_interval = cfg.log_config.interval + self.visual_interval = cfg.log_config.visiual_interval + self.weight_interval = cfg.snapshot_config.interval + + self.start_epoch = 1 + self.current_epoch = 1 + self.current_iter = 1 + self.inner_iter = 1 + self.batch_id = 0 + self.global_steps = 0 + + # build model + self.model = build_model(cfg.model) + # multiple gpus prepare + if ParallelEnv().nranks > 1: + self.distributed_data_parallel() + + # build metrics + self.metrics = None + self.is_save_img = True + validate_cfg = cfg.get('validate', None) + if validate_cfg and 'metrics' in validate_cfg: + self.metrics = self.model.setup_metrics(validate_cfg['metrics']) + if validate_cfg and 'save_img' in validate_cfg: + self.is_save_img = validate_cfg['save_img'] + + self.enable_visualdl = cfg.get('enable_visualdl', False) + if self.enable_visualdl: + import visualdl + self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir) + + # evaluate only + if not cfg.is_train: + return + + # build train dataloader + self.train_dataloader = build_dataloader(cfg.dataset.train) + self.iters_per_epoch = len(self.train_dataloader) + + # build lr scheduler + # TODO: has a better way? + if 'lr_scheduler' in cfg and 'iters_per_epoch' in cfg.lr_scheduler: + cfg.lr_scheduler.iters_per_epoch = self.iters_per_epoch + self.lr_schedulers = self.model.setup_lr_schedulers(cfg.lr_scheduler) + + # build optimizers + self.optimizers = self.model.setup_optimizers(self.lr_schedulers, + cfg.optimizer) + + self.epochs = cfg.get('epochs', None) + if self.epochs: + self.total_iters = self.epochs * self.iters_per_epoch + self.by_epoch = True + else: + self.by_epoch = False + self.total_iters = cfg.total_iters + + if self.by_epoch: + self.weight_interval *= self.iters_per_epoch + + self.validate_interval = -1 + if cfg.get('validate', None) is not None: + self.validate_interval = cfg.validate.get('interval', -1) + + self.time_count = {} + self.best_metric = {} + self.model.set_total_iter(self.total_iters) + self.profiler_options = cfg.profiler_options + + def distributed_data_parallel(self): + paddle.distributed.init_parallel_env() + find_unused_parameters = self.cfg.get('find_unused_parameters', False) + for net_name, net in self.model.nets.items(): + self.model.nets[net_name] = paddle.DataParallel( + net, find_unused_parameters=find_unused_parameters) + + def learning_rate_scheduler_step(self): + if isinstance(self.model.lr_scheduler, dict): + for lr_scheduler in self.model.lr_scheduler.values(): + lr_scheduler.step() + elif isinstance(self.model.lr_scheduler, + paddle.optimizer.lr.LRScheduler): + self.model.lr_scheduler.step() + else: + raise ValueError( + 'lr schedulter must be a dict or an instance of LRScheduler') + + def train(self): + reader_cost_averager = TimeAverager() + batch_cost_averager = TimeAverager() + + iter_loader = IterLoader(self.train_dataloader) + + # set model.is_train = True + self.model.setup_train_mode(is_train=True) + while self.current_iter < (self.total_iters + 1): + self.current_epoch = iter_loader.epoch + self.inner_iter = self.current_iter % self.iters_per_epoch + + add_profiler_step(self.profiler_options) + + start_time = step_start_time = time.time() + data = next(iter_loader) + reader_cost_averager.record(time.time() - step_start_time) + # unpack data from dataset and apply preprocessing + # data input should be dict + self.model.setup_input(data) + self.model.train_iter(self.optimizers) + + batch_cost_averager.record( + time.time() - step_start_time, + num_samples=self.cfg['dataset']['train'].get('batch_size', 1)) + + step_start_time = time.time() + + if self.current_iter % self.log_interval == 0: + self.data_time = reader_cost_averager.get_average() + self.step_time = batch_cost_averager.get_average() + self.ips = batch_cost_averager.get_ips_average() + self.print_log() + + reader_cost_averager.reset() + batch_cost_averager.reset() + + if self.current_iter % self.visual_interval == 0 and self.local_rank == 0: + self.visual('visual_train') + + self.learning_rate_scheduler_step() + + if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0: + self.test() + + if self.current_iter % self.weight_interval == 0: + self.save(self.current_iter, 'weight', keep=-1) + self.save(self.current_iter) + + self.current_iter += 1 + + def test(self): + if not hasattr(self, 'test_dataloader'): + self.test_dataloader = build_dataloader( + self.cfg.dataset.test, is_train=False) + iter_loader = IterLoader(self.test_dataloader) + if self.max_eval_steps is None: + self.max_eval_steps = len(self.test_dataloader) + + if self.metrics: + for metric in self.metrics.values(): + metric.reset() + + # set model.is_train = False + self.model.setup_train_mode(is_train=False) + + for i in range(self.max_eval_steps): + if self.max_eval_steps < self.log_interval or i % self.log_interval == 0: + self.logger.info('Test iter: [%d/%d]' % ( + i * self.world_size, self.max_eval_steps * self.world_size)) + + data = next(iter_loader) + self.model.setup_input(data) + self.model.test_iter(metrics=self.metrics) + + if self.is_save_img: + visual_results = {} + current_paths = self.model.get_image_paths() + current_visuals = self.model.get_current_visuals() + + if len(current_visuals) > 0 and list(current_visuals.values())[ + 0].shape == 4: + num_samples = list(current_visuals.values())[0].shape[0] + else: + num_samples = 1 + + for j in range(num_samples): + if j < len(current_paths): + short_path = os.path.basename(current_paths[j]) + basename = os.path.splitext(short_path)[0] + else: + basename = '{:04d}_{:04d}'.format(i, j) + for k, img_tensor in current_visuals.items(): + name = '%s_%s' % (basename, k) + if len(img_tensor.shape) == 4: + visual_results.update({name: img_tensor[j]}) + else: + visual_results.update({name: img_tensor}) + + self.visual( + 'visual_test', + visual_results=visual_results, + step=self.batch_id, + is_save_image=True) + + if self.metrics: + for metric_name, metric in self.metrics.items(): + self.logger.info("Metric {}: {:.4f}".format( + metric_name, metric.accumulate())) + + def print_log(self): + losses = self.model.get_current_losses() + + message = '' + if self.by_epoch: + message += 'Epoch: %d/%d, iter: %d/%d ' % ( + self.current_epoch, self.epochs, self.inner_iter, + self.iters_per_epoch) + else: + message += 'Iter: %d/%d ' % (self.current_iter, self.total_iters) + + message += f'lr: {self.current_learning_rate:.3e} ' + + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v) + if self.enable_visualdl: + self.vdl_logger.add_scalar(k, v, step=self.global_steps) + + if hasattr(self, 'step_time'): + message += 'batch_cost: %.5f sec ' % self.step_time + + if hasattr(self, 'data_time'): + message += 'reader_cost: %.5f sec ' % self.data_time + + if hasattr(self, 'ips'): + message += 'ips: %.5f images/s ' % self.ips + + if hasattr(self, 'step_time'): + eta = self.step_time * (self.total_iters - self.current_iter) + eta = eta if eta > 0 else 0 + + eta_str = str(datetime.timedelta(seconds=int(eta))) + message += f'eta: {eta_str}' + + # print the message + self.logger.info(message) + + @property + def current_learning_rate(self): + for optimizer in self.model.optimizers.values(): + return optimizer.get_lr() + + def visual(self, + results_dir, + visual_results=None, + step=None, + is_save_image=False): + """ + visual the images, use visualdl or directly write to the directory + Parameters: + results_dir (str) -- directory name which contains saved images + visual_results (dict) -- the results images dict + step (int) -- global steps, used in visualdl + is_save_image (bool) -- weather write to the directory or visualdl + """ + self.model.compute_visuals() + + if visual_results is None: + visual_results = self.model.get_current_visuals() + + min_max = self.cfg.get('min_max', None) + if min_max is None: + min_max = (-1., 1.) + + image_num = self.cfg.get('image_num', None) + if (image_num is None) or (not self.enable_visualdl): + image_num = 1 + for label, image in visual_results.items(): + image_numpy = tensor2img(image, min_max, image_num) + if (not is_save_image) and self.enable_visualdl: + self.vdl_logger.add_image( + results_dir + '/' + label, + image_numpy, + step=step if step else self.global_steps, + dataformats="HWC" if image_num == 1 else "NCHW") + else: + if self.cfg.is_train: + if self.by_epoch: + msg = 'epoch%.3d_' % self.current_epoch + else: + msg = 'iter%.3d_' % self.current_iter + else: + msg = '' + makedirs(os.path.join(self.output_dir, results_dir)) + img_path = os.path.join(self.output_dir, results_dir, + msg + '%s.png' % (label)) + save_image(image_numpy, img_path) + + def save(self, epoch, name='checkpoint', keep=1): + if self.local_rank != 0: + return + + assert name in ['checkpoint', 'weight'] + + state_dicts = {} + if self.by_epoch: + save_filename = 'epoch_%s_%s.pdparams' % ( + epoch // self.iters_per_epoch, name) + else: + save_filename = 'iter_%s_%s.pdparams' % (epoch, name) + + os.makedirs(self.output_dir, exist_ok=True) + save_path = os.path.join(self.output_dir, save_filename) + for net_name, net in self.model.nets.items(): + state_dicts[net_name] = net.state_dict() + + if name == 'weight': + save(state_dicts, save_path) + return + + state_dicts['epoch'] = epoch + + for opt_name, opt in self.model.optimizers.items(): + state_dicts[opt_name] = opt.state_dict() + + save(state_dicts, save_path) + + if keep > 0: + try: + if self.by_epoch: + checkpoint_name_to_be_removed = os.path.join( + self.output_dir, 'epoch_%s_%s.pdparams' % ( + (epoch - keep * self.weight_interval) // + self.iters_per_epoch, name)) + else: + checkpoint_name_to_be_removed = os.path.join( + self.output_dir, 'iter_%s_%s.pdparams' % + (epoch - keep * self.weight_interval, name)) + + if os.path.exists(checkpoint_name_to_be_removed): + os.remove(checkpoint_name_to_be_removed) + + except Exception as e: + self.logger.info('remove old checkpoints error: {}'.format(e)) + + def resume(self, checkpoint_path): + state_dicts = load(checkpoint_path) + if state_dicts.get('epoch', None) is not None: + self.start_epoch = state_dicts['epoch'] + 1 + self.global_steps = self.iters_per_epoch * state_dicts['epoch'] + + self.current_iter = state_dicts['epoch'] + 1 + + for net_name, net in self.model.nets.items(): + net.set_state_dict(state_dicts[net_name]) + + for opt_name, opt in self.model.optimizers.items(): + opt.set_state_dict(state_dicts[opt_name]) + + def load(self, weight_path): + state_dicts = load(weight_path) + + for net_name, net in self.model.nets.items(): + if net_name in state_dicts: + net.set_state_dict(state_dicts[net_name]) + self.logger.info('Loaded pretrained weight for net {}'.format( + net_name)) + else: + self.logger.warning( + 'Can not find state dict of net {}. Skip load pretrained weight for net {}' + .format(net_name, net_name)) + + def close(self): + """ + when finish the training need close file handler or other. + """ + if self.enable_visualdl: + self.vdl_logger.close() + + +# 基础超分模型训练类 +class BasicSRNet: + def __init__(self): + self.model = {} + self.optimizer = {} + self.lr_scheduler = {} + self.min_max = '' + + def train( + self, + total_iters, + train_dataset, + test_dataset, + output_dir, + validate, + snapshot, + log, + lr_rate, + evaluate_weights='', + resume='', + pretrain_weights='', + periods=[100000], + restart_weights=[1], ): + self.lr_scheduler['learning_rate'] = lr_rate + + if self.lr_scheduler['name'] == 'CosineAnnealingRestartLR': + self.lr_scheduler['periods'] = periods + self.lr_scheduler['restart_weights'] = restart_weights + + validate = { + 'interval': validate, + 'save_img': False, + 'metrics': { + 'psnr': { + 'name': 'PSNR', + 'crop_border': 4, + 'test_y_channel': True + }, + 'ssim': { + 'name': 'SSIM', + 'crop_border': 4, + 'test_y_channel': True + } + } + } + log_config = {'interval': log, 'visiual_interval': 500} + snapshot_config = {'interval': snapshot} + + cfg = { + 'total_iters': total_iters, + 'output_dir': output_dir, + 'min_max': self.min_max, + 'model': self.model, + 'dataset': { + 'train': train_dataset, + 'test': test_dataset + }, + 'lr_scheduler': self.lr_scheduler, + 'optimizer': self.optimizer, + 'validate': validate, + 'log_config': log_config, + 'snapshot_config': snapshot_config + } + + cfg = AttrDict(cfg) + create_attr_dict(cfg) + + cfg.is_train = True + cfg.profiler_options = None + cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime()) + + if cfg.model.name == 'BaseSRModel': + floderModelName = cfg.model.generator.name + else: + floderModelName = cfg.model.name + cfg.output_dir = os.path.join(cfg.output_dir, + floderModelName + cfg.timestamp) + + logger_cfg = setup_logger(cfg.output_dir) + logger_cfg.info('Configs: {}'.format(cfg)) + + if paddle.is_compiled_with_cuda(): + paddle.set_device('gpu') + else: + paddle.set_device('cpu') + + # build trainer + trainer = Restorer(cfg, logger_cfg) + + # continue train or evaluate, checkpoint need contain epoch and optimizer info + if len(resume) > 0: + trainer.resume(resume) + # evaluate or finute, only load generator weights + elif len(pretrain_weights) > 0: + trainer.load(pretrain_weights) + if len(evaluate_weights) > 0: + trainer.load(evaluate_weights) + trainer.test() + return + # training, when keyboard interrupt save weights + try: + trainer.train() + except KeyboardInterrupt as e: + trainer.save(trainer.current_epoch) + + trainer.close() + + +# DRN模型训练 +class DRNet(BasicSRNet): + def __init__(self, + n_blocks=30, + n_feats=16, + n_colors=3, + rgb_range=255, + negval=0.2): + super(DRNet, self).__init__() + self.min_max = '(0., 255.)' + self.generator = { + 'name': 'DRNGenerator', + 'scale': (2, 4), + 'n_blocks': n_blocks, + 'n_feats': n_feats, + 'n_colors': n_colors, + 'rgb_range': rgb_range, + 'negval': negval + } + self.pixel_criterion = {'name': 'L1Loss'} + self.model = { + 'name': 'DRN', + 'generator': self.generator, + 'pixel_criterion': self.pixel_criterion + } + self.optimizer = { + 'optimG': { + 'name': 'Adam', + 'net_names': ['generator'], + 'weight_decay': 0.0, + 'beta1': 0.9, + 'beta2': 0.999 + }, + 'optimD': { + 'name': 'Adam', + 'net_names': ['dual_model_0', 'dual_model_1'], + 'weight_decay': 0.0, + 'beta1': 0.9, + 'beta2': 0.999 + } + } + self.lr_scheduler = { + 'name': 'CosineAnnealingRestartLR', + 'eta_min': 1e-07 + } + + +# 轻量化超分模型LESRCNN训练 +class LESRCNNet(BasicSRNet): + def __init__(self, scale=4, multi_scale=False, group=1): + super(LESRCNNet, self).__init__() + self.min_max = '(0., 1.)' + self.generator = { + 'name': 'LESRCNNGenerator', + 'scale': scale, + 'multi_scale': False, + 'group': 1 + } + self.pixel_criterion = {'name': 'L1Loss'} + self.model = { + 'name': 'BaseSRModel', + 'generator': self.generator, + 'pixel_criterion': self.pixel_criterion + } + self.optimizer = { + 'name': 'Adam', + 'net_names': ['generator'], + 'beta1': 0.9, + 'beta2': 0.99 + } + self.lr_scheduler = { + 'name': 'CosineAnnealingRestartLR', + 'eta_min': 1e-07 + } + + +# ESRGAN模型训练 +# 若loss_type='gan' 使用感知损失、对抗损失和像素损失 +# 若loss_type = 'pixel' 只使用像素损失 +class ESRGANet(BasicSRNet): + def __init__(self, loss_type='gan', in_nc=3, out_nc=3, nf=64, nb=23): + super(ESRGANet, self).__init__() + self.min_max = '(0., 1.)' + self.generator = { + 'name': 'RRDBNet', + 'in_nc': in_nc, + 'out_nc': out_nc, + 'nf': nf, + 'nb': nb + } + + if loss_type == 'gan': + # 定义损失函数 + self.pixel_criterion = {'name': 'L1Loss', 'loss_weight': 0.01} + self.discriminator = { + 'name': 'VGGDiscriminator128', + 'in_channels': 3, + 'num_feat': 64 + } + self.perceptual_criterion = { + 'name': 'PerceptualLoss', + 'layer_weights': { + '34': 1.0 + }, + 'perceptual_weight': 1.0, + 'style_weight': 0.0, + 'norm_img': False + } + self.gan_criterion = { + 'name': 'GANLoss', + 'gan_mode': 'vanilla', + 'loss_weight': 0.005 + } + # 定义模型 + self.model = { + 'name': 'ESRGAN', + 'generator': self.generator, + 'discriminator': self.discriminator, + 'pixel_criterion': self.pixel_criterion, + 'perceptual_criterion': self.perceptual_criterion, + 'gan_criterion': self.gan_criterion + } + self.optimizer = { + 'optimG': { + 'name': 'Adam', + 'net_names': ['generator'], + 'weight_decay': 0.0, + 'beta1': 0.9, + 'beta2': 0.99 + }, + 'optimD': { + 'name': 'Adam', + 'net_names': ['discriminator'], + 'weight_decay': 0.0, + 'beta1': 0.9, + 'beta2': 0.99 + } + } + self.lr_scheduler = { + 'name': 'MultiStepDecay', + 'milestones': [50000, 100000, 200000, 300000], + 'gamma': 0.5 + } + else: + self.pixel_criterion = {'name': 'L1Loss'} + self.model = { + 'name': 'BaseSRModel', + 'generator': self.generator, + 'pixel_criterion': self.pixel_criterion + } + self.optimizer = { + 'name': 'Adam', + 'net_names': ['generator'], + 'beta1': 0.9, + 'beta2': 0.99 + } + self.lr_scheduler = { + 'name': 'CosineAnnealingRestartLR', + 'eta_min': 1e-07 + } diff --git a/tutorials/train/image_restoration/drn_train.py b/tutorials/train/image_restoration/drn_train.py new file mode 100644 index 0000000..899338c --- /dev/null +++ b/tutorials/train/image_restoration/drn_train.py @@ -0,0 +1,81 @@ +import os +import sys +sys.path.append(os.path.abspath('../PaddleRS')) + +import paddle +import paddlers as pdrs + +if __name__ == "__main__": + + # 定义训练和验证时的transforms + train_transforms = pdrs.datasets.ComposeTrans( + input_keys=['lq', 'gt'], + output_keys=['lq', 'lqx2', 'gt'], + pipelines=[{ + 'name': 'SRPairedRandomCrop', + 'gt_patch_size': 192, + 'scale': 4, + 'scale_list': True + }, { + 'name': 'PairedRandomHorizontalFlip' + }, { + 'name': 'PairedRandomVerticalFlip' + }, { + 'name': 'PairedRandomTransposeHW' + }, { + 'name': 'Transpose' + }, { + 'name': 'Normalize', + 'mean': [0.0, 0.0, 0.0], + 'std': [1.0, 1.0, 1.0] + }]) + + test_transforms = pdrs.datasets.ComposeTrans( + input_keys=['lq', 'gt'], + output_keys=['lq', 'gt'], + pipelines=[{ + 'name': 'Transpose' + }, { + 'name': 'Normalize', + 'mean': [0.0, 0.0, 0.0], + 'std': [1.0, 1.0, 1.0] + }]) + + # 定义训练集 + train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径 + train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径 + num_workers = 4 + batch_size = 8 + scale = 4 + train_dataset = pdrs.datasets.SRdataset( + mode='train', + gt_floder=train_gt_floder, + lq_floder=train_lq_floder, + transforms=train_transforms(), + scale=scale, + num_workers=num_workers, + batch_size=batch_size) + train_dict = train_dataset() + + # 定义测试集 + test_gt_floder = r"../work/RSdata_for_SR/test_HR" + test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4" + test_dataset = pdrs.datasets.SRdataset( + mode='test', + gt_floder=test_gt_floder, + lq_floder=test_lq_floder, + transforms=test_transforms(), + scale=scale) + + # 初始化模型,可以对网络结构的参数进行调整 + model = pdrs.tasks.DRNet( + n_blocks=30, n_feats=16, n_colors=3, rgb_range=255, negval=0.2) + + model.train( + total_iters=100000, + train_dataset=train_dataset(), + test_dataset=test_dataset(), + output_dir='output_dir', + validate=5000, + snapshot=5000, + lr_rate=0.0001) diff --git a/tutorials/train/image_restoration/esrgan_train.py b/tutorials/train/image_restoration/esrgan_train.py new file mode 100644 index 0000000..47f5b86 --- /dev/null +++ b/tutorials/train/image_restoration/esrgan_train.py @@ -0,0 +1,80 @@ +import os +import sys +sys.path.append(os.path.abspath('../PaddleRS')) + +import paddlers as pdrs + +# 定义训练和验证时的transforms +train_transforms = pdrs.datasets.ComposeTrans( + input_keys=['lq', 'gt'], + output_keys=['lq', 'gt'], + pipelines=[{ + 'name': 'SRPairedRandomCrop', + 'gt_patch_size': 128, + 'scale': 4 + }, { + 'name': 'PairedRandomHorizontalFlip' + }, { + 'name': 'PairedRandomVerticalFlip' + }, { + 'name': 'PairedRandomTransposeHW' + }, { + 'name': 'Transpose' + }, { + 'name': 'Normalize', + 'mean': [0.0, 0.0, 0.0], + 'std': [255.0, 255.0, 255.0] + }]) + +test_transforms = pdrs.datasets.ComposeTrans( + input_keys=['lq', 'gt'], + output_keys=['lq', 'gt'], + pipelines=[{ + 'name': 'Transpose' + }, { + 'name': 'Normalize', + 'mean': [0.0, 0.0, 0.0], + 'std': [255.0, 255.0, 255.0] + }]) + +# 定义训练集 +train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径 +train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径 +num_workers = 6 +batch_size = 32 +scale = 4 +train_dataset = pdrs.datasets.SRdataset( + mode='train', + gt_floder=train_gt_floder, + lq_floder=train_lq_floder, + transforms=train_transforms(), + scale=scale, + num_workers=num_workers, + batch_size=batch_size) + +# 定义测试集 +test_gt_floder = r"../work/RSdata_for_SR/test_HR" +test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4" +test_dataset = pdrs.datasets.SRdataset( + mode='test', + gt_floder=test_gt_floder, + lq_floder=test_lq_floder, + transforms=test_transforms(), + scale=scale) + +# 初始化模型,可以对网络结构的参数进行调整 +# 若loss_type='gan' 使用感知损失、对抗损失和像素损失 +# 若loss_type = 'pixel' 只使用像素损失 +model = pdrs.tasks.ESRGANet(loss_type='pixel') + +model.train( + total_iters=1000000, + train_dataset=train_dataset(), + test_dataset=test_dataset(), + output_dir='output_dir', + validate=5000, + snapshot=5000, + log=100, + lr_rate=0.0001, + periods=[250000, 250000, 250000, 250000], + restart_weights=[1, 1, 1, 1]) diff --git a/tutorials/train/image_restoration/lesrcnn_train.py b/tutorials/train/image_restoration/lesrcnn_train.py new file mode 100644 index 0000000..b1afdf2 --- /dev/null +++ b/tutorials/train/image_restoration/lesrcnn_train.py @@ -0,0 +1,78 @@ +import os +import sys +sys.path.append(os.path.abspath('../PaddleRS')) + +import paddlers as pdrs + +# 定义训练和验证时的transforms +train_transforms = pdrs.datasets.ComposeTrans( + input_keys=['lq', 'gt'], + output_keys=['lq', 'gt'], + pipelines=[{ + 'name': 'SRPairedRandomCrop', + 'gt_patch_size': 192, + 'scale': 4 + }, { + 'name': 'PairedRandomHorizontalFlip' + }, { + 'name': 'PairedRandomVerticalFlip' + }, { + 'name': 'PairedRandomTransposeHW' + }, { + 'name': 'Transpose' + }, { + 'name': 'Normalize', + 'mean': [0.0, 0.0, 0.0], + 'std': [255.0, 255.0, 255.0] + }]) + +test_transforms = pdrs.datasets.ComposeTrans( + input_keys=['lq', 'gt'], + output_keys=['lq', 'gt'], + pipelines=[{ + 'name': 'Transpose' + }, { + 'name': 'Normalize', + 'mean': [0.0, 0.0, 0.0], + 'std': [255.0, 255.0, 255.0] + }]) + +# 定义训练集 +train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径 +train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径 +num_workers = 4 +batch_size = 16 +scale = 4 +train_dataset = pdrs.datasets.SRdataset( + mode='train', + gt_floder=train_gt_floder, + lq_floder=train_lq_floder, + transforms=train_transforms(), + scale=scale, + num_workers=num_workers, + batch_size=batch_size) + +# 定义测试集 +test_gt_floder = r"../work/RSdata_for_SR/test_HR" +test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4" +test_dataset = pdrs.datasets.SRdataset( + mode='test', + gt_floder=test_gt_floder, + lq_floder=test_lq_floder, + transforms=test_transforms(), + scale=scale) + +# 初始化模型,可以对网络结构的参数进行调整 +model = pdrs.tasks.LESRCNNet(scale=4, multi_scale=False, group=1) + +model.train( + total_iters=1000000, + train_dataset=train_dataset(), + test_dataset=test_dataset(), + output_dir='output_dir', + validate=5000, + snapshot=5000, + log=100, + lr_rate=0.0001, + periods=[250000, 250000, 250000, 250000], + restart_weights=[1, 1, 1, 1])