Merge pull request #24 from Bobholamovic/refactor_res
[Refactor] Refactor Models for Restoration Tasksown
commit
6b02cf875f
61 changed files with 2038 additions and 1460 deletions
@ -0,0 +1,83 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import os.path as osp |
||||
import copy |
||||
|
||||
from .base import BaseDataset |
||||
from paddlers.utils import logging, get_encoding, norm_path, is_pic |
||||
|
||||
|
||||
class ResDataset(BaseDataset): |
||||
""" |
||||
Dataset for image restoration tasks. |
||||
|
||||
Args: |
||||
data_dir (str): Root directory of the dataset. |
||||
file_list (str): Path of the file that contains relative paths of source and target image files. |
||||
transforms (paddlers.transforms.Compose): Data preprocessing and data augmentation operators to apply. |
||||
num_workers (int|str, optional): Number of processes used for data loading. If `num_workers` is 'auto', |
||||
the number of workers will be automatically determined according to the number of CPU cores: If |
||||
there are more than 16 cores,8 workers will be used. Otherwise, the number of workers will be half |
||||
the number of CPU cores. Defaults: 'auto'. |
||||
shuffle (bool, optional): Whether to shuffle the samples. Defaults to False. |
||||
sr_factor (int|None, optional): Scaling factor of image super-resolution task. None for other image |
||||
restoration tasks. Defaults to None. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
data_dir, |
||||
file_list, |
||||
transforms, |
||||
num_workers='auto', |
||||
shuffle=False, |
||||
sr_factor=None): |
||||
super(ResDataset, self).__init__(data_dir, None, transforms, |
||||
num_workers, shuffle) |
||||
self.batch_transforms = None |
||||
self.file_list = list() |
||||
|
||||
with open(file_list, encoding=get_encoding(file_list)) as f: |
||||
for line in f: |
||||
items = line.strip().split() |
||||
if len(items) > 2: |
||||
raise ValueError( |
||||
"A space is defined as the delimiter to separate the source and target image path, " \ |
||||
"so the space cannot be in the source image or target image path, but the line[{}] of " \ |
||||
" file_list[{}] has a space in the two paths.".format(line, file_list)) |
||||
items[0] = norm_path(items[0]) |
||||
items[1] = norm_path(items[1]) |
||||
full_path_im = osp.join(data_dir, items[0]) |
||||
full_path_tar = osp.join(data_dir, items[1]) |
||||
if not is_pic(full_path_im) or not is_pic(full_path_tar): |
||||
continue |
||||
if not osp.exists(full_path_im): |
||||
raise IOError("Source image file {} does not exist!".format( |
||||
full_path_im)) |
||||
if not osp.exists(full_path_tar): |
||||
raise IOError("Target image file {} does not exist!".format( |
||||
full_path_tar)) |
||||
sample = { |
||||
'image': full_path_im, |
||||
'target': full_path_tar, |
||||
} |
||||
if sr_factor is not None: |
||||
sample['sr_factor'] = sr_factor |
||||
self.file_list.append(sample) |
||||
self.num_samples = len(self.file_list) |
||||
logging.info("{} samples in file {}".format( |
||||
len(self.file_list), file_list)) |
||||
|
||||
def __len__(self): |
||||
return len(self.file_list) |
@ -1,99 +0,0 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
|
||||
# 超分辨率数据集定义 |
||||
class SRdataset(object): |
||||
def __init__(self, |
||||
mode, |
||||
gt_floder, |
||||
lq_floder, |
||||
transforms, |
||||
scale, |
||||
num_workers=4, |
||||
batch_size=8): |
||||
if mode == 'train': |
||||
preprocess = [] |
||||
preprocess.append({ |
||||
'name': 'LoadImageFromFile', |
||||
'key': 'lq' |
||||
}) # 加载方式 |
||||
preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'}) |
||||
preprocess.append(transforms) # 变换方式 |
||||
self.dataset = { |
||||
'name': 'SRDataset', |
||||
'gt_folder': gt_floder, |
||||
'lq_folder': lq_floder, |
||||
'num_workers': num_workers, |
||||
'batch_size': batch_size, |
||||
'scale': scale, |
||||
'preprocess': preprocess |
||||
} |
||||
|
||||
if mode == "test": |
||||
preprocess = [] |
||||
preprocess.append({'name': 'LoadImageFromFile', 'key': 'lq'}) |
||||
preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'}) |
||||
preprocess.append(transforms) |
||||
self.dataset = { |
||||
'name': 'SRDataset', |
||||
'gt_folder': gt_floder, |
||||
'lq_folder': lq_floder, |
||||
'scale': scale, |
||||
'preprocess': preprocess |
||||
} |
||||
|
||||
def __call__(self): |
||||
return self.dataset |
||||
|
||||
|
||||
# 对定义的transforms处理方式组合,返回字典 |
||||
class ComposeTrans(object): |
||||
def __init__(self, input_keys, output_keys, pipelines): |
||||
if not isinstance(pipelines, list): |
||||
raise TypeError( |
||||
'Type of transforms is invalid. Must be List, but received is {}' |
||||
.format(type(pipelines))) |
||||
if len(pipelines) < 1: |
||||
raise ValueError( |
||||
'Length of transforms must not be less than 1, but received is {}' |
||||
.format(len(pipelines))) |
||||
self.transforms = pipelines |
||||
self.output_length = len(output_keys) # 当output_keys的长度为3时,是DRN训练 |
||||
self.input_keys = input_keys |
||||
self.output_keys = output_keys |
||||
|
||||
def __call__(self): |
||||
pipeline = [] |
||||
for op in self.transforms: |
||||
if op['name'] == 'SRPairedRandomCrop': |
||||
op['keys'] = ['image'] * 2 |
||||
else: |
||||
op['keys'] = ['image'] * self.output_length |
||||
pipeline.append(op) |
||||
if self.output_length == 2: |
||||
transform_dict = { |
||||
'name': 'Transforms', |
||||
'input_keys': self.input_keys, |
||||
'pipeline': pipeline |
||||
} |
||||
else: |
||||
transform_dict = { |
||||
'name': 'Transforms', |
||||
'input_keys': self.input_keys, |
||||
'output_keys': self.output_keys, |
||||
'pipeline': pipeline |
||||
} |
||||
|
||||
return transform_dict |
@ -1,106 +0,0 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
|
||||
from .generators.builder import build_generator |
||||
from ...models.ppgan.models.criterions.builder import build_criterion |
||||
from ...models.ppgan.models.base_model import BaseModel |
||||
from ...models.ppgan.models.builder import MODELS |
||||
from ...models.ppgan.utils.visual import tensor2img |
||||
from ...models.ppgan.modules.init import reset_parameters |
||||
|
||||
|
||||
@MODELS.register() |
||||
class RCANModel(BaseModel): |
||||
""" |
||||
Base SR model for single image super-resolution. |
||||
""" |
||||
|
||||
def __init__(self, generator, pixel_criterion=None, use_init_weight=False): |
||||
""" |
||||
Args: |
||||
generator (dict): config of generator. |
||||
pixel_criterion (dict): config of pixel criterion. |
||||
""" |
||||
super(RCANModel, self).__init__() |
||||
|
||||
self.nets['generator'] = build_generator(generator) |
||||
self.error_last = 1e8 |
||||
self.batch = 0 |
||||
if pixel_criterion: |
||||
self.pixel_criterion = build_criterion(pixel_criterion) |
||||
if use_init_weight: |
||||
init_sr_weight(self.nets['generator']) |
||||
|
||||
def setup_input(self, input): |
||||
self.lq = paddle.to_tensor(input['lq']) |
||||
self.visual_items['lq'] = self.lq |
||||
if 'gt' in input: |
||||
self.gt = paddle.to_tensor(input['gt']) |
||||
self.visual_items['gt'] = self.gt |
||||
self.image_paths = input['lq_path'] |
||||
|
||||
def forward(self): |
||||
pass |
||||
|
||||
def train_iter(self, optims=None): |
||||
optims['optim'].clear_grad() |
||||
|
||||
self.output = self.nets['generator'](self.lq) |
||||
self.visual_items['output'] = self.output |
||||
# pixel loss |
||||
loss_pixel = self.pixel_criterion(self.output, self.gt) |
||||
self.losses['loss_pixel'] = loss_pixel |
||||
|
||||
skip_threshold = 1e6 |
||||
|
||||
if loss_pixel.item() < skip_threshold * self.error_last: |
||||
loss_pixel.backward() |
||||
optims['optim'].step() |
||||
else: |
||||
print('Skip this batch {}! (Loss: {})'.format(self.batch + 1, |
||||
loss_pixel.item())) |
||||
self.batch += 1 |
||||
|
||||
if self.batch % 1000 == 0: |
||||
self.error_last = loss_pixel.item() / 1000 |
||||
print("update error_last:{}".format(self.error_last)) |
||||
|
||||
def test_iter(self, metrics=None): |
||||
self.nets['generator'].eval() |
||||
with paddle.no_grad(): |
||||
self.output = self.nets['generator'](self.lq) |
||||
self.visual_items['output'] = self.output |
||||
self.nets['generator'].train() |
||||
|
||||
out_img = [] |
||||
gt_img = [] |
||||
for out_tensor, gt_tensor in zip(self.output, self.gt): |
||||
out_img.append(tensor2img(out_tensor, (0., 255.))) |
||||
gt_img.append(tensor2img(gt_tensor, (0., 255.))) |
||||
|
||||
if metrics is not None: |
||||
for metric in metrics.values(): |
||||
metric.update(out_img, gt_img) |
||||
|
||||
|
||||
def init_sr_weight(net): |
||||
def reset_func(m): |
||||
if hasattr(m, 'weight') and ( |
||||
not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))): |
||||
reset_parameters(m) |
||||
|
||||
net.apply(reset_func) |
@ -1,786 +0,0 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import os |
||||
import time |
||||
import datetime |
||||
|
||||
import paddle |
||||
from paddle.distributed import ParallelEnv |
||||
|
||||
from ..models.ppgan.datasets.builder import build_dataloader |
||||
from ..models.ppgan.models.builder import build_model |
||||
from ..models.ppgan.utils.visual import tensor2img, save_image |
||||
from ..models.ppgan.utils.filesystem import makedirs, save, load |
||||
from ..models.ppgan.utils.timer import TimeAverager |
||||
from ..models.ppgan.utils.profiler import add_profiler_step |
||||
from ..models.ppgan.utils.logger import setup_logger |
||||
|
||||
|
||||
# 定义AttrDict类实现动态属性 |
||||
class AttrDict(dict): |
||||
def __getattr__(self, key): |
||||
try: |
||||
return self[key] |
||||
except KeyError: |
||||
raise AttributeError(key) |
||||
|
||||
def __setattr__(self, key, value): |
||||
if key in self.__dict__: |
||||
self.__dict__[key] = value |
||||
else: |
||||
self[key] = value |
||||
|
||||
|
||||
# 创建AttrDict类 |
||||
def create_attr_dict(config_dict): |
||||
from ast import literal_eval |
||||
for key, value in config_dict.items(): |
||||
if type(value) is dict: |
||||
config_dict[key] = value = AttrDict(value) |
||||
if isinstance(value, str): |
||||
try: |
||||
value = literal_eval(value) |
||||
except BaseException: |
||||
pass |
||||
if isinstance(value, AttrDict): |
||||
create_attr_dict(config_dict[key]) |
||||
else: |
||||
config_dict[key] = value |
||||
|
||||
|
||||
# 数据加载类 |
||||
class IterLoader: |
||||
def __init__(self, dataloader): |
||||
self._dataloader = dataloader |
||||
self.iter_loader = iter(self._dataloader) |
||||
self._epoch = 1 |
||||
|
||||
@property |
||||
def epoch(self): |
||||
return self._epoch |
||||
|
||||
def __next__(self): |
||||
try: |
||||
data = next(self.iter_loader) |
||||
except StopIteration: |
||||
self._epoch += 1 |
||||
self.iter_loader = iter(self._dataloader) |
||||
data = next(self.iter_loader) |
||||
|
||||
return data |
||||
|
||||
def __len__(self): |
||||
return len(self._dataloader) |
||||
|
||||
|
||||
# 基础训练类 |
||||
class Restorer: |
||||
""" |
||||
# trainer calling logic: |
||||
# |
||||
# build_model || model(BaseModel) |
||||
# | || |
||||
# build_dataloader || dataloader |
||||
# | || |
||||
# model.setup_lr_schedulers || lr_scheduler |
||||
# | || |
||||
# model.setup_optimizers || optimizers |
||||
# | || |
||||
# train loop (model.setup_input + model.train_iter) || train loop |
||||
# | || |
||||
# print log (model.get_current_losses) || |
||||
# | || |
||||
# save checkpoint (model.nets) \/ |
||||
""" |
||||
|
||||
def __init__(self, cfg, logger): |
||||
# base config |
||||
# self.logger = logging.getLogger(__name__) |
||||
self.logger = logger |
||||
self.cfg = cfg |
||||
self.output_dir = cfg.output_dir |
||||
self.max_eval_steps = cfg.model.get('max_eval_steps', None) |
||||
|
||||
self.local_rank = ParallelEnv().local_rank |
||||
self.world_size = ParallelEnv().nranks |
||||
self.log_interval = cfg.log_config.interval |
||||
self.visual_interval = cfg.log_config.visiual_interval |
||||
self.weight_interval = cfg.snapshot_config.interval |
||||
|
||||
self.start_epoch = 1 |
||||
self.current_epoch = 1 |
||||
self.current_iter = 1 |
||||
self.inner_iter = 1 |
||||
self.batch_id = 0 |
||||
self.global_steps = 0 |
||||
|
||||
# build model |
||||
self.model = build_model(cfg.model) |
||||
# multiple gpus prepare |
||||
if ParallelEnv().nranks > 1: |
||||
self.distributed_data_parallel() |
||||
|
||||
# build metrics |
||||
self.metrics = None |
||||
self.is_save_img = True |
||||
validate_cfg = cfg.get('validate', None) |
||||
if validate_cfg and 'metrics' in validate_cfg: |
||||
self.metrics = self.model.setup_metrics(validate_cfg['metrics']) |
||||
if validate_cfg and 'save_img' in validate_cfg: |
||||
self.is_save_img = validate_cfg['save_img'] |
||||
|
||||
self.enable_visualdl = cfg.get('enable_visualdl', False) |
||||
if self.enable_visualdl: |
||||
import visualdl |
||||
self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir) |
||||
|
||||
# evaluate only |
||||
if not cfg.is_train: |
||||
return |
||||
|
||||
# build train dataloader |
||||
self.train_dataloader = build_dataloader(cfg.dataset.train) |
||||
self.iters_per_epoch = len(self.train_dataloader) |
||||
|
||||
# build lr scheduler |
||||
# TODO: has a better way? |
||||
if 'lr_scheduler' in cfg and 'iters_per_epoch' in cfg.lr_scheduler: |
||||
cfg.lr_scheduler.iters_per_epoch = self.iters_per_epoch |
||||
self.lr_schedulers = self.model.setup_lr_schedulers(cfg.lr_scheduler) |
||||
|
||||
# build optimizers |
||||
self.optimizers = self.model.setup_optimizers(self.lr_schedulers, |
||||
cfg.optimizer) |
||||
|
||||
self.epochs = cfg.get('epochs', None) |
||||
if self.epochs: |
||||
self.total_iters = self.epochs * self.iters_per_epoch |
||||
self.by_epoch = True |
||||
else: |
||||
self.by_epoch = False |
||||
self.total_iters = cfg.total_iters |
||||
|
||||
if self.by_epoch: |
||||
self.weight_interval *= self.iters_per_epoch |
||||
|
||||
self.validate_interval = -1 |
||||
if cfg.get('validate', None) is not None: |
||||
self.validate_interval = cfg.validate.get('interval', -1) |
||||
|
||||
self.time_count = {} |
||||
self.best_metric = {} |
||||
self.model.set_total_iter(self.total_iters) |
||||
self.profiler_options = cfg.profiler_options |
||||
|
||||
def distributed_data_parallel(self): |
||||
paddle.distributed.init_parallel_env() |
||||
find_unused_parameters = self.cfg.get('find_unused_parameters', False) |
||||
for net_name, net in self.model.nets.items(): |
||||
self.model.nets[net_name] = paddle.DataParallel( |
||||
net, find_unused_parameters=find_unused_parameters) |
||||
|
||||
def learning_rate_scheduler_step(self): |
||||
if isinstance(self.model.lr_scheduler, dict): |
||||
for lr_scheduler in self.model.lr_scheduler.values(): |
||||
lr_scheduler.step() |
||||
elif isinstance(self.model.lr_scheduler, |
||||
paddle.optimizer.lr.LRScheduler): |
||||
self.model.lr_scheduler.step() |
||||
else: |
||||
raise ValueError( |
||||
'lr schedulter must be a dict or an instance of LRScheduler') |
||||
|
||||
def train(self): |
||||
reader_cost_averager = TimeAverager() |
||||
batch_cost_averager = TimeAverager() |
||||
|
||||
iter_loader = IterLoader(self.train_dataloader) |
||||
|
||||
# set model.is_train = True |
||||
self.model.setup_train_mode(is_train=True) |
||||
while self.current_iter < (self.total_iters + 1): |
||||
self.current_epoch = iter_loader.epoch |
||||
self.inner_iter = self.current_iter % self.iters_per_epoch |
||||
|
||||
add_profiler_step(self.profiler_options) |
||||
|
||||
start_time = step_start_time = time.time() |
||||
data = next(iter_loader) |
||||
reader_cost_averager.record(time.time() - step_start_time) |
||||
# unpack data from dataset and apply preprocessing |
||||
# data input should be dict |
||||
self.model.setup_input(data) |
||||
self.model.train_iter(self.optimizers) |
||||
|
||||
batch_cost_averager.record( |
||||
time.time() - step_start_time, |
||||
num_samples=self.cfg['dataset']['train'].get('batch_size', 1)) |
||||
|
||||
step_start_time = time.time() |
||||
|
||||
if self.current_iter % self.log_interval == 0: |
||||
self.data_time = reader_cost_averager.get_average() |
||||
self.step_time = batch_cost_averager.get_average() |
||||
self.ips = batch_cost_averager.get_ips_average() |
||||
self.print_log() |
||||
|
||||
reader_cost_averager.reset() |
||||
batch_cost_averager.reset() |
||||
|
||||
if self.current_iter % self.visual_interval == 0 and self.local_rank == 0: |
||||
self.visual('visual_train') |
||||
|
||||
self.learning_rate_scheduler_step() |
||||
|
||||
if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0: |
||||
self.test() |
||||
|
||||
if self.current_iter % self.weight_interval == 0: |
||||
self.save(self.current_iter, 'weight', keep=-1) |
||||
self.save(self.current_iter) |
||||
|
||||
self.current_iter += 1 |
||||
|
||||
def test(self): |
||||
if not hasattr(self, 'test_dataloader'): |
||||
self.test_dataloader = build_dataloader( |
||||
self.cfg.dataset.test, is_train=False) |
||||
iter_loader = IterLoader(self.test_dataloader) |
||||
if self.max_eval_steps is None: |
||||
self.max_eval_steps = len(self.test_dataloader) |
||||
|
||||
if self.metrics: |
||||
for metric in self.metrics.values(): |
||||
metric.reset() |
||||
|
||||
# set model.is_train = False |
||||
self.model.setup_train_mode(is_train=False) |
||||
|
||||
for i in range(self.max_eval_steps): |
||||
if self.max_eval_steps < self.log_interval or i % self.log_interval == 0: |
||||
self.logger.info('Test iter: [%d/%d]' % ( |
||||
i * self.world_size, self.max_eval_steps * self.world_size)) |
||||
|
||||
data = next(iter_loader) |
||||
self.model.setup_input(data) |
||||
self.model.test_iter(metrics=self.metrics) |
||||
|
||||
if self.is_save_img: |
||||
visual_results = {} |
||||
current_paths = self.model.get_image_paths() |
||||
current_visuals = self.model.get_current_visuals() |
||||
|
||||
if len(current_visuals) > 0 and list(current_visuals.values())[ |
||||
0].shape == 4: |
||||
num_samples = list(current_visuals.values())[0].shape[0] |
||||
else: |
||||
num_samples = 1 |
||||
|
||||
for j in range(num_samples): |
||||
if j < len(current_paths): |
||||
short_path = os.path.basename(current_paths[j]) |
||||
basename = os.path.splitext(short_path)[0] |
||||
else: |
||||
basename = '{:04d}_{:04d}'.format(i, j) |
||||
for k, img_tensor in current_visuals.items(): |
||||
name = '%s_%s' % (basename, k) |
||||
if len(img_tensor.shape) == 4: |
||||
visual_results.update({name: img_tensor[j]}) |
||||
else: |
||||
visual_results.update({name: img_tensor}) |
||||
|
||||
self.visual( |
||||
'visual_test', |
||||
visual_results=visual_results, |
||||
step=self.batch_id, |
||||
is_save_image=True) |
||||
|
||||
if self.metrics: |
||||
for metric_name, metric in self.metrics.items(): |
||||
self.logger.info("Metric {}: {:.4f}".format( |
||||
metric_name, metric.accumulate())) |
||||
|
||||
def print_log(self): |
||||
losses = self.model.get_current_losses() |
||||
|
||||
message = '' |
||||
if self.by_epoch: |
||||
message += 'Epoch: %d/%d, iter: %d/%d ' % ( |
||||
self.current_epoch, self.epochs, self.inner_iter, |
||||
self.iters_per_epoch) |
||||
else: |
||||
message += 'Iter: %d/%d ' % (self.current_iter, self.total_iters) |
||||
|
||||
message += f'lr: {self.current_learning_rate:.3e} ' |
||||
|
||||
for k, v in losses.items(): |
||||
message += '%s: %.3f ' % (k, v) |
||||
if self.enable_visualdl: |
||||
self.vdl_logger.add_scalar(k, v, step=self.global_steps) |
||||
|
||||
if hasattr(self, 'step_time'): |
||||
message += 'batch_cost: %.5f sec ' % self.step_time |
||||
|
||||
if hasattr(self, 'data_time'): |
||||
message += 'reader_cost: %.5f sec ' % self.data_time |
||||
|
||||
if hasattr(self, 'ips'): |
||||
message += 'ips: %.5f images/s ' % self.ips |
||||
|
||||
if hasattr(self, 'step_time'): |
||||
eta = self.step_time * (self.total_iters - self.current_iter) |
||||
eta = eta if eta > 0 else 0 |
||||
|
||||
eta_str = str(datetime.timedelta(seconds=int(eta))) |
||||
message += f'eta: {eta_str}' |
||||
|
||||
# print the message |
||||
self.logger.info(message) |
||||
|
||||
@property |
||||
def current_learning_rate(self): |
||||
for optimizer in self.model.optimizers.values(): |
||||
return optimizer.get_lr() |
||||
|
||||
def visual(self, |
||||
results_dir, |
||||
visual_results=None, |
||||
step=None, |
||||
is_save_image=False): |
||||
""" |
||||
visual the images, use visualdl or directly write to the directory |
||||
Parameters: |
||||
results_dir (str) -- directory name which contains saved images |
||||
visual_results (dict) -- the results images dict |
||||
step (int) -- global steps, used in visualdl |
||||
is_save_image (bool) -- weather write to the directory or visualdl |
||||
""" |
||||
self.model.compute_visuals() |
||||
|
||||
if visual_results is None: |
||||
visual_results = self.model.get_current_visuals() |
||||
|
||||
min_max = self.cfg.get('min_max', None) |
||||
if min_max is None: |
||||
min_max = (-1., 1.) |
||||
|
||||
image_num = self.cfg.get('image_num', None) |
||||
if (image_num is None) or (not self.enable_visualdl): |
||||
image_num = 1 |
||||
for label, image in visual_results.items(): |
||||
image_numpy = tensor2img(image, min_max, image_num) |
||||
if (not is_save_image) and self.enable_visualdl: |
||||
self.vdl_logger.add_image( |
||||
results_dir + '/' + label, |
||||
image_numpy, |
||||
step=step if step else self.global_steps, |
||||
dataformats="HWC" if image_num == 1 else "NCHW") |
||||
else: |
||||
if self.cfg.is_train: |
||||
if self.by_epoch: |
||||
msg = 'epoch%.3d_' % self.current_epoch |
||||
else: |
||||
msg = 'iter%.3d_' % self.current_iter |
||||
else: |
||||
msg = '' |
||||
makedirs(os.path.join(self.output_dir, results_dir)) |
||||
img_path = os.path.join(self.output_dir, results_dir, |
||||
msg + '%s.png' % (label)) |
||||
save_image(image_numpy, img_path) |
||||
|
||||
def save(self, epoch, name='checkpoint', keep=1): |
||||
if self.local_rank != 0: |
||||
return |
||||
|
||||
assert name in ['checkpoint', 'weight'] |
||||
|
||||
state_dicts = {} |
||||
if self.by_epoch: |
||||
save_filename = 'epoch_%s_%s.pdparams' % ( |
||||
epoch // self.iters_per_epoch, name) |
||||
else: |
||||
save_filename = 'iter_%s_%s.pdparams' % (epoch, name) |
||||
|
||||
os.makedirs(self.output_dir, exist_ok=True) |
||||
save_path = os.path.join(self.output_dir, save_filename) |
||||
for net_name, net in self.model.nets.items(): |
||||
state_dicts[net_name] = net.state_dict() |
||||
|
||||
if name == 'weight': |
||||
save(state_dicts, save_path) |
||||
return |
||||
|
||||
state_dicts['epoch'] = epoch |
||||
|
||||
for opt_name, opt in self.model.optimizers.items(): |
||||
state_dicts[opt_name] = opt.state_dict() |
||||
|
||||
save(state_dicts, save_path) |
||||
|
||||
if keep > 0: |
||||
try: |
||||
if self.by_epoch: |
||||
checkpoint_name_to_be_removed = os.path.join( |
||||
self.output_dir, 'epoch_%s_%s.pdparams' % ( |
||||
(epoch - keep * self.weight_interval) // |
||||
self.iters_per_epoch, name)) |
||||
else: |
||||
checkpoint_name_to_be_removed = os.path.join( |
||||
self.output_dir, 'iter_%s_%s.pdparams' % |
||||
(epoch - keep * self.weight_interval, name)) |
||||
|
||||
if os.path.exists(checkpoint_name_to_be_removed): |
||||
os.remove(checkpoint_name_to_be_removed) |
||||
|
||||
except Exception as e: |
||||
self.logger.info('remove old checkpoints error: {}'.format(e)) |
||||
|
||||
def resume(self, checkpoint_path): |
||||
state_dicts = load(checkpoint_path) |
||||
if state_dicts.get('epoch', None) is not None: |
||||
self.start_epoch = state_dicts['epoch'] + 1 |
||||
self.global_steps = self.iters_per_epoch * state_dicts['epoch'] |
||||
|
||||
self.current_iter = state_dicts['epoch'] + 1 |
||||
|
||||
for net_name, net in self.model.nets.items(): |
||||
net.set_state_dict(state_dicts[net_name]) |
||||
|
||||
for opt_name, opt in self.model.optimizers.items(): |
||||
opt.set_state_dict(state_dicts[opt_name]) |
||||
|
||||
def load(self, weight_path): |
||||
state_dicts = load(weight_path) |
||||
|
||||
for net_name, net in self.model.nets.items(): |
||||
if net_name in state_dicts: |
||||
net.set_state_dict(state_dicts[net_name]) |
||||
self.logger.info('Loaded pretrained weight for net {}'.format( |
||||
net_name)) |
||||
else: |
||||
self.logger.warning( |
||||
'Can not find state dict of net {}. Skip load pretrained weight for net {}' |
||||
.format(net_name, net_name)) |
||||
|
||||
def close(self): |
||||
""" |
||||
when finish the training need close file handler or other. |
||||
""" |
||||
if self.enable_visualdl: |
||||
self.vdl_logger.close() |
||||
|
||||
|
||||
# 基础超分模型训练类 |
||||
class BasicSRNet: |
||||
def __init__(self): |
||||
self.model = {} |
||||
self.optimizer = {} |
||||
self.lr_scheduler = {} |
||||
self.min_max = '' |
||||
|
||||
def train( |
||||
self, |
||||
total_iters, |
||||
train_dataset, |
||||
test_dataset, |
||||
output_dir, |
||||
validate, |
||||
snapshot, |
||||
log, |
||||
lr_rate, |
||||
evaluate_weights='', |
||||
resume='', |
||||
pretrain_weights='', |
||||
periods=[100000], |
||||
restart_weights=[1], ): |
||||
self.lr_scheduler['learning_rate'] = lr_rate |
||||
|
||||
if self.lr_scheduler['name'] == 'CosineAnnealingRestartLR': |
||||
self.lr_scheduler['periods'] = periods |
||||
self.lr_scheduler['restart_weights'] = restart_weights |
||||
|
||||
validate = { |
||||
'interval': validate, |
||||
'save_img': False, |
||||
'metrics': { |
||||
'psnr': { |
||||
'name': 'PSNR', |
||||
'crop_border': 4, |
||||
'test_y_channel': True |
||||
}, |
||||
'ssim': { |
||||
'name': 'SSIM', |
||||
'crop_border': 4, |
||||
'test_y_channel': True |
||||
} |
||||
} |
||||
} |
||||
log_config = {'interval': log, 'visiual_interval': 500} |
||||
snapshot_config = {'interval': snapshot} |
||||
|
||||
cfg = { |
||||
'total_iters': total_iters, |
||||
'output_dir': output_dir, |
||||
'min_max': self.min_max, |
||||
'model': self.model, |
||||
'dataset': { |
||||
'train': train_dataset, |
||||
'test': test_dataset |
||||
}, |
||||
'lr_scheduler': self.lr_scheduler, |
||||
'optimizer': self.optimizer, |
||||
'validate': validate, |
||||
'log_config': log_config, |
||||
'snapshot_config': snapshot_config |
||||
} |
||||
|
||||
cfg = AttrDict(cfg) |
||||
create_attr_dict(cfg) |
||||
|
||||
cfg.is_train = True |
||||
cfg.profiler_options = None |
||||
cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime()) |
||||
|
||||
if cfg.model.name == 'BaseSRModel': |
||||
floderModelName = cfg.model.generator.name |
||||
else: |
||||
floderModelName = cfg.model.name |
||||
cfg.output_dir = os.path.join(cfg.output_dir, |
||||
floderModelName + cfg.timestamp) |
||||
|
||||
logger_cfg = setup_logger(cfg.output_dir) |
||||
logger_cfg.info('Configs: {}'.format(cfg)) |
||||
|
||||
if paddle.is_compiled_with_cuda(): |
||||
paddle.set_device('gpu') |
||||
else: |
||||
paddle.set_device('cpu') |
||||
|
||||
# build trainer |
||||
trainer = Restorer(cfg, logger_cfg) |
||||
|
||||
# continue train or evaluate, checkpoint need contain epoch and optimizer info |
||||
if len(resume) > 0: |
||||
trainer.resume(resume) |
||||
# evaluate or finute, only load generator weights |
||||
elif len(pretrain_weights) > 0: |
||||
trainer.load(pretrain_weights) |
||||
if len(evaluate_weights) > 0: |
||||
trainer.load(evaluate_weights) |
||||
trainer.test() |
||||
return |
||||
# training, when keyboard interrupt save weights |
||||
try: |
||||
trainer.train() |
||||
except KeyboardInterrupt as e: |
||||
trainer.save(trainer.current_epoch) |
||||
|
||||
trainer.close() |
||||
|
||||
|
||||
# DRN模型训练 |
||||
class DRNet(BasicSRNet): |
||||
def __init__(self, |
||||
n_blocks=30, |
||||
n_feats=16, |
||||
n_colors=3, |
||||
rgb_range=255, |
||||
negval=0.2): |
||||
super(DRNet, self).__init__() |
||||
self.min_max = '(0., 255.)' |
||||
self.generator = { |
||||
'name': 'DRNGenerator', |
||||
'scale': (2, 4), |
||||
'n_blocks': n_blocks, |
||||
'n_feats': n_feats, |
||||
'n_colors': n_colors, |
||||
'rgb_range': rgb_range, |
||||
'negval': negval |
||||
} |
||||
self.pixel_criterion = {'name': 'L1Loss'} |
||||
self.model = { |
||||
'name': 'DRN', |
||||
'generator': self.generator, |
||||
'pixel_criterion': self.pixel_criterion |
||||
} |
||||
self.optimizer = { |
||||
'optimG': { |
||||
'name': 'Adam', |
||||
'net_names': ['generator'], |
||||
'weight_decay': 0.0, |
||||
'beta1': 0.9, |
||||
'beta2': 0.999 |
||||
}, |
||||
'optimD': { |
||||
'name': 'Adam', |
||||
'net_names': ['dual_model_0', 'dual_model_1'], |
||||
'weight_decay': 0.0, |
||||
'beta1': 0.9, |
||||
'beta2': 0.999 |
||||
} |
||||
} |
||||
self.lr_scheduler = { |
||||
'name': 'CosineAnnealingRestartLR', |
||||
'eta_min': 1e-07 |
||||
} |
||||
|
||||
|
||||
# 轻量化超分模型LESRCNN训练 |
||||
class LESRCNNet(BasicSRNet): |
||||
def __init__(self, scale=4, multi_scale=False, group=1): |
||||
super(LESRCNNet, self).__init__() |
||||
self.min_max = '(0., 1.)' |
||||
self.generator = { |
||||
'name': 'LESRCNNGenerator', |
||||
'scale': scale, |
||||
'multi_scale': False, |
||||
'group': 1 |
||||
} |
||||
self.pixel_criterion = {'name': 'L1Loss'} |
||||
self.model = { |
||||
'name': 'BaseSRModel', |
||||
'generator': self.generator, |
||||
'pixel_criterion': self.pixel_criterion |
||||
} |
||||
self.optimizer = { |
||||
'name': 'Adam', |
||||
'net_names': ['generator'], |
||||
'beta1': 0.9, |
||||
'beta2': 0.99 |
||||
} |
||||
self.lr_scheduler = { |
||||
'name': 'CosineAnnealingRestartLR', |
||||
'eta_min': 1e-07 |
||||
} |
||||
|
||||
|
||||
# ESRGAN模型训练 |
||||
# 若loss_type='gan' 使用感知损失、对抗损失和像素损失 |
||||
# 若loss_type = 'pixel' 只使用像素损失 |
||||
class ESRGANet(BasicSRNet): |
||||
def __init__(self, loss_type='gan', in_nc=3, out_nc=3, nf=64, nb=23): |
||||
super(ESRGANet, self).__init__() |
||||
self.min_max = '(0., 1.)' |
||||
self.generator = { |
||||
'name': 'RRDBNet', |
||||
'in_nc': in_nc, |
||||
'out_nc': out_nc, |
||||
'nf': nf, |
||||
'nb': nb |
||||
} |
||||
|
||||
if loss_type == 'gan': |
||||
# 定义损失函数 |
||||
self.pixel_criterion = {'name': 'L1Loss', 'loss_weight': 0.01} |
||||
self.discriminator = { |
||||
'name': 'VGGDiscriminator128', |
||||
'in_channels': 3, |
||||
'num_feat': 64 |
||||
} |
||||
self.perceptual_criterion = { |
||||
'name': 'PerceptualLoss', |
||||
'layer_weights': { |
||||
'34': 1.0 |
||||
}, |
||||
'perceptual_weight': 1.0, |
||||
'style_weight': 0.0, |
||||
'norm_img': False |
||||
} |
||||
self.gan_criterion = { |
||||
'name': 'GANLoss', |
||||
'gan_mode': 'vanilla', |
||||
'loss_weight': 0.005 |
||||
} |
||||
# 定义模型 |
||||
self.model = { |
||||
'name': 'ESRGAN', |
||||
'generator': self.generator, |
||||
'discriminator': self.discriminator, |
||||
'pixel_criterion': self.pixel_criterion, |
||||
'perceptual_criterion': self.perceptual_criterion, |
||||
'gan_criterion': self.gan_criterion |
||||
} |
||||
self.optimizer = { |
||||
'optimG': { |
||||
'name': 'Adam', |
||||
'net_names': ['generator'], |
||||
'weight_decay': 0.0, |
||||
'beta1': 0.9, |
||||
'beta2': 0.99 |
||||
}, |
||||
'optimD': { |
||||
'name': 'Adam', |
||||
'net_names': ['discriminator'], |
||||
'weight_decay': 0.0, |
||||
'beta1': 0.9, |
||||
'beta2': 0.99 |
||||
} |
||||
} |
||||
self.lr_scheduler = { |
||||
'name': 'MultiStepDecay', |
||||
'milestones': [50000, 100000, 200000, 300000], |
||||
'gamma': 0.5 |
||||
} |
||||
else: |
||||
self.pixel_criterion = {'name': 'L1Loss'} |
||||
self.model = { |
||||
'name': 'BaseSRModel', |
||||
'generator': self.generator, |
||||
'pixel_criterion': self.pixel_criterion |
||||
} |
||||
self.optimizer = { |
||||
'name': 'Adam', |
||||
'net_names': ['generator'], |
||||
'beta1': 0.9, |
||||
'beta2': 0.99 |
||||
} |
||||
self.lr_scheduler = { |
||||
'name': 'CosineAnnealingRestartLR', |
||||
'eta_min': 1e-07 |
||||
} |
||||
|
||||
|
||||
# RCAN模型训练 |
||||
class RCANet(BasicSRNet): |
||||
def __init__( |
||||
self, |
||||
scale=2, |
||||
n_resgroups=10, |
||||
n_resblocks=20, ): |
||||
super(RCANet, self).__init__() |
||||
self.min_max = '(0., 255.)' |
||||
self.generator = { |
||||
'name': 'RCAN', |
||||
'scale': scale, |
||||
'n_resgroups': n_resgroups, |
||||
'n_resblocks': n_resblocks |
||||
} |
||||
self.pixel_criterion = {'name': 'L1Loss'} |
||||
self.model = { |
||||
'name': 'RCANModel', |
||||
'generator': self.generator, |
||||
'pixel_criterion': self.pixel_criterion |
||||
} |
||||
self.optimizer = { |
||||
'name': 'Adam', |
||||
'net_names': ['generator'], |
||||
'beta1': 0.9, |
||||
'beta2': 0.99 |
||||
} |
||||
self.lr_scheduler = { |
||||
'name': 'MultiStepDecay', |
||||
'milestones': [250000, 500000, 750000, 1000000], |
||||
'gamma': 0.5 |
||||
} |
@ -0,0 +1,934 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import os |
||||
import os.path as osp |
||||
from collections import OrderedDict |
||||
|
||||
import numpy as np |
||||
import cv2 |
||||
import paddle |
||||
import paddle.nn.functional as F |
||||
from paddle.static import InputSpec |
||||
|
||||
import paddlers |
||||
import paddlers.models.ppgan as ppgan |
||||
import paddlers.rs_models.res as cmres |
||||
import paddlers.models.ppgan.metrics as metrics |
||||
import paddlers.utils.logging as logging |
||||
from paddlers.models import res_losses |
||||
from paddlers.transforms import Resize, decode_image |
||||
from paddlers.transforms.functions import calc_hr_shape |
||||
from paddlers.utils import get_single_card_bs |
||||
from .base import BaseModel |
||||
from .utils.res_adapters import GANAdapter, OptimizerAdapter |
||||
from .utils.infer_nets import InferResNet |
||||
|
||||
__all__ = ["DRN", "LESRCNN", "ESRGAN"] |
||||
|
||||
|
||||
class BaseRestorer(BaseModel): |
||||
MIN_MAX = (0., 1.) |
||||
TEST_OUT_KEY = None |
||||
|
||||
def __init__(self, model_name, losses=None, sr_factor=None, **params): |
||||
self.init_params = locals() |
||||
if 'with_net' in self.init_params: |
||||
del self.init_params['with_net'] |
||||
super(BaseRestorer, self).__init__('restorer') |
||||
self.model_name = model_name |
||||
self.losses = losses |
||||
self.sr_factor = sr_factor |
||||
if params.get('with_net', True): |
||||
params.pop('with_net', None) |
||||
self.net = self.build_net(**params) |
||||
self.find_unused_parameters = True |
||||
|
||||
def build_net(self, **params): |
||||
# Currently, only use models from cmres. |
||||
if not hasattr(cmres, self.model_name): |
||||
raise ValueError("ERROR: There is no model named {}.".format( |
||||
model_name)) |
||||
net = dict(**cmres.__dict__)[self.model_name](**params) |
||||
return net |
||||
|
||||
def _build_inference_net(self): |
||||
# For GAN models, only the generator will be used for inference. |
||||
if isinstance(self.net, GANAdapter): |
||||
infer_net = InferResNet( |
||||
self.net.generator, out_key=self.TEST_OUT_KEY) |
||||
else: |
||||
infer_net = InferResNet(self.net, out_key=self.TEST_OUT_KEY) |
||||
infer_net.eval() |
||||
return infer_net |
||||
|
||||
def _fix_transforms_shape(self, image_shape): |
||||
if hasattr(self, 'test_transforms'): |
||||
if self.test_transforms is not None: |
||||
has_resize_op = False |
||||
resize_op_idx = -1 |
||||
normalize_op_idx = len(self.test_transforms.transforms) |
||||
for idx, op in enumerate(self.test_transforms.transforms): |
||||
name = op.__class__.__name__ |
||||
if name == 'Normalize': |
||||
normalize_op_idx = idx |
||||
if 'Resize' in name: |
||||
has_resize_op = True |
||||
resize_op_idx = idx |
||||
|
||||
if not has_resize_op: |
||||
self.test_transforms.transforms.insert( |
||||
normalize_op_idx, Resize(target_size=image_shape)) |
||||
else: |
||||
self.test_transforms.transforms[resize_op_idx] = Resize( |
||||
target_size=image_shape) |
||||
|
||||
def _get_test_inputs(self, image_shape): |
||||
if image_shape is not None: |
||||
if len(image_shape) == 2: |
||||
image_shape = [1, 3] + image_shape |
||||
self._fix_transforms_shape(image_shape[-2:]) |
||||
else: |
||||
image_shape = [None, 3, -1, -1] |
||||
self.fixed_input_shape = image_shape |
||||
input_spec = [ |
||||
InputSpec( |
||||
shape=image_shape, name='image', dtype='float32') |
||||
] |
||||
return input_spec |
||||
|
||||
def run(self, net, inputs, mode): |
||||
outputs = OrderedDict() |
||||
|
||||
if mode == 'test': |
||||
tar_shape = inputs[1] |
||||
if self.status == 'Infer': |
||||
net_out = net(inputs[0]) |
||||
res_map_list = self.postprocess( |
||||
net_out, tar_shape, transforms=inputs[2]) |
||||
else: |
||||
if isinstance(net, GANAdapter): |
||||
net_out = net.generator(inputs[0]) |
||||
else: |
||||
net_out = net(inputs[0]) |
||||
if self.TEST_OUT_KEY is not None: |
||||
net_out = net_out[self.TEST_OUT_KEY] |
||||
pred = self.postprocess( |
||||
net_out, tar_shape, transforms=inputs[2]) |
||||
res_map_list = [] |
||||
for res_map in pred: |
||||
res_map = self._tensor_to_images(res_map) |
||||
res_map_list.append(res_map) |
||||
outputs['res_map'] = res_map_list |
||||
|
||||
if mode == 'eval': |
||||
if isinstance(net, GANAdapter): |
||||
net_out = net.generator(inputs[0]) |
||||
else: |
||||
net_out = net(inputs[0]) |
||||
if self.TEST_OUT_KEY is not None: |
||||
net_out = net_out[self.TEST_OUT_KEY] |
||||
tar = inputs[1] |
||||
tar_shape = [tar.shape[-2:]] |
||||
pred = self.postprocess( |
||||
net_out, tar_shape, transforms=inputs[2])[0] # NCHW |
||||
pred = self._tensor_to_images(pred) |
||||
outputs['pred'] = pred |
||||
tar = self._tensor_to_images(tar) |
||||
outputs['tar'] = tar |
||||
|
||||
if mode == 'train': |
||||
# This is used by non-GAN models. |
||||
# For GAN models, self.run_gan() should be used. |
||||
net_out = net(inputs[0]) |
||||
loss = self.losses(net_out, inputs[1]) |
||||
outputs['loss'] = loss |
||||
return outputs |
||||
|
||||
def run_gan(self, net, inputs, mode, gan_mode): |
||||
raise NotImplementedError |
||||
|
||||
def default_loss(self): |
||||
return res_losses.L1Loss() |
||||
|
||||
def default_optimizer(self, |
||||
parameters, |
||||
learning_rate, |
||||
num_epochs, |
||||
num_steps_each_epoch, |
||||
lr_decay_power=0.9): |
||||
decay_step = num_epochs * num_steps_each_epoch |
||||
lr_scheduler = paddle.optimizer.lr.PolynomialDecay( |
||||
learning_rate, decay_step, end_lr=0, power=lr_decay_power) |
||||
optimizer = paddle.optimizer.Momentum( |
||||
learning_rate=lr_scheduler, |
||||
parameters=parameters, |
||||
momentum=0.9, |
||||
weight_decay=4e-5) |
||||
return optimizer |
||||
|
||||
def train(self, |
||||
num_epochs, |
||||
train_dataset, |
||||
train_batch_size=2, |
||||
eval_dataset=None, |
||||
optimizer=None, |
||||
save_interval_epochs=1, |
||||
log_interval_steps=2, |
||||
save_dir='output', |
||||
pretrain_weights=None, |
||||
learning_rate=0.01, |
||||
lr_decay_power=0.9, |
||||
early_stop=False, |
||||
early_stop_patience=5, |
||||
use_vdl=True, |
||||
resume_checkpoint=None): |
||||
""" |
||||
Train the model. |
||||
|
||||
Args: |
||||
num_epochs (int): Number of epochs. |
||||
train_dataset (paddlers.datasets.ResDataset): Training dataset. |
||||
train_batch_size (int, optional): Total batch size among all cards used in |
||||
training. Defaults to 2. |
||||
eval_dataset (paddlers.datasets.ResDataset|None, optional): Evaluation dataset. |
||||
If None, the model will not be evaluated during training process. |
||||
Defaults to None. |
||||
optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in |
||||
training. If None, a default optimizer will be used. Defaults to None. |
||||
save_interval_epochs (int, optional): Epoch interval for saving the model. |
||||
Defaults to 1. |
||||
log_interval_steps (int, optional): Step interval for printing training |
||||
information. Defaults to 2. |
||||
save_dir (str, optional): Directory to save the model. Defaults to 'output'. |
||||
pretrain_weights (str|None, optional): None or name/path of pretrained |
||||
weights. If None, no pretrained weights will be loaded. |
||||
Defaults to None. |
||||
learning_rate (float, optional): Learning rate for training. Defaults to .01. |
||||
lr_decay_power (float, optional): Learning decay power. Defaults to .9. |
||||
early_stop (bool, optional): Whether to adopt early stop strategy. Defaults |
||||
to False. |
||||
early_stop_patience (int, optional): Early stop patience. Defaults to 5. |
||||
use_vdl (bool, optional): Whether to use VisualDL to monitor the training |
||||
process. Defaults to True. |
||||
resume_checkpoint (str|None, optional): Path of the checkpoint to resume |
||||
training from. If None, no training checkpoint will be resumed. At most |
||||
Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously. |
||||
Defaults to None. |
||||
""" |
||||
|
||||
if self.status == 'Infer': |
||||
logging.error( |
||||
"Exported inference model does not support training.", |
||||
exit=True) |
||||
if pretrain_weights is not None and resume_checkpoint is not None: |
||||
logging.error( |
||||
"pretrain_weights and resume_checkpoint cannot be set simultaneously.", |
||||
exit=True) |
||||
|
||||
if self.losses is None: |
||||
self.losses = self.default_loss() |
||||
|
||||
if optimizer is None: |
||||
num_steps_each_epoch = train_dataset.num_samples // train_batch_size |
||||
if isinstance(self.net, GANAdapter): |
||||
parameters = {'params_g': [], 'params_d': []} |
||||
for net_g in self.net.generators: |
||||
parameters['params_g'].append(net_g.parameters()) |
||||
for net_d in self.net.discriminators: |
||||
parameters['params_d'].append(net_d.parameters()) |
||||
else: |
||||
parameters = self.net.parameters() |
||||
self.optimizer = self.default_optimizer( |
||||
parameters, learning_rate, num_epochs, num_steps_each_epoch, |
||||
lr_decay_power) |
||||
else: |
||||
self.optimizer = optimizer |
||||
|
||||
if pretrain_weights is not None and not osp.exists(pretrain_weights): |
||||
logging.warning("Path of pretrain_weights('{}') does not exist!". |
||||
format(pretrain_weights)) |
||||
elif pretrain_weights is not None and osp.exists(pretrain_weights): |
||||
if osp.splitext(pretrain_weights)[-1] != '.pdparams': |
||||
logging.error( |
||||
"Invalid pretrain weights. Please specify a '.pdparams' file.", |
||||
exit=True) |
||||
pretrained_dir = osp.join(save_dir, 'pretrain') |
||||
is_backbone_weights = pretrain_weights == 'IMAGENET' |
||||
self.net_initialize( |
||||
pretrain_weights=pretrain_weights, |
||||
save_dir=pretrained_dir, |
||||
resume_checkpoint=resume_checkpoint, |
||||
is_backbone_weights=is_backbone_weights) |
||||
|
||||
self.train_loop( |
||||
num_epochs=num_epochs, |
||||
train_dataset=train_dataset, |
||||
train_batch_size=train_batch_size, |
||||
eval_dataset=eval_dataset, |
||||
save_interval_epochs=save_interval_epochs, |
||||
log_interval_steps=log_interval_steps, |
||||
save_dir=save_dir, |
||||
early_stop=early_stop, |
||||
early_stop_patience=early_stop_patience, |
||||
use_vdl=use_vdl) |
||||
|
||||
def quant_aware_train(self, |
||||
num_epochs, |
||||
train_dataset, |
||||
train_batch_size=2, |
||||
eval_dataset=None, |
||||
optimizer=None, |
||||
save_interval_epochs=1, |
||||
log_interval_steps=2, |
||||
save_dir='output', |
||||
learning_rate=0.0001, |
||||
lr_decay_power=0.9, |
||||
early_stop=False, |
||||
early_stop_patience=5, |
||||
use_vdl=True, |
||||
resume_checkpoint=None, |
||||
quant_config=None): |
||||
""" |
||||
Quantization-aware training. |
||||
|
||||
Args: |
||||
num_epochs (int): Number of epochs. |
||||
train_dataset (paddlers.datasets.ResDataset): Training dataset. |
||||
train_batch_size (int, optional): Total batch size among all cards used in |
||||
training. Defaults to 2. |
||||
eval_dataset (paddlers.datasets.ResDataset|None, optional): Evaluation dataset. |
||||
If None, the model will not be evaluated during training process. |
||||
Defaults to None. |
||||
optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in |
||||
training. If None, a default optimizer will be used. Defaults to None. |
||||
save_interval_epochs (int, optional): Epoch interval for saving the model. |
||||
Defaults to 1. |
||||
log_interval_steps (int, optional): Step interval for printing training |
||||
information. Defaults to 2. |
||||
save_dir (str, optional): Directory to save the model. Defaults to 'output'. |
||||
learning_rate (float, optional): Learning rate for training. |
||||
Defaults to .0001. |
||||
lr_decay_power (float, optional): Learning decay power. Defaults to .9. |
||||
early_stop (bool, optional): Whether to adopt early stop strategy. |
||||
Defaults to False. |
||||
early_stop_patience (int, optional): Early stop patience. Defaults to 5. |
||||
use_vdl (bool, optional): Whether to use VisualDL to monitor the training |
||||
process. Defaults to True. |
||||
quant_config (dict|None, optional): Quantization configuration. If None, |
||||
a default rule of thumb configuration will be used. Defaults to None. |
||||
resume_checkpoint (str|None, optional): Path of the checkpoint to resume |
||||
quantization-aware training from. If None, no training checkpoint will |
||||
be resumed. Defaults to None. |
||||
""" |
||||
|
||||
self._prepare_qat(quant_config) |
||||
self.train( |
||||
num_epochs=num_epochs, |
||||
train_dataset=train_dataset, |
||||
train_batch_size=train_batch_size, |
||||
eval_dataset=eval_dataset, |
||||
optimizer=optimizer, |
||||
save_interval_epochs=save_interval_epochs, |
||||
log_interval_steps=log_interval_steps, |
||||
save_dir=save_dir, |
||||
pretrain_weights=None, |
||||
learning_rate=learning_rate, |
||||
lr_decay_power=lr_decay_power, |
||||
early_stop=early_stop, |
||||
early_stop_patience=early_stop_patience, |
||||
use_vdl=use_vdl, |
||||
resume_checkpoint=resume_checkpoint) |
||||
|
||||
def evaluate(self, eval_dataset, batch_size=1, return_details=False): |
||||
""" |
||||
Evaluate the model. |
||||
|
||||
Args: |
||||
eval_dataset (paddlers.datasets.ResDataset): Evaluation dataset. |
||||
batch_size (int, optional): Total batch size among all cards used for |
||||
evaluation. Defaults to 1. |
||||
return_details (bool, optional): Whether to return evaluation details. |
||||
Defaults to False. |
||||
|
||||
Returns: |
||||
If `return_details` is False, return collections.OrderedDict with |
||||
key-value pairs: |
||||
{"psnr": `peak signal-to-noise ratio`, |
||||
"ssim": `structural similarity`}. |
||||
|
||||
""" |
||||
|
||||
self._check_transforms(eval_dataset.transforms, 'eval') |
||||
|
||||
self.net.eval() |
||||
nranks = paddle.distributed.get_world_size() |
||||
local_rank = paddle.distributed.get_rank() |
||||
if nranks > 1: |
||||
# Initialize parallel environment if not done. |
||||
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized( |
||||
): |
||||
paddle.distributed.init_parallel_env() |
||||
|
||||
# TODO: Distributed evaluation |
||||
if batch_size > 1: |
||||
logging.warning( |
||||
"Restorer only supports single card evaluation with batch_size=1 " |
||||
"during evaluation, so batch_size is forcibly set to 1.") |
||||
batch_size = 1 |
||||
|
||||
if nranks < 2 or local_rank == 0: |
||||
self.eval_data_loader = self.build_data_loader( |
||||
eval_dataset, batch_size=batch_size, mode='eval') |
||||
# XXX: Hard-code crop_border and test_y_channel |
||||
psnr = metrics.PSNR(crop_border=4, test_y_channel=True) |
||||
ssim = metrics.SSIM(crop_border=4, test_y_channel=True) |
||||
logging.info( |
||||
"Start to evaluate(total_samples={}, total_steps={})...".format( |
||||
eval_dataset.num_samples, eval_dataset.num_samples)) |
||||
with paddle.no_grad(): |
||||
for step, data in enumerate(self.eval_data_loader): |
||||
data.append(eval_dataset.transforms.transforms) |
||||
outputs = self.run(self.net, data, 'eval') |
||||
psnr.update(outputs['pred'], outputs['tar']) |
||||
ssim.update(outputs['pred'], outputs['tar']) |
||||
|
||||
# DO NOT use psnr.accumulate() here, otherwise the program hangs in multi-card training. |
||||
assert len(psnr.results) > 0 |
||||
assert len(ssim.results) > 0 |
||||
eval_metrics = OrderedDict( |
||||
zip(['psnr', 'ssim'], |
||||
[np.mean(psnr.results), np.mean(ssim.results)])) |
||||
|
||||
if return_details: |
||||
# TODO: Add details |
||||
return eval_metrics, None |
||||
|
||||
return eval_metrics |
||||
|
||||
def predict(self, img_file, transforms=None): |
||||
""" |
||||
Do inference. |
||||
|
||||
Args: |
||||
img_file (list[np.ndarray|str] | str | np.ndarray): Image path or decoded |
||||
image data, which also could constitute a list, meaning all images to be |
||||
predicted as a mini-batch. |
||||
transforms (paddlers.transforms.Compose|None, optional): Transforms for |
||||
inputs. If None, the transforms for evaluation process will be used. |
||||
Defaults to None. |
||||
|
||||
Returns: |
||||
If `img_file` is a tuple of string or np.array, the result is a dict with |
||||
the following key-value pairs: |
||||
res_map (np.ndarray): Restored image (HWC). |
||||
|
||||
If `img_file` is a list, the result is a list composed of dicts with the |
||||
above keys. |
||||
""" |
||||
|
||||
if transforms is None and not hasattr(self, 'test_transforms'): |
||||
raise ValueError("transforms need to be defined, now is None.") |
||||
if transforms is None: |
||||
transforms = self.test_transforms |
||||
if isinstance(img_file, (str, np.ndarray)): |
||||
images = [img_file] |
||||
else: |
||||
images = img_file |
||||
batch_im, batch_tar_shape = self.preprocess(images, transforms, |
||||
self.model_type) |
||||
self.net.eval() |
||||
data = (batch_im, batch_tar_shape, transforms.transforms) |
||||
outputs = self.run(self.net, data, 'test') |
||||
res_map_list = outputs['res_map'] |
||||
if isinstance(img_file, list): |
||||
prediction = [{'res_map': m} for m in res_map_list] |
||||
else: |
||||
prediction = {'res_map': res_map_list[0]} |
||||
return prediction |
||||
|
||||
def preprocess(self, images, transforms, to_tensor=True): |
||||
self._check_transforms(transforms, 'test') |
||||
batch_im = list() |
||||
batch_tar_shape = list() |
||||
for im in images: |
||||
if isinstance(im, str): |
||||
im = decode_image(im, to_rgb=False) |
||||
ori_shape = im.shape[:2] |
||||
sample = {'image': im} |
||||
im = transforms(sample)[0] |
||||
batch_im.append(im) |
||||
batch_tar_shape.append(self._get_target_shape(ori_shape)) |
||||
if to_tensor: |
||||
batch_im = paddle.to_tensor(batch_im) |
||||
else: |
||||
batch_im = np.asarray(batch_im) |
||||
|
||||
return batch_im, batch_tar_shape |
||||
|
||||
def _get_target_shape(self, ori_shape): |
||||
if self.sr_factor is None: |
||||
return ori_shape |
||||
else: |
||||
return calc_hr_shape(ori_shape, self.sr_factor) |
||||
|
||||
@staticmethod |
||||
def get_transforms_shape_info(batch_tar_shape, transforms): |
||||
batch_restore_list = list() |
||||
for tar_shape in batch_tar_shape: |
||||
restore_list = list() |
||||
h, w = tar_shape[0], tar_shape[1] |
||||
for op in transforms: |
||||
if op.__class__.__name__ == 'Resize': |
||||
restore_list.append(('resize', (h, w))) |
||||
h, w = op.target_size |
||||
elif op.__class__.__name__ == 'ResizeByShort': |
||||
restore_list.append(('resize', (h, w))) |
||||
im_short_size = min(h, w) |
||||
im_long_size = max(h, w) |
||||
scale = float(op.short_size) / float(im_short_size) |
||||
if 0 < op.max_size < np.round(scale * im_long_size): |
||||
scale = float(op.max_size) / float(im_long_size) |
||||
h = int(round(h * scale)) |
||||
w = int(round(w * scale)) |
||||
elif op.__class__.__name__ == 'ResizeByLong': |
||||
restore_list.append(('resize', (h, w))) |
||||
im_long_size = max(h, w) |
||||
scale = float(op.long_size) / float(im_long_size) |
||||
h = int(round(h * scale)) |
||||
w = int(round(w * scale)) |
||||
elif op.__class__.__name__ == 'Pad': |
||||
if op.target_size: |
||||
target_h, target_w = op.target_size |
||||
else: |
||||
target_h = int( |
||||
(np.ceil(h / op.size_divisor) * op.size_divisor)) |
||||
target_w = int( |
||||
(np.ceil(w / op.size_divisor) * op.size_divisor)) |
||||
|
||||
if op.pad_mode == -1: |
||||
offsets = op.offsets |
||||
elif op.pad_mode == 0: |
||||
offsets = [0, 0] |
||||
elif op.pad_mode == 1: |
||||
offsets = [(target_h - h) // 2, (target_w - w) // 2] |
||||
else: |
||||
offsets = [target_h - h, target_w - w] |
||||
restore_list.append(('padding', (h, w), offsets)) |
||||
h, w = target_h, target_w |
||||
|
||||
batch_restore_list.append(restore_list) |
||||
return batch_restore_list |
||||
|
||||
def postprocess(self, batch_pred, batch_tar_shape, transforms): |
||||
batch_restore_list = BaseRestorer.get_transforms_shape_info( |
||||
batch_tar_shape, transforms) |
||||
if self.status == 'Infer': |
||||
return self._infer_postprocess( |
||||
batch_res_map=batch_pred, batch_restore_list=batch_restore_list) |
||||
results = [] |
||||
if batch_pred.dtype == paddle.float32: |
||||
mode = 'bilinear' |
||||
else: |
||||
mode = 'nearest' |
||||
for pred, restore_list in zip(batch_pred, batch_restore_list): |
||||
pred = paddle.unsqueeze(pred, axis=0) |
||||
for item in restore_list[::-1]: |
||||
h, w = item[1][0], item[1][1] |
||||
if item[0] == 'resize': |
||||
pred = F.interpolate( |
||||
pred, (h, w), mode=mode, data_format='NCHW') |
||||
elif item[0] == 'padding': |
||||
x, y = item[2] |
||||
pred = pred[:, :, y:y + h, x:x + w] |
||||
else: |
||||
pass |
||||
results.append(pred) |
||||
return results |
||||
|
||||
def _infer_postprocess(self, batch_res_map, batch_restore_list): |
||||
res_maps = [] |
||||
for res_map, restore_list in zip(batch_res_map, batch_restore_list): |
||||
if not isinstance(res_map, np.ndarray): |
||||
res_map = paddle.unsqueeze(res_map, axis=0) |
||||
for item in restore_list[::-1]: |
||||
h, w = item[1][0], item[1][1] |
||||
if item[0] == 'resize': |
||||
if isinstance(res_map, np.ndarray): |
||||
res_map = cv2.resize( |
||||
res_map, (w, h), interpolation=cv2.INTER_LINEAR) |
||||
else: |
||||
res_map = F.interpolate( |
||||
res_map, (h, w), |
||||
mode='bilinear', |
||||
data_format='NHWC') |
||||
elif item[0] == 'padding': |
||||
x, y = item[2] |
||||
if isinstance(res_map, np.ndarray): |
||||
res_map = res_map[y:y + h, x:x + w] |
||||
else: |
||||
res_map = res_map[:, y:y + h, x:x + w, :] |
||||
else: |
||||
pass |
||||
res_map = res_map.squeeze() |
||||
if not isinstance(res_map, np.ndarray): |
||||
res_map = res_map.numpy() |
||||
res_map = self._normalize(res_map) |
||||
res_maps.append(res_map.squeeze()) |
||||
return res_maps |
||||
|
||||
def _check_transforms(self, transforms, mode): |
||||
super()._check_transforms(transforms, mode) |
||||
if not isinstance(transforms.arrange, |
||||
paddlers.transforms.ArrangeRestorer): |
||||
raise TypeError( |
||||
"`transforms.arrange` must be an ArrangeRestorer object.") |
||||
|
||||
def build_data_loader(self, dataset, batch_size, mode='train'): |
||||
if dataset.num_samples < batch_size: |
||||
raise ValueError( |
||||
'The volume of dataset({}) must be larger than batch size({}).' |
||||
.format(dataset.num_samples, batch_size)) |
||||
|
||||
if mode != 'train': |
||||
return paddle.io.DataLoader( |
||||
dataset, |
||||
batch_size=batch_size, |
||||
shuffle=dataset.shuffle, |
||||
drop_last=False, |
||||
collate_fn=dataset.batch_transforms, |
||||
num_workers=dataset.num_workers, |
||||
return_list=True, |
||||
use_shared_memory=False) |
||||
else: |
||||
return super(BaseRestorer, self).build_data_loader(dataset, |
||||
batch_size, mode) |
||||
|
||||
def set_losses(self, losses): |
||||
self.losses = losses |
||||
|
||||
def _tensor_to_images(self, |
||||
tensor, |
||||
transpose=True, |
||||
squeeze=True, |
||||
quantize=True): |
||||
if transpose: |
||||
tensor = paddle.transpose(tensor, perm=[0, 2, 3, 1]) # NHWC |
||||
if squeeze: |
||||
tensor = tensor.squeeze() |
||||
images = tensor.numpy().astype('float32') |
||||
images = self._normalize( |
||||
images, copy=True, clip=True, quantize=quantize) |
||||
return images |
||||
|
||||
def _normalize(self, im, copy=False, clip=True, quantize=True): |
||||
if copy: |
||||
im = im.copy() |
||||
if clip: |
||||
im = np.clip(im, self.MIN_MAX[0], self.MIN_MAX[1]) |
||||
im -= im.min() |
||||
im /= im.max() + 1e-32 |
||||
if quantize: |
||||
im *= 255 |
||||
im = im.astype('uint8') |
||||
return im |
||||
|
||||
|
||||
class DRN(BaseRestorer): |
||||
TEST_OUT_KEY = -1 |
||||
|
||||
def __init__(self, |
||||
losses=None, |
||||
sr_factor=4, |
||||
scale=(2, 4), |
||||
n_blocks=30, |
||||
n_feats=16, |
||||
n_colors=3, |
||||
rgb_range=1.0, |
||||
negval=0.2, |
||||
lq_loss_weight=0.1, |
||||
dual_loss_weight=0.1, |
||||
**params): |
||||
if sr_factor != max(scale): |
||||
raise ValueError(f"`sr_factor` must be equal to `max(scale)`.") |
||||
params.update({ |
||||
'scale': scale, |
||||
'n_blocks': n_blocks, |
||||
'n_feats': n_feats, |
||||
'n_colors': n_colors, |
||||
'rgb_range': rgb_range, |
||||
'negval': negval |
||||
}) |
||||
self.lq_loss_weight = lq_loss_weight |
||||
self.dual_loss_weight = dual_loss_weight |
||||
super(DRN, self).__init__( |
||||
model_name='DRN', losses=losses, sr_factor=sr_factor, **params) |
||||
|
||||
def build_net(self, **params): |
||||
from ppgan.modules.init import init_weights |
||||
generators = [ppgan.models.generators.DRNGenerator(**params)] |
||||
init_weights(generators[-1]) |
||||
for scale in params['scale']: |
||||
dual_model = ppgan.models.generators.drn.DownBlock( |
||||
params['negval'], params['n_feats'], params['n_colors'], 2) |
||||
generators.append(dual_model) |
||||
init_weights(generators[-1]) |
||||
return GANAdapter(generators, []) |
||||
|
||||
def default_optimizer(self, parameters, *args, **kwargs): |
||||
optims_g = [ |
||||
super(DRN, self).default_optimizer(params_g, *args, **kwargs) |
||||
for params_g in parameters['params_g'] |
||||
] |
||||
return OptimizerAdapter(*optims_g) |
||||
|
||||
def run_gan(self, net, inputs, mode, gan_mode='forward_primary'): |
||||
if mode != 'train': |
||||
raise ValueError("`mode` is not 'train'.") |
||||
outputs = OrderedDict() |
||||
if gan_mode == 'forward_primary': |
||||
sr = net.generator(inputs[0]) |
||||
lr = [inputs[0]] |
||||
lr.extend([ |
||||
F.interpolate( |
||||
inputs[0], scale_factor=s, mode='bicubic') |
||||
for s in net.generator.scale[:-1] |
||||
]) |
||||
loss = self.losses(sr[-1], inputs[1]) |
||||
for i in range(1, len(sr)): |
||||
if self.lq_loss_weight > 0: |
||||
loss += self.losses(sr[i - 1 - len(sr)], |
||||
lr[i - len(sr)]) * self.lq_loss_weight |
||||
outputs['loss_prim'] = loss |
||||
outputs['sr'] = sr |
||||
outputs['lr'] = lr |
||||
elif gan_mode == 'forward_dual': |
||||
sr, lr = inputs[0], inputs[1] |
||||
sr2lr = [] |
||||
n_scales = len(net.generator.scale) |
||||
for i in range(n_scales): |
||||
sr2lr_i = net.generators[1 + i](sr[i - n_scales]) |
||||
sr2lr.append(sr2lr_i) |
||||
loss = self.losses(sr2lr[0], lr[0]) |
||||
for i in range(1, n_scales): |
||||
if self.dual_loss_weight > 0.0: |
||||
loss += self.losses(sr2lr[i], lr[i]) * self.dual_loss_weight |
||||
outputs['loss_dual'] = loss |
||||
else: |
||||
raise ValueError("Invalid `gan_mode`!") |
||||
return outputs |
||||
|
||||
def train_step(self, step, data, net): |
||||
outputs = self.run_gan( |
||||
net, data, mode='train', gan_mode='forward_primary') |
||||
outputs.update( |
||||
self.run_gan( |
||||
net, (outputs['sr'], outputs['lr']), |
||||
mode='train', |
||||
gan_mode='forward_dual')) |
||||
self.optimizer.clear_grad() |
||||
(outputs['loss_prim'] + outputs['loss_dual']).backward() |
||||
self.optimizer.step() |
||||
return { |
||||
'loss_prim': outputs['loss_prim'], |
||||
'loss_dual': outputs['loss_dual'] |
||||
} |
||||
|
||||
|
||||
class LESRCNN(BaseRestorer): |
||||
def __init__(self, |
||||
losses=None, |
||||
sr_factor=4, |
||||
multi_scale=False, |
||||
group=1, |
||||
**params): |
||||
params.update({ |
||||
'scale': sr_factor, |
||||
'multi_scale': multi_scale, |
||||
'group': group |
||||
}) |
||||
super(LESRCNN, self).__init__( |
||||
model_name='LESRCNN', losses=losses, sr_factor=sr_factor, **params) |
||||
|
||||
def build_net(self, **params): |
||||
net = ppgan.models.generators.LESRCNNGenerator(**params) |
||||
return net |
||||
|
||||
|
||||
class ESRGAN(BaseRestorer): |
||||
def __init__(self, |
||||
losses=None, |
||||
sr_factor=4, |
||||
use_gan=True, |
||||
in_channels=3, |
||||
out_channels=3, |
||||
nf=64, |
||||
nb=23, |
||||
**params): |
||||
if sr_factor != 4: |
||||
raise ValueError("`sr_factor` must be 4.") |
||||
params.update({ |
||||
'in_nc': in_channels, |
||||
'out_nc': out_channels, |
||||
'nf': nf, |
||||
'nb': nb |
||||
}) |
||||
self.use_gan = use_gan |
||||
super(ESRGAN, self).__init__( |
||||
model_name='ESRGAN', losses=losses, sr_factor=sr_factor, **params) |
||||
|
||||
def build_net(self, **params): |
||||
from ppgan.modules.init import init_weights |
||||
generator = ppgan.models.generators.RRDBNet(**params) |
||||
init_weights(generator) |
||||
if self.use_gan: |
||||
discriminator = ppgan.models.discriminators.VGGDiscriminator128( |
||||
in_channels=params['out_nc'], num_feat=64) |
||||
net = GANAdapter( |
||||
generators=[generator], discriminators=[discriminator]) |
||||
else: |
||||
net = generator |
||||
return net |
||||
|
||||
def default_loss(self): |
||||
if self.use_gan: |
||||
return { |
||||
'pixel': res_losses.L1Loss(loss_weight=0.01), |
||||
'perceptual': res_losses.PerceptualLoss( |
||||
layer_weights={'34': 1.0}, |
||||
perceptual_weight=1.0, |
||||
style_weight=0.0, |
||||
norm_img=False), |
||||
'gan': res_losses.GANLoss( |
||||
gan_mode='vanilla', loss_weight=0.005) |
||||
} |
||||
else: |
||||
return res_losses.L1Loss() |
||||
|
||||
def default_optimizer(self, parameters, *args, **kwargs): |
||||
if self.use_gan: |
||||
optim_g = super(ESRGAN, self).default_optimizer( |
||||
parameters['params_g'][0], *args, **kwargs) |
||||
optim_d = super(ESRGAN, self).default_optimizer( |
||||
parameters['params_d'][0], *args, **kwargs) |
||||
return OptimizerAdapter(optim_g, optim_d) |
||||
else: |
||||
return super(ESRGAN, self).default_optimizer(parameters, *args, |
||||
**kwargs) |
||||
|
||||
def run_gan(self, net, inputs, mode, gan_mode='forward_g'): |
||||
if mode != 'train': |
||||
raise ValueError("`mode` is not 'train'.") |
||||
outputs = OrderedDict() |
||||
if gan_mode == 'forward_g': |
||||
loss_g = 0 |
||||
g_pred = net.generator(inputs[0]) |
||||
loss_pix = self.losses['pixel'](g_pred, inputs[1]) |
||||
loss_perc, loss_sty = self.losses['perceptual'](g_pred, inputs[1]) |
||||
loss_g += loss_pix |
||||
if loss_perc is not None: |
||||
loss_g += loss_perc |
||||
if loss_sty is not None: |
||||
loss_g += loss_sty |
||||
self._set_requires_grad(net.discriminator, False) |
||||
real_d_pred = net.discriminator(inputs[1]).detach() |
||||
fake_g_pred = net.discriminator(g_pred) |
||||
loss_g_real = self.losses['gan']( |
||||
real_d_pred - paddle.mean(fake_g_pred), False, |
||||
is_disc=False) * 0.5 |
||||
loss_g_fake = self.losses['gan']( |
||||
fake_g_pred - paddle.mean(real_d_pred), True, |
||||
is_disc=False) * 0.5 |
||||
loss_g_gan = loss_g_real + loss_g_fake |
||||
outputs['g_pred'] = g_pred.detach() |
||||
outputs['loss_g_pps'] = loss_g |
||||
outputs['loss_g_gan'] = loss_g_gan |
||||
elif gan_mode == 'forward_d': |
||||
self._set_requires_grad(net.discriminator, True) |
||||
# Real |
||||
fake_d_pred = net.discriminator(inputs[0]).detach() |
||||
real_d_pred = net.discriminator(inputs[1]) |
||||
loss_d_real = self.losses['gan']( |
||||
real_d_pred - paddle.mean(fake_d_pred), True, |
||||
is_disc=True) * 0.5 |
||||
# Fake |
||||
fake_d_pred = net.discriminator(inputs[0].detach()) |
||||
loss_d_fake = self.losses['gan']( |
||||
fake_d_pred - paddle.mean(real_d_pred.detach()), |
||||
False, |
||||
is_disc=True) * 0.5 |
||||
outputs['loss_d'] = loss_d_real + loss_d_fake |
||||
else: |
||||
raise ValueError("Invalid `gan_mode`!") |
||||
return outputs |
||||
|
||||
def train_step(self, step, data, net): |
||||
if self.use_gan: |
||||
optim_g, optim_d = self.optimizer |
||||
|
||||
outputs = self.run_gan( |
||||
net, data, mode='train', gan_mode='forward_g') |
||||
optim_g.clear_grad() |
||||
(outputs['loss_g_pps'] + outputs['loss_g_gan']).backward() |
||||
optim_g.step() |
||||
|
||||
outputs.update( |
||||
self.run_gan( |
||||
net, (outputs['g_pred'], data[1]), |
||||
mode='train', |
||||
gan_mode='forward_d')) |
||||
optim_d.clear_grad() |
||||
outputs['loss_d'].backward() |
||||
optim_d.step() |
||||
|
||||
outputs['loss'] = outputs['loss_g_pps'] + outputs[ |
||||
'loss_g_gan'] + outputs['loss_d'] |
||||
|
||||
return { |
||||
'loss': outputs['loss'], |
||||
'loss_g_pps': outputs['loss_g_pps'], |
||||
'loss_g_gan': outputs['loss_g_gan'], |
||||
'loss_d': outputs['loss_d'] |
||||
} |
||||
else: |
||||
return super(ESRGAN, self).train_step(step, data, net) |
||||
|
||||
def _set_requires_grad(self, net, requires_grad): |
||||
for p in net.parameters(): |
||||
p.trainable = requires_grad |
||||
|
||||
|
||||
class RCAN(BaseRestorer): |
||||
def __init__(self, |
||||
losses=None, |
||||
sr_factor=4, |
||||
n_resgroups=10, |
||||
n_resblocks=20, |
||||
n_feats=64, |
||||
n_colors=3, |
||||
rgb_range=1.0, |
||||
kernel_size=3, |
||||
reduction=16, |
||||
**params): |
||||
params.update({ |
||||
'n_resgroups': n_resgroups, |
||||
'n_resblocks': n_resblocks, |
||||
'n_feats': n_feats, |
||||
'n_colors': n_colors, |
||||
'rgb_range': rgb_range, |
||||
'kernel_size': kernel_size, |
||||
'reduction': reduction |
||||
}) |
||||
super(RCAN, self).__init__( |
||||
model_name='RCAN', losses=losses, sr_factor=sr_factor, **params) |
@ -0,0 +1,132 @@ |
||||
from functools import wraps |
||||
from inspect import isfunction, isgeneratorfunction, getmembers |
||||
from collections.abc import Sequence |
||||
from abc import ABC |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
|
||||
__all__ = ['GANAdapter', 'OptimizerAdapter'] |
||||
|
||||
|
||||
class _AttrDesc: |
||||
def __init__(self, key): |
||||
self.key = key |
||||
|
||||
def __get__(self, instance, owner): |
||||
return tuple(getattr(ele, self.key) for ele in instance) |
||||
|
||||
def __set__(self, instance, value): |
||||
for ele in instance: |
||||
setattr(ele, self.key, value) |
||||
|
||||
|
||||
def _func_deco(cls, func_name): |
||||
@wraps(getattr(cls.__ducktype__, func_name)) |
||||
def _wrapper(self, *args, **kwargs): |
||||
return tuple(getattr(ele, func_name)(*args, **kwargs) for ele in self) |
||||
|
||||
return _wrapper |
||||
|
||||
|
||||
def _generator_deco(cls, func_name): |
||||
@wraps(getattr(cls.__ducktype__, func_name)) |
||||
def _wrapper(self, *args, **kwargs): |
||||
for ele in self: |
||||
yield from getattr(ele, func_name)(*args, **kwargs) |
||||
|
||||
return _wrapper |
||||
|
||||
|
||||
class Adapter(Sequence, ABC): |
||||
__ducktype__ = object |
||||
__ava__ = () |
||||
|
||||
def __init__(self, *args): |
||||
if not all(map(self._check, args)): |
||||
raise TypeError("Please check the input type.") |
||||
self._seq = tuple(args) |
||||
|
||||
def __getitem__(self, key): |
||||
return self._seq[key] |
||||
|
||||
def __len__(self): |
||||
return len(self._seq) |
||||
|
||||
def __repr__(self): |
||||
return repr(self._seq) |
||||
|
||||
@classmethod |
||||
def _check(cls, obj): |
||||
for attr in cls.__ava__: |
||||
try: |
||||
getattr(obj, attr) |
||||
# TODO: Check function signature |
||||
except AttributeError: |
||||
return False |
||||
return True |
||||
|
||||
|
||||
def make_adapter(cls): |
||||
members = dict(getmembers(cls.__ducktype__)) |
||||
for k in cls.__ava__: |
||||
if hasattr(cls, k): |
||||
continue |
||||
if k in members: |
||||
v = members[k] |
||||
if isgeneratorfunction(v): |
||||
setattr(cls, k, _generator_deco(cls, k)) |
||||
elif isfunction(v): |
||||
setattr(cls, k, _func_deco(cls, k)) |
||||
else: |
||||
setattr(cls, k, _AttrDesc(k)) |
||||
return cls |
||||
|
||||
|
||||
class GANAdapter(nn.Layer): |
||||
__ducktype__ = nn.Layer |
||||
__ava__ = ('state_dict', 'set_state_dict', 'train', 'eval') |
||||
|
||||
def __init__(self, generators, discriminators): |
||||
super(GANAdapter, self).__init__() |
||||
self.generators = nn.LayerList(generators) |
||||
self.discriminators = nn.LayerList(discriminators) |
||||
self._m = [*generators, *discriminators] |
||||
|
||||
def __len__(self): |
||||
return len(self._m) |
||||
|
||||
def __getitem__(self, key): |
||||
return self._m[key] |
||||
|
||||
def __contains__(self, m): |
||||
return m in self._m |
||||
|
||||
def __repr__(self): |
||||
return repr(self._m) |
||||
|
||||
@property |
||||
def generator(self): |
||||
return self.generators[0] |
||||
|
||||
@property |
||||
def discriminator(self): |
||||
return self.discriminators[0] |
||||
|
||||
|
||||
Adapter.register(GANAdapter) |
||||
|
||||
|
||||
@make_adapter |
||||
class OptimizerAdapter(Adapter): |
||||
__ducktype__ = paddle.optimizer.Optimizer |
||||
__ava__ = ('state_dict', 'set_state_dict', 'clear_grad', 'step', 'get_lr') |
||||
|
||||
def set_state_dict(self, state_dicts): |
||||
# Special dispatching rule |
||||
for optim, state_dict in zip(self, state_dicts): |
||||
optim.set_state_dict(state_dict) |
||||
|
||||
def get_lr(self): |
||||
# Return the lr of the first optimizer |
||||
return self[0].get_lr() |
@ -0,0 +1,46 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddlers |
||||
from rs_models.test_model import TestModel |
||||
|
||||
__all__ = [] |
||||
|
||||
|
||||
class TestResModel(TestModel): |
||||
def check_output(self, output, target): |
||||
output = output.numpy() |
||||
self.check_output_equal(output.shape, target.shape) |
||||
|
||||
def set_inputs(self): |
||||
def _gen_data(specs): |
||||
for spec in specs: |
||||
c = spec.get('in_channels', 3) |
||||
yield self.get_randn_tensor(c) |
||||
|
||||
self.inputs = _gen_data(self.specs) |
||||
|
||||
def set_targets(self): |
||||
def _gen_data(specs): |
||||
for spec in specs: |
||||
# XXX: Hard coding |
||||
if 'out_channels' in spec: |
||||
c = spec['out_channels'] |
||||
elif 'in_channels' in spec: |
||||
c = spec['in_channels'] |
||||
else: |
||||
c = 3 |
||||
yield [self.get_zeros_array(c)] |
||||
|
||||
self.targets = _gen_data(self.specs) |
@ -0,0 +1,3 @@ |
||||
*.zip |
||||
*.tar.gz |
||||
rssr/ |
@ -0,0 +1,89 @@ |
||||
#!/usr/bin/env python |
||||
|
||||
# 图像复原模型DRN训练示例脚本 |
||||
# 执行此脚本前,请确认已正确安装PaddleRS库 |
||||
|
||||
import paddlers as pdrs |
||||
from paddlers import transforms as T |
||||
|
||||
# 数据集存放目录 |
||||
DATA_DIR = './data/rssr/' |
||||
# 训练集`file_list`文件路径 |
||||
TRAIN_FILE_LIST_PATH = './data/rssr/train.txt' |
||||
# 验证集`file_list`文件路径 |
||||
EVAL_FILE_LIST_PATH = './data/rssr/val.txt' |
||||
# 实验目录,保存输出的模型权重和结果 |
||||
EXP_DIR = './output/drn/' |
||||
|
||||
# 下载和解压遥感影像超分辨率数据集 |
||||
pdrs.utils.download_and_decompress( |
||||
'https://paddlers.bj.bcebos.com/datasets/rssr.zip', path='./data/') |
||||
|
||||
# 定义训练和验证时使用的数据变换(数据增强、预处理等) |
||||
# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行 |
||||
# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md |
||||
train_transforms = T.Compose([ |
||||
# 读取影像 |
||||
T.DecodeImg(), |
||||
# 从输入影像中裁剪96x96大小的影像块 |
||||
T.RandomCrop(crop_size=96), |
||||
# 以50%的概率实施随机水平翻转 |
||||
T.RandomHorizontalFlip(prob=0.5), |
||||
# 以50%的概率实施随机垂直翻转 |
||||
T.RandomVerticalFlip(prob=0.5), |
||||
# 将数据归一化到[0,1] |
||||
T.Normalize( |
||||
mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), |
||||
T.ArrangeRestorer('train') |
||||
]) |
||||
|
||||
eval_transforms = T.Compose([ |
||||
T.DecodeImg(), |
||||
# 将输入影像缩放到256x256大小 |
||||
T.Resize(target_size=256), |
||||
# 验证阶段与训练阶段的数据归一化方式必须相同 |
||||
T.Normalize( |
||||
mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), |
||||
T.ArrangeRestorer('eval') |
||||
]) |
||||
|
||||
# 分别构建训练和验证所用的数据集 |
||||
train_dataset = pdrs.datasets.ResDataset( |
||||
data_dir=DATA_DIR, |
||||
file_list=TRAIN_FILE_LIST_PATH, |
||||
transforms=train_transforms, |
||||
num_workers=0, |
||||
shuffle=True, |
||||
sr_factor=4) |
||||
|
||||
eval_dataset = pdrs.datasets.ResDataset( |
||||
data_dir=DATA_DIR, |
||||
file_list=EVAL_FILE_LIST_PATH, |
||||
transforms=eval_transforms, |
||||
num_workers=0, |
||||
shuffle=False, |
||||
sr_factor=4) |
||||
|
||||
# 使用默认参数构建DRN模型 |
||||
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md |
||||
# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py |
||||
model = pdrs.tasks.res.DRN() |
||||
|
||||
# 执行模型训练 |
||||
model.train( |
||||
num_epochs=10, |
||||
train_dataset=train_dataset, |
||||
train_batch_size=8, |
||||
eval_dataset=eval_dataset, |
||||
save_interval_epochs=5, |
||||
# 每多少次迭代记录一次日志 |
||||
log_interval_steps=10, |
||||
save_dir=EXP_DIR, |
||||
# 初始学习率大小 |
||||
learning_rate=0.001, |
||||
# 是否使用early stopping策略,当精度不再改善时提前终止训练 |
||||
early_stop=False, |
||||
# 是否启用VisualDL日志功能 |
||||
use_vdl=True, |
||||
# 指定从某个检查点继续训练 |
||||
resume_checkpoint=None) |
@ -1,80 +0,0 @@ |
||||
import os |
||||
import sys |
||||
sys.path.append(os.path.abspath('../PaddleRS')) |
||||
|
||||
import paddle |
||||
import paddlers as pdrs |
||||
|
||||
# 定义训练和验证时的transforms |
||||
train_transforms = pdrs.datasets.ComposeTrans( |
||||
input_keys=['lq', 'gt'], |
||||
output_keys=['lq', 'lqx2', 'gt'], |
||||
pipelines=[{ |
||||
'name': 'SRPairedRandomCrop', |
||||
'gt_patch_size': 192, |
||||
'scale': 4, |
||||
'scale_list': True |
||||
}, { |
||||
'name': 'PairedRandomHorizontalFlip' |
||||
}, { |
||||
'name': 'PairedRandomVerticalFlip' |
||||
}, { |
||||
'name': 'PairedRandomTransposeHW' |
||||
}, { |
||||
'name': 'Transpose' |
||||
}, { |
||||
'name': 'Normalize', |
||||
'mean': [0.0, 0.0, 0.0], |
||||
'std': [1.0, 1.0, 1.0] |
||||
}]) |
||||
|
||||
test_transforms = pdrs.datasets.ComposeTrans( |
||||
input_keys=['lq', 'gt'], |
||||
output_keys=['lq', 'gt'], |
||||
pipelines=[{ |
||||
'name': 'Transpose' |
||||
}, { |
||||
'name': 'Normalize', |
||||
'mean': [0.0, 0.0, 0.0], |
||||
'std': [1.0, 1.0, 1.0] |
||||
}]) |
||||
|
||||
# 定义训练集 |
||||
train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径 |
||||
train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径 |
||||
num_workers = 4 |
||||
batch_size = 8 |
||||
scale = 4 |
||||
train_dataset = pdrs.datasets.SRdataset( |
||||
mode='train', |
||||
gt_floder=train_gt_floder, |
||||
lq_floder=train_lq_floder, |
||||
transforms=train_transforms(), |
||||
scale=scale, |
||||
num_workers=num_workers, |
||||
batch_size=batch_size) |
||||
train_dict = train_dataset() |
||||
|
||||
# 定义测试集 |
||||
test_gt_floder = r"../work/RSdata_for_SR/test_HR" |
||||
test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4" |
||||
test_dataset = pdrs.datasets.SRdataset( |
||||
mode='test', |
||||
gt_floder=test_gt_floder, |
||||
lq_floder=test_lq_floder, |
||||
transforms=test_transforms(), |
||||
scale=scale) |
||||
|
||||
# 初始化模型,可以对网络结构的参数进行调整 |
||||
model = pdrs.tasks.res.DRNet( |
||||
n_blocks=30, n_feats=16, n_colors=3, rgb_range=255, negval=0.2) |
||||
|
||||
model.train( |
||||
total_iters=100000, |
||||
train_dataset=train_dataset(), |
||||
test_dataset=test_dataset(), |
||||
output_dir='output_dir', |
||||
validate=5000, |
||||
snapshot=5000, |
||||
lr_rate=0.0001, |
||||
log=10) |
@ -0,0 +1,89 @@ |
||||
#!/usr/bin/env python |
||||
|
||||
# 图像复原模型ESRGAN训练示例脚本 |
||||
# 执行此脚本前,请确认已正确安装PaddleRS库 |
||||
|
||||
import paddlers as pdrs |
||||
from paddlers import transforms as T |
||||
|
||||
# 数据集存放目录 |
||||
DATA_DIR = './data/rssr/' |
||||
# 训练集`file_list`文件路径 |
||||
TRAIN_FILE_LIST_PATH = './data/rssr/train.txt' |
||||
# 验证集`file_list`文件路径 |
||||
EVAL_FILE_LIST_PATH = './data/rssr/val.txt' |
||||
# 实验目录,保存输出的模型权重和结果 |
||||
EXP_DIR = './output/esrgan/' |
||||
|
||||
# 下载和解压遥感影像超分辨率数据集 |
||||
pdrs.utils.download_and_decompress( |
||||
'https://paddlers.bj.bcebos.com/datasets/rssr.zip', path='./data/') |
||||
|
||||
# 定义训练和验证时使用的数据变换(数据增强、预处理等) |
||||
# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行 |
||||
# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md |
||||
train_transforms = T.Compose([ |
||||
# 读取影像 |
||||
T.DecodeImg(), |
||||
# 从输入影像中裁剪32x32大小的影像块 |
||||
T.RandomCrop(crop_size=32), |
||||
# 以50%的概率实施随机水平翻转 |
||||
T.RandomHorizontalFlip(prob=0.5), |
||||
# 以50%的概率实施随机垂直翻转 |
||||
T.RandomVerticalFlip(prob=0.5), |
||||
# 将数据归一化到[0,1] |
||||
T.Normalize( |
||||
mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), |
||||
T.ArrangeRestorer('train') |
||||
]) |
||||
|
||||
eval_transforms = T.Compose([ |
||||
T.DecodeImg(), |
||||
# 将输入影像缩放到256x256大小 |
||||
T.Resize(target_size=256), |
||||
# 验证阶段与训练阶段的数据归一化方式必须相同 |
||||
T.Normalize( |
||||
mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), |
||||
T.ArrangeRestorer('eval') |
||||
]) |
||||
|
||||
# 分别构建训练和验证所用的数据集 |
||||
train_dataset = pdrs.datasets.ResDataset( |
||||
data_dir=DATA_DIR, |
||||
file_list=TRAIN_FILE_LIST_PATH, |
||||
transforms=train_transforms, |
||||
num_workers=0, |
||||
shuffle=True, |
||||
sr_factor=4) |
||||
|
||||
eval_dataset = pdrs.datasets.ResDataset( |
||||
data_dir=DATA_DIR, |
||||
file_list=EVAL_FILE_LIST_PATH, |
||||
transforms=eval_transforms, |
||||
num_workers=0, |
||||
shuffle=False, |
||||
sr_factor=4) |
||||
|
||||
# 使用默认参数构建ESRGAN模型 |
||||
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md |
||||
# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py |
||||
model = pdrs.tasks.res.ESRGAN() |
||||
|
||||
# 执行模型训练 |
||||
model.train( |
||||
num_epochs=10, |
||||
train_dataset=train_dataset, |
||||
train_batch_size=8, |
||||
eval_dataset=eval_dataset, |
||||
save_interval_epochs=5, |
||||
# 每多少次迭代记录一次日志 |
||||
log_interval_steps=10, |
||||
save_dir=EXP_DIR, |
||||
# 初始学习率大小 |
||||
learning_rate=0.001, |
||||
# 是否使用early stopping策略,当精度不再改善时提前终止训练 |
||||
early_stop=False, |
||||
# 是否启用VisualDL日志功能 |
||||
use_vdl=True, |
||||
# 指定从某个检查点继续训练 |
||||
resume_checkpoint=None) |
@ -1,80 +0,0 @@ |
||||
import os |
||||
import sys |
||||
sys.path.append(os.path.abspath('../PaddleRS')) |
||||
|
||||
import paddlers as pdrs |
||||
|
||||
# 定义训练和验证时的transforms |
||||
train_transforms = pdrs.datasets.ComposeTrans( |
||||
input_keys=['lq', 'gt'], |
||||
output_keys=['lq', 'gt'], |
||||
pipelines=[{ |
||||
'name': 'SRPairedRandomCrop', |
||||
'gt_patch_size': 128, |
||||
'scale': 4 |
||||
}, { |
||||
'name': 'PairedRandomHorizontalFlip' |
||||
}, { |
||||
'name': 'PairedRandomVerticalFlip' |
||||
}, { |
||||
'name': 'PairedRandomTransposeHW' |
||||
}, { |
||||
'name': 'Transpose' |
||||
}, { |
||||
'name': 'Normalize', |
||||
'mean': [0.0, 0.0, 0.0], |
||||
'std': [255.0, 255.0, 255.0] |
||||
}]) |
||||
|
||||
test_transforms = pdrs.datasets.ComposeTrans( |
||||
input_keys=['lq', 'gt'], |
||||
output_keys=['lq', 'gt'], |
||||
pipelines=[{ |
||||
'name': 'Transpose' |
||||
}, { |
||||
'name': 'Normalize', |
||||
'mean': [0.0, 0.0, 0.0], |
||||
'std': [255.0, 255.0, 255.0] |
||||
}]) |
||||
|
||||
# 定义训练集 |
||||
train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径 |
||||
train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径 |
||||
num_workers = 6 |
||||
batch_size = 32 |
||||
scale = 4 |
||||
train_dataset = pdrs.datasets.SRdataset( |
||||
mode='train', |
||||
gt_floder=train_gt_floder, |
||||
lq_floder=train_lq_floder, |
||||
transforms=train_transforms(), |
||||
scale=scale, |
||||
num_workers=num_workers, |
||||
batch_size=batch_size) |
||||
|
||||
# 定义测试集 |
||||
test_gt_floder = r"../work/RSdata_for_SR/test_HR" |
||||
test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4" |
||||
test_dataset = pdrs.datasets.SRdataset( |
||||
mode='test', |
||||
gt_floder=test_gt_floder, |
||||
lq_floder=test_lq_floder, |
||||
transforms=test_transforms(), |
||||
scale=scale) |
||||
|
||||
# 初始化模型,可以对网络结构的参数进行调整 |
||||
# 若loss_type='gan' 使用感知损失、对抗损失和像素损失 |
||||
# 若loss_type = 'pixel' 只使用像素损失 |
||||
model = pdrs.tasks.res.ESRGANet(loss_type='pixel') |
||||
|
||||
model.train( |
||||
total_iters=1000000, |
||||
train_dataset=train_dataset(), |
||||
test_dataset=test_dataset(), |
||||
output_dir='output_dir', |
||||
validate=5000, |
||||
snapshot=5000, |
||||
log=100, |
||||
lr_rate=0.0001, |
||||
periods=[250000, 250000, 250000, 250000], |
||||
restart_weights=[1, 1, 1, 1]) |
@ -0,0 +1,89 @@ |
||||
#!/usr/bin/env python |
||||
|
||||
# 图像复原模型LESRCNN训练示例脚本 |
||||
# 执行此脚本前,请确认已正确安装PaddleRS库 |
||||
|
||||
import paddlers as pdrs |
||||
from paddlers import transforms as T |
||||
|
||||
# 数据集存放目录 |
||||
DATA_DIR = './data/rssr/' |
||||
# 训练集`file_list`文件路径 |
||||
TRAIN_FILE_LIST_PATH = './data/rssr/train.txt' |
||||
# 验证集`file_list`文件路径 |
||||
EVAL_FILE_LIST_PATH = './data/rssr/val.txt' |
||||
# 实验目录,保存输出的模型权重和结果 |
||||
EXP_DIR = './output/lesrcnn/' |
||||
|
||||
# 下载和解压遥感影像超分辨率数据集 |
||||
pdrs.utils.download_and_decompress( |
||||
'https://paddlers.bj.bcebos.com/datasets/rssr.zip', path='./data/') |
||||
|
||||
# 定义训练和验证时使用的数据变换(数据增强、预处理等) |
||||
# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行 |
||||
# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md |
||||
train_transforms = T.Compose([ |
||||
# 读取影像 |
||||
T.DecodeImg(), |
||||
# 从输入影像中裁剪32x32大小的影像块 |
||||
T.RandomCrop(crop_size=32), |
||||
# 以50%的概率实施随机水平翻转 |
||||
T.RandomHorizontalFlip(prob=0.5), |
||||
# 以50%的概率实施随机垂直翻转 |
||||
T.RandomVerticalFlip(prob=0.5), |
||||
# 将数据归一化到[0,1] |
||||
T.Normalize( |
||||
mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), |
||||
T.ArrangeRestorer('train') |
||||
]) |
||||
|
||||
eval_transforms = T.Compose([ |
||||
T.DecodeImg(), |
||||
# 将输入影像缩放到256x256大小 |
||||
T.Resize(target_size=256), |
||||
# 验证阶段与训练阶段的数据归一化方式必须相同 |
||||
T.Normalize( |
||||
mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), |
||||
T.ArrangeRestorer('eval') |
||||
]) |
||||
|
||||
# 分别构建训练和验证所用的数据集 |
||||
train_dataset = pdrs.datasets.ResDataset( |
||||
data_dir=DATA_DIR, |
||||
file_list=TRAIN_FILE_LIST_PATH, |
||||
transforms=train_transforms, |
||||
num_workers=0, |
||||
shuffle=True, |
||||
sr_factor=4) |
||||
|
||||
eval_dataset = pdrs.datasets.ResDataset( |
||||
data_dir=DATA_DIR, |
||||
file_list=EVAL_FILE_LIST_PATH, |
||||
transforms=eval_transforms, |
||||
num_workers=0, |
||||
shuffle=False, |
||||
sr_factor=4) |
||||
|
||||
# 使用默认参数构建LESRCNN模型 |
||||
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md |
||||
# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/restorer.py |
||||
model = pdrs.tasks.res.LESRCNN() |
||||
|
||||
# 执行模型训练 |
||||
model.train( |
||||
num_epochs=10, |
||||
train_dataset=train_dataset, |
||||
train_batch_size=8, |
||||
eval_dataset=eval_dataset, |
||||
save_interval_epochs=5, |
||||
# 每多少次迭代记录一次日志 |
||||
log_interval_steps=10, |
||||
save_dir=EXP_DIR, |
||||
# 初始学习率大小 |
||||
learning_rate=0.001, |
||||
# 是否使用early stopping策略,当精度不再改善时提前终止训练 |
||||
early_stop=False, |
||||
# 是否启用VisualDL日志功能 |
||||
use_vdl=True, |
||||
# 指定从某个检查点继续训练 |
||||
resume_checkpoint=None) |
@ -1,78 +0,0 @@ |
||||
import os |
||||
import sys |
||||
sys.path.append(os.path.abspath('../PaddleRS')) |
||||
|
||||
import paddlers as pdrs |
||||
|
||||
# 定义训练和验证时的transforms |
||||
train_transforms = pdrs.datasets.ComposeTrans( |
||||
input_keys=['lq', 'gt'], |
||||
output_keys=['lq', 'gt'], |
||||
pipelines=[{ |
||||
'name': 'SRPairedRandomCrop', |
||||
'gt_patch_size': 192, |
||||
'scale': 4 |
||||
}, { |
||||
'name': 'PairedRandomHorizontalFlip' |
||||
}, { |
||||
'name': 'PairedRandomVerticalFlip' |
||||
}, { |
||||
'name': 'PairedRandomTransposeHW' |
||||
}, { |
||||
'name': 'Transpose' |
||||
}, { |
||||
'name': 'Normalize', |
||||
'mean': [0.0, 0.0, 0.0], |
||||
'std': [255.0, 255.0, 255.0] |
||||
}]) |
||||
|
||||
test_transforms = pdrs.datasets.ComposeTrans( |
||||
input_keys=['lq', 'gt'], |
||||
output_keys=['lq', 'gt'], |
||||
pipelines=[{ |
||||
'name': 'Transpose' |
||||
}, { |
||||
'name': 'Normalize', |
||||
'mean': [0.0, 0.0, 0.0], |
||||
'std': [255.0, 255.0, 255.0] |
||||
}]) |
||||
|
||||
# 定义训练集 |
||||
train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径 |
||||
train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径 |
||||
num_workers = 4 |
||||
batch_size = 16 |
||||
scale = 4 |
||||
train_dataset = pdrs.datasets.SRdataset( |
||||
mode='train', |
||||
gt_floder=train_gt_floder, |
||||
lq_floder=train_lq_floder, |
||||
transforms=train_transforms(), |
||||
scale=scale, |
||||
num_workers=num_workers, |
||||
batch_size=batch_size) |
||||
|
||||
# 定义测试集 |
||||
test_gt_floder = r"../work/RSdata_for_SR/test_HR" |
||||
test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4" |
||||
test_dataset = pdrs.datasets.SRdataset( |
||||
mode='test', |
||||
gt_floder=test_gt_floder, |
||||
lq_floder=test_lq_floder, |
||||
transforms=test_transforms(), |
||||
scale=scale) |
||||
|
||||
# 初始化模型,可以对网络结构的参数进行调整 |
||||
model = pdrs.tasks.res.LESRCNNet(scale=4, multi_scale=False, group=1) |
||||
|
||||
model.train( |
||||
total_iters=1000000, |
||||
train_dataset=train_dataset(), |
||||
test_dataset=test_dataset(), |
||||
output_dir='output_dir', |
||||
validate=5000, |
||||
snapshot=5000, |
||||
log=100, |
||||
lr_rate=0.0001, |
||||
periods=[250000, 250000, 250000, 250000], |
||||
restart_weights=[1, 1, 1, 1]) |
Loading…
Reference in new issue