parent
d2e80e829d
commit
d29af2909c
7 changed files with 1094 additions and 1 deletions
@ -0,0 +1,99 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
|
||||
# 超分辨率数据集定义 |
||||
class SRdataset(object): |
||||
def __init__(self, |
||||
mode, |
||||
gt_floder, |
||||
lq_floder, |
||||
transforms, |
||||
scale, |
||||
num_workers=4, |
||||
batch_size=8): |
||||
if mode == 'train': |
||||
preprocess = [] |
||||
preprocess.append({ |
||||
'name': 'LoadImageFromFile', |
||||
'key': 'lq' |
||||
}) # 加载方式 |
||||
preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'}) |
||||
preprocess.append(transforms) # 变换方式 |
||||
self.dataset = { |
||||
'name': 'SRDataset', |
||||
'gt_folder': gt_floder, |
||||
'lq_folder': lq_floder, |
||||
'num_workers': num_workers, |
||||
'batch_size': batch_size, |
||||
'scale': scale, |
||||
'preprocess': preprocess |
||||
} |
||||
|
||||
if mode == "test": |
||||
preprocess = [] |
||||
preprocess.append({'name': 'LoadImageFromFile', 'key': 'lq'}) |
||||
preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'}) |
||||
preprocess.append(transforms) |
||||
self.dataset = { |
||||
'name': 'SRDataset', |
||||
'gt_folder': gt_floder, |
||||
'lq_folder': lq_floder, |
||||
'scale': scale, |
||||
'preprocess': preprocess |
||||
} |
||||
|
||||
def __call__(self): |
||||
return self.dataset |
||||
|
||||
|
||||
# 对定义的transforms处理方式组合,返回字典 |
||||
class ComposeTrans(object): |
||||
def __init__(self, input_keys, output_keys, pipelines): |
||||
if not isinstance(pipelines, list): |
||||
raise TypeError( |
||||
'Type of transforms is invalid. Must be List, but received is {}' |
||||
.format(type(pipelines))) |
||||
if len(pipelines) < 1: |
||||
raise ValueError( |
||||
'Length of transforms must not be less than 1, but received is {}' |
||||
.format(len(pipelines))) |
||||
self.transforms = pipelines |
||||
self.output_length = len(output_keys) # 当output_keys的长度为3时,是DRN训练 |
||||
self.input_keys = input_keys |
||||
self.output_keys = output_keys |
||||
|
||||
def __call__(self): |
||||
pipeline = [] |
||||
for op in self.transforms: |
||||
if op['name'] == 'SRPairedRandomCrop': |
||||
op['keys'] = ['image'] * 2 |
||||
else: |
||||
op['keys'] = ['image'] * self.output_length |
||||
pipeline.append(op) |
||||
if self.output_length == 2: |
||||
transform_dict = { |
||||
'name': 'Transforms', |
||||
'input_keys': self.input_keys, |
||||
'pipeline': pipeline |
||||
} |
||||
else: |
||||
transform_dict = { |
||||
'name': 'Transforms', |
||||
'input_keys': self.input_keys, |
||||
'output_keys': self.output_keys, |
||||
'pipeline': pipeline |
||||
} |
||||
|
||||
return transform_dict |
@ -0,0 +1,753 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import os |
||||
import time |
||||
|
||||
import datetime |
||||
|
||||
import paddle |
||||
from paddle.distributed import ParallelEnv |
||||
|
||||
from ..models.ppgan.datasets.builder import build_dataloader |
||||
from ..models.ppgan.models.builder import build_model |
||||
from ..models.ppgan.utils.visual import tensor2img, save_image |
||||
from ..models.ppgan.utils.filesystem import makedirs, save, load |
||||
from ..models.ppgan.utils.timer import TimeAverager |
||||
from ..models.ppgan.utils.profiler import add_profiler_step |
||||
from ..models.ppgan.utils.logger import setup_logger |
||||
|
||||
|
||||
# 定义AttrDict类实现动态属性 |
||||
class AttrDict(dict): |
||||
def __getattr__(self, key): |
||||
try: |
||||
return self[key] |
||||
except KeyError: |
||||
raise AttributeError(key) |
||||
|
||||
def __setattr__(self, key, value): |
||||
if key in self.__dict__: |
||||
self.__dict__[key] = value |
||||
else: |
||||
self[key] = value |
||||
|
||||
|
||||
# 创建AttrDict类 |
||||
def create_attr_dict(config_dict): |
||||
from ast import literal_eval |
||||
for key, value in config_dict.items(): |
||||
if type(value) is dict: |
||||
config_dict[key] = value = AttrDict(value) |
||||
if isinstance(value, str): |
||||
try: |
||||
value = literal_eval(value) |
||||
except BaseException: |
||||
pass |
||||
if isinstance(value, AttrDict): |
||||
create_attr_dict(config_dict[key]) |
||||
else: |
||||
config_dict[key] = value |
||||
|
||||
|
||||
# 数据加载类 |
||||
class IterLoader: |
||||
def __init__(self, dataloader): |
||||
self._dataloader = dataloader |
||||
self.iter_loader = iter(self._dataloader) |
||||
self._epoch = 1 |
||||
|
||||
@property |
||||
def epoch(self): |
||||
return self._epoch |
||||
|
||||
def __next__(self): |
||||
try: |
||||
data = next(self.iter_loader) |
||||
except StopIteration: |
||||
self._epoch += 1 |
||||
self.iter_loader = iter(self._dataloader) |
||||
data = next(self.iter_loader) |
||||
|
||||
return data |
||||
|
||||
def __len__(self): |
||||
return len(self._dataloader) |
||||
|
||||
|
||||
# 基础训练类 |
||||
class Restorer: |
||||
""" |
||||
# trainer calling logic: |
||||
# |
||||
# build_model || model(BaseModel) |
||||
# | || |
||||
# build_dataloader || dataloader |
||||
# | || |
||||
# model.setup_lr_schedulers || lr_scheduler |
||||
# | || |
||||
# model.setup_optimizers || optimizers |
||||
# | || |
||||
# train loop (model.setup_input + model.train_iter) || train loop |
||||
# | || |
||||
# print log (model.get_current_losses) || |
||||
# | || |
||||
# save checkpoint (model.nets) \/ |
||||
""" |
||||
|
||||
def __init__(self, cfg, logger): |
||||
# base config |
||||
# self.logger = logging.getLogger(__name__) |
||||
self.logger = logger |
||||
self.cfg = cfg |
||||
self.output_dir = cfg.output_dir |
||||
self.max_eval_steps = cfg.model.get('max_eval_steps', None) |
||||
|
||||
self.local_rank = ParallelEnv().local_rank |
||||
self.world_size = ParallelEnv().nranks |
||||
self.log_interval = cfg.log_config.interval |
||||
self.visual_interval = cfg.log_config.visiual_interval |
||||
self.weight_interval = cfg.snapshot_config.interval |
||||
|
||||
self.start_epoch = 1 |
||||
self.current_epoch = 1 |
||||
self.current_iter = 1 |
||||
self.inner_iter = 1 |
||||
self.batch_id = 0 |
||||
self.global_steps = 0 |
||||
|
||||
# build model |
||||
self.model = build_model(cfg.model) |
||||
# multiple gpus prepare |
||||
if ParallelEnv().nranks > 1: |
||||
self.distributed_data_parallel() |
||||
|
||||
# build metrics |
||||
self.metrics = None |
||||
self.is_save_img = True |
||||
validate_cfg = cfg.get('validate', None) |
||||
if validate_cfg and 'metrics' in validate_cfg: |
||||
self.metrics = self.model.setup_metrics(validate_cfg['metrics']) |
||||
if validate_cfg and 'save_img' in validate_cfg: |
||||
self.is_save_img = validate_cfg['save_img'] |
||||
|
||||
self.enable_visualdl = cfg.get('enable_visualdl', False) |
||||
if self.enable_visualdl: |
||||
import visualdl |
||||
self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir) |
||||
|
||||
# evaluate only |
||||
if not cfg.is_train: |
||||
return |
||||
|
||||
# build train dataloader |
||||
self.train_dataloader = build_dataloader(cfg.dataset.train) |
||||
self.iters_per_epoch = len(self.train_dataloader) |
||||
|
||||
# build lr scheduler |
||||
# TODO: has a better way? |
||||
if 'lr_scheduler' in cfg and 'iters_per_epoch' in cfg.lr_scheduler: |
||||
cfg.lr_scheduler.iters_per_epoch = self.iters_per_epoch |
||||
self.lr_schedulers = self.model.setup_lr_schedulers(cfg.lr_scheduler) |
||||
|
||||
# build optimizers |
||||
self.optimizers = self.model.setup_optimizers(self.lr_schedulers, |
||||
cfg.optimizer) |
||||
|
||||
self.epochs = cfg.get('epochs', None) |
||||
if self.epochs: |
||||
self.total_iters = self.epochs * self.iters_per_epoch |
||||
self.by_epoch = True |
||||
else: |
||||
self.by_epoch = False |
||||
self.total_iters = cfg.total_iters |
||||
|
||||
if self.by_epoch: |
||||
self.weight_interval *= self.iters_per_epoch |
||||
|
||||
self.validate_interval = -1 |
||||
if cfg.get('validate', None) is not None: |
||||
self.validate_interval = cfg.validate.get('interval', -1) |
||||
|
||||
self.time_count = {} |
||||
self.best_metric = {} |
||||
self.model.set_total_iter(self.total_iters) |
||||
self.profiler_options = cfg.profiler_options |
||||
|
||||
def distributed_data_parallel(self): |
||||
paddle.distributed.init_parallel_env() |
||||
find_unused_parameters = self.cfg.get('find_unused_parameters', False) |
||||
for net_name, net in self.model.nets.items(): |
||||
self.model.nets[net_name] = paddle.DataParallel( |
||||
net, find_unused_parameters=find_unused_parameters) |
||||
|
||||
def learning_rate_scheduler_step(self): |
||||
if isinstance(self.model.lr_scheduler, dict): |
||||
for lr_scheduler in self.model.lr_scheduler.values(): |
||||
lr_scheduler.step() |
||||
elif isinstance(self.model.lr_scheduler, |
||||
paddle.optimizer.lr.LRScheduler): |
||||
self.model.lr_scheduler.step() |
||||
else: |
||||
raise ValueError( |
||||
'lr schedulter must be a dict or an instance of LRScheduler') |
||||
|
||||
def train(self): |
||||
reader_cost_averager = TimeAverager() |
||||
batch_cost_averager = TimeAverager() |
||||
|
||||
iter_loader = IterLoader(self.train_dataloader) |
||||
|
||||
# set model.is_train = True |
||||
self.model.setup_train_mode(is_train=True) |
||||
while self.current_iter < (self.total_iters + 1): |
||||
self.current_epoch = iter_loader.epoch |
||||
self.inner_iter = self.current_iter % self.iters_per_epoch |
||||
|
||||
add_profiler_step(self.profiler_options) |
||||
|
||||
start_time = step_start_time = time.time() |
||||
data = next(iter_loader) |
||||
reader_cost_averager.record(time.time() - step_start_time) |
||||
# unpack data from dataset and apply preprocessing |
||||
# data input should be dict |
||||
self.model.setup_input(data) |
||||
self.model.train_iter(self.optimizers) |
||||
|
||||
batch_cost_averager.record( |
||||
time.time() - step_start_time, |
||||
num_samples=self.cfg['dataset']['train'].get('batch_size', 1)) |
||||
|
||||
step_start_time = time.time() |
||||
|
||||
if self.current_iter % self.log_interval == 0: |
||||
self.data_time = reader_cost_averager.get_average() |
||||
self.step_time = batch_cost_averager.get_average() |
||||
self.ips = batch_cost_averager.get_ips_average() |
||||
self.print_log() |
||||
|
||||
reader_cost_averager.reset() |
||||
batch_cost_averager.reset() |
||||
|
||||
if self.current_iter % self.visual_interval == 0 and self.local_rank == 0: |
||||
self.visual('visual_train') |
||||
|
||||
self.learning_rate_scheduler_step() |
||||
|
||||
if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0: |
||||
self.test() |
||||
|
||||
if self.current_iter % self.weight_interval == 0: |
||||
self.save(self.current_iter, 'weight', keep=-1) |
||||
self.save(self.current_iter) |
||||
|
||||
self.current_iter += 1 |
||||
|
||||
def test(self): |
||||
if not hasattr(self, 'test_dataloader'): |
||||
self.test_dataloader = build_dataloader( |
||||
self.cfg.dataset.test, is_train=False) |
||||
iter_loader = IterLoader(self.test_dataloader) |
||||
if self.max_eval_steps is None: |
||||
self.max_eval_steps = len(self.test_dataloader) |
||||
|
||||
if self.metrics: |
||||
for metric in self.metrics.values(): |
||||
metric.reset() |
||||
|
||||
# set model.is_train = False |
||||
self.model.setup_train_mode(is_train=False) |
||||
|
||||
for i in range(self.max_eval_steps): |
||||
if self.max_eval_steps < self.log_interval or i % self.log_interval == 0: |
||||
self.logger.info('Test iter: [%d/%d]' % ( |
||||
i * self.world_size, self.max_eval_steps * self.world_size)) |
||||
|
||||
data = next(iter_loader) |
||||
self.model.setup_input(data) |
||||
self.model.test_iter(metrics=self.metrics) |
||||
|
||||
if self.is_save_img: |
||||
visual_results = {} |
||||
current_paths = self.model.get_image_paths() |
||||
current_visuals = self.model.get_current_visuals() |
||||
|
||||
if len(current_visuals) > 0 and list(current_visuals.values())[ |
||||
0].shape == 4: |
||||
num_samples = list(current_visuals.values())[0].shape[0] |
||||
else: |
||||
num_samples = 1 |
||||
|
||||
for j in range(num_samples): |
||||
if j < len(current_paths): |
||||
short_path = os.path.basename(current_paths[j]) |
||||
basename = os.path.splitext(short_path)[0] |
||||
else: |
||||
basename = '{:04d}_{:04d}'.format(i, j) |
||||
for k, img_tensor in current_visuals.items(): |
||||
name = '%s_%s' % (basename, k) |
||||
if len(img_tensor.shape) == 4: |
||||
visual_results.update({name: img_tensor[j]}) |
||||
else: |
||||
visual_results.update({name: img_tensor}) |
||||
|
||||
self.visual( |
||||
'visual_test', |
||||
visual_results=visual_results, |
||||
step=self.batch_id, |
||||
is_save_image=True) |
||||
|
||||
if self.metrics: |
||||
for metric_name, metric in self.metrics.items(): |
||||
self.logger.info("Metric {}: {:.4f}".format( |
||||
metric_name, metric.accumulate())) |
||||
|
||||
def print_log(self): |
||||
losses = self.model.get_current_losses() |
||||
|
||||
message = '' |
||||
if self.by_epoch: |
||||
message += 'Epoch: %d/%d, iter: %d/%d ' % ( |
||||
self.current_epoch, self.epochs, self.inner_iter, |
||||
self.iters_per_epoch) |
||||
else: |
||||
message += 'Iter: %d/%d ' % (self.current_iter, self.total_iters) |
||||
|
||||
message += f'lr: {self.current_learning_rate:.3e} ' |
||||
|
||||
for k, v in losses.items(): |
||||
message += '%s: %.3f ' % (k, v) |
||||
if self.enable_visualdl: |
||||
self.vdl_logger.add_scalar(k, v, step=self.global_steps) |
||||
|
||||
if hasattr(self, 'step_time'): |
||||
message += 'batch_cost: %.5f sec ' % self.step_time |
||||
|
||||
if hasattr(self, 'data_time'): |
||||
message += 'reader_cost: %.5f sec ' % self.data_time |
||||
|
||||
if hasattr(self, 'ips'): |
||||
message += 'ips: %.5f images/s ' % self.ips |
||||
|
||||
if hasattr(self, 'step_time'): |
||||
eta = self.step_time * (self.total_iters - self.current_iter) |
||||
eta = eta if eta > 0 else 0 |
||||
|
||||
eta_str = str(datetime.timedelta(seconds=int(eta))) |
||||
message += f'eta: {eta_str}' |
||||
|
||||
# print the message |
||||
self.logger.info(message) |
||||
|
||||
@property |
||||
def current_learning_rate(self): |
||||
for optimizer in self.model.optimizers.values(): |
||||
return optimizer.get_lr() |
||||
|
||||
def visual(self, |
||||
results_dir, |
||||
visual_results=None, |
||||
step=None, |
||||
is_save_image=False): |
||||
""" |
||||
visual the images, use visualdl or directly write to the directory |
||||
Parameters: |
||||
results_dir (str) -- directory name which contains saved images |
||||
visual_results (dict) -- the results images dict |
||||
step (int) -- global steps, used in visualdl |
||||
is_save_image (bool) -- weather write to the directory or visualdl |
||||
""" |
||||
self.model.compute_visuals() |
||||
|
||||
if visual_results is None: |
||||
visual_results = self.model.get_current_visuals() |
||||
|
||||
min_max = self.cfg.get('min_max', None) |
||||
if min_max is None: |
||||
min_max = (-1., 1.) |
||||
|
||||
image_num = self.cfg.get('image_num', None) |
||||
if (image_num is None) or (not self.enable_visualdl): |
||||
image_num = 1 |
||||
for label, image in visual_results.items(): |
||||
image_numpy = tensor2img(image, min_max, image_num) |
||||
if (not is_save_image) and self.enable_visualdl: |
||||
self.vdl_logger.add_image( |
||||
results_dir + '/' + label, |
||||
image_numpy, |
||||
step=step if step else self.global_steps, |
||||
dataformats="HWC" if image_num == 1 else "NCHW") |
||||
else: |
||||
if self.cfg.is_train: |
||||
if self.by_epoch: |
||||
msg = 'epoch%.3d_' % self.current_epoch |
||||
else: |
||||
msg = 'iter%.3d_' % self.current_iter |
||||
else: |
||||
msg = '' |
||||
makedirs(os.path.join(self.output_dir, results_dir)) |
||||
img_path = os.path.join(self.output_dir, results_dir, |
||||
msg + '%s.png' % (label)) |
||||
save_image(image_numpy, img_path) |
||||
|
||||
def save(self, epoch, name='checkpoint', keep=1): |
||||
if self.local_rank != 0: |
||||
return |
||||
|
||||
assert name in ['checkpoint', 'weight'] |
||||
|
||||
state_dicts = {} |
||||
if self.by_epoch: |
||||
save_filename = 'epoch_%s_%s.pdparams' % ( |
||||
epoch // self.iters_per_epoch, name) |
||||
else: |
||||
save_filename = 'iter_%s_%s.pdparams' % (epoch, name) |
||||
|
||||
os.makedirs(self.output_dir, exist_ok=True) |
||||
save_path = os.path.join(self.output_dir, save_filename) |
||||
for net_name, net in self.model.nets.items(): |
||||
state_dicts[net_name] = net.state_dict() |
||||
|
||||
if name == 'weight': |
||||
save(state_dicts, save_path) |
||||
return |
||||
|
||||
state_dicts['epoch'] = epoch |
||||
|
||||
for opt_name, opt in self.model.optimizers.items(): |
||||
state_dicts[opt_name] = opt.state_dict() |
||||
|
||||
save(state_dicts, save_path) |
||||
|
||||
if keep > 0: |
||||
try: |
||||
if self.by_epoch: |
||||
checkpoint_name_to_be_removed = os.path.join( |
||||
self.output_dir, 'epoch_%s_%s.pdparams' % ( |
||||
(epoch - keep * self.weight_interval) // |
||||
self.iters_per_epoch, name)) |
||||
else: |
||||
checkpoint_name_to_be_removed = os.path.join( |
||||
self.output_dir, 'iter_%s_%s.pdparams' % |
||||
(epoch - keep * self.weight_interval, name)) |
||||
|
||||
if os.path.exists(checkpoint_name_to_be_removed): |
||||
os.remove(checkpoint_name_to_be_removed) |
||||
|
||||
except Exception as e: |
||||
self.logger.info('remove old checkpoints error: {}'.format(e)) |
||||
|
||||
def resume(self, checkpoint_path): |
||||
state_dicts = load(checkpoint_path) |
||||
if state_dicts.get('epoch', None) is not None: |
||||
self.start_epoch = state_dicts['epoch'] + 1 |
||||
self.global_steps = self.iters_per_epoch * state_dicts['epoch'] |
||||
|
||||
self.current_iter = state_dicts['epoch'] + 1 |
||||
|
||||
for net_name, net in self.model.nets.items(): |
||||
net.set_state_dict(state_dicts[net_name]) |
||||
|
||||
for opt_name, opt in self.model.optimizers.items(): |
||||
opt.set_state_dict(state_dicts[opt_name]) |
||||
|
||||
def load(self, weight_path): |
||||
state_dicts = load(weight_path) |
||||
|
||||
for net_name, net in self.model.nets.items(): |
||||
if net_name in state_dicts: |
||||
net.set_state_dict(state_dicts[net_name]) |
||||
self.logger.info('Loaded pretrained weight for net {}'.format( |
||||
net_name)) |
||||
else: |
||||
self.logger.warning( |
||||
'Can not find state dict of net {}. Skip load pretrained weight for net {}' |
||||
.format(net_name, net_name)) |
||||
|
||||
def close(self): |
||||
""" |
||||
when finish the training need close file handler or other. |
||||
""" |
||||
if self.enable_visualdl: |
||||
self.vdl_logger.close() |
||||
|
||||
|
||||
# 基础超分模型训练类 |
||||
class BasicSRNet: |
||||
def __init__(self): |
||||
self.model = {} |
||||
self.optimizer = {} |
||||
self.lr_scheduler = {} |
||||
self.min_max = '' |
||||
|
||||
def train( |
||||
self, |
||||
total_iters, |
||||
train_dataset, |
||||
test_dataset, |
||||
output_dir, |
||||
validate, |
||||
snapshot, |
||||
log, |
||||
lr_rate, |
||||
evaluate_weights='', |
||||
resume='', |
||||
pretrain_weights='', |
||||
periods=[100000], |
||||
restart_weights=[1], ): |
||||
self.lr_scheduler['learning_rate'] = lr_rate |
||||
|
||||
if self.lr_scheduler['name'] == 'CosineAnnealingRestartLR': |
||||
self.lr_scheduler['periods'] = periods |
||||
self.lr_scheduler['restart_weights'] = restart_weights |
||||
|
||||
validate = { |
||||
'interval': validate, |
||||
'save_img': False, |
||||
'metrics': { |
||||
'psnr': { |
||||
'name': 'PSNR', |
||||
'crop_border': 4, |
||||
'test_y_channel': True |
||||
}, |
||||
'ssim': { |
||||
'name': 'SSIM', |
||||
'crop_border': 4, |
||||
'test_y_channel': True |
||||
} |
||||
} |
||||
} |
||||
log_config = {'interval': log, 'visiual_interval': 500} |
||||
snapshot_config = {'interval': snapshot} |
||||
|
||||
cfg = { |
||||
'total_iters': total_iters, |
||||
'output_dir': output_dir, |
||||
'min_max': self.min_max, |
||||
'model': self.model, |
||||
'dataset': { |
||||
'train': train_dataset, |
||||
'test': test_dataset |
||||
}, |
||||
'lr_scheduler': self.lr_scheduler, |
||||
'optimizer': self.optimizer, |
||||
'validate': validate, |
||||
'log_config': log_config, |
||||
'snapshot_config': snapshot_config |
||||
} |
||||
|
||||
cfg = AttrDict(cfg) |
||||
create_attr_dict(cfg) |
||||
|
||||
cfg.is_train = True |
||||
cfg.profiler_options = None |
||||
cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime()) |
||||
|
||||
if cfg.model.name == 'BaseSRModel': |
||||
floderModelName = cfg.model.generator.name |
||||
else: |
||||
floderModelName = cfg.model.name |
||||
cfg.output_dir = os.path.join(cfg.output_dir, |
||||
floderModelName + cfg.timestamp) |
||||
|
||||
logger_cfg = setup_logger(cfg.output_dir) |
||||
logger_cfg.info('Configs: {}'.format(cfg)) |
||||
|
||||
if paddle.is_compiled_with_cuda(): |
||||
paddle.set_device('gpu') |
||||
else: |
||||
paddle.set_device('cpu') |
||||
|
||||
# build trainer |
||||
trainer = Restorer(cfg, logger_cfg) |
||||
|
||||
# continue train or evaluate, checkpoint need contain epoch and optimizer info |
||||
if len(resume) > 0: |
||||
trainer.resume(resume) |
||||
# evaluate or finute, only load generator weights |
||||
elif len(pretrain_weights) > 0: |
||||
trainer.load(pretrain_weights) |
||||
if len(evaluate_weights) > 0: |
||||
trainer.load(evaluate_weights) |
||||
trainer.test() |
||||
return |
||||
# training, when keyboard interrupt save weights |
||||
try: |
||||
trainer.train() |
||||
except KeyboardInterrupt as e: |
||||
trainer.save(trainer.current_epoch) |
||||
|
||||
trainer.close() |
||||
|
||||
|
||||
# DRN模型训练 |
||||
class DRNet(BasicSRNet): |
||||
def __init__(self, |
||||
n_blocks=30, |
||||
n_feats=16, |
||||
n_colors=3, |
||||
rgb_range=255, |
||||
negval=0.2): |
||||
super(DRNet, self).__init__() |
||||
self.min_max = '(0., 255.)' |
||||
self.generator = { |
||||
'name': 'DRNGenerator', |
||||
'scale': (2, 4), |
||||
'n_blocks': n_blocks, |
||||
'n_feats': n_feats, |
||||
'n_colors': n_colors, |
||||
'rgb_range': rgb_range, |
||||
'negval': negval |
||||
} |
||||
self.pixel_criterion = {'name': 'L1Loss'} |
||||
self.model = { |
||||
'name': 'DRN', |
||||
'generator': self.generator, |
||||
'pixel_criterion': self.pixel_criterion |
||||
} |
||||
self.optimizer = { |
||||
'optimG': { |
||||
'name': 'Adam', |
||||
'net_names': ['generator'], |
||||
'weight_decay': 0.0, |
||||
'beta1': 0.9, |
||||
'beta2': 0.999 |
||||
}, |
||||
'optimD': { |
||||
'name': 'Adam', |
||||
'net_names': ['dual_model_0', 'dual_model_1'], |
||||
'weight_decay': 0.0, |
||||
'beta1': 0.9, |
||||
'beta2': 0.999 |
||||
} |
||||
} |
||||
self.lr_scheduler = { |
||||
'name': 'CosineAnnealingRestartLR', |
||||
'eta_min': 1e-07 |
||||
} |
||||
|
||||
|
||||
# 轻量化超分模型LESRCNN训练 |
||||
class LESRCNNet(BasicSRNet): |
||||
def __init__(self, scale=4, multi_scale=False, group=1): |
||||
super(LESRCNNet, self).__init__() |
||||
self.min_max = '(0., 1.)' |
||||
self.generator = { |
||||
'name': 'LESRCNNGenerator', |
||||
'scale': scale, |
||||
'multi_scale': False, |
||||
'group': 1 |
||||
} |
||||
self.pixel_criterion = {'name': 'L1Loss'} |
||||
self.model = { |
||||
'name': 'BaseSRModel', |
||||
'generator': self.generator, |
||||
'pixel_criterion': self.pixel_criterion |
||||
} |
||||
self.optimizer = { |
||||
'name': 'Adam', |
||||
'net_names': ['generator'], |
||||
'beta1': 0.9, |
||||
'beta2': 0.99 |
||||
} |
||||
self.lr_scheduler = { |
||||
'name': 'CosineAnnealingRestartLR', |
||||
'eta_min': 1e-07 |
||||
} |
||||
|
||||
|
||||
# ESRGAN模型训练 |
||||
# 若loss_type='gan' 使用感知损失、对抗损失和像素损失 |
||||
# 若loss_type = 'pixel' 只使用像素损失 |
||||
class ESRGANet(BasicSRNet): |
||||
def __init__(self, loss_type='gan', in_nc=3, out_nc=3, nf=64, nb=23): |
||||
super(ESRGANet, self).__init__() |
||||
self.min_max = '(0., 1.)' |
||||
self.generator = { |
||||
'name': 'RRDBNet', |
||||
'in_nc': in_nc, |
||||
'out_nc': out_nc, |
||||
'nf': nf, |
||||
'nb': nb |
||||
} |
||||
|
||||
if loss_type == 'gan': |
||||
# 定义损失函数 |
||||
self.pixel_criterion = {'name': 'L1Loss', 'loss_weight': 0.01} |
||||
self.discriminator = { |
||||
'name': 'VGGDiscriminator128', |
||||
'in_channels': 3, |
||||
'num_feat': 64 |
||||
} |
||||
self.perceptual_criterion = { |
||||
'name': 'PerceptualLoss', |
||||
'layer_weights': { |
||||
'34': 1.0 |
||||
}, |
||||
'perceptual_weight': 1.0, |
||||
'style_weight': 0.0, |
||||
'norm_img': False |
||||
} |
||||
self.gan_criterion = { |
||||
'name': 'GANLoss', |
||||
'gan_mode': 'vanilla', |
||||
'loss_weight': 0.005 |
||||
} |
||||
# 定义模型 |
||||
self.model = { |
||||
'name': 'ESRGAN', |
||||
'generator': self.generator, |
||||
'discriminator': self.discriminator, |
||||
'pixel_criterion': self.pixel_criterion, |
||||
'perceptual_criterion': self.perceptual_criterion, |
||||
'gan_criterion': self.gan_criterion |
||||
} |
||||
self.optimizer = { |
||||
'optimG': { |
||||
'name': 'Adam', |
||||
'net_names': ['generator'], |
||||
'weight_decay': 0.0, |
||||
'beta1': 0.9, |
||||
'beta2': 0.99 |
||||
}, |
||||
'optimD': { |
||||
'name': 'Adam', |
||||
'net_names': ['discriminator'], |
||||
'weight_decay': 0.0, |
||||
'beta1': 0.9, |
||||
'beta2': 0.99 |
||||
} |
||||
} |
||||
self.lr_scheduler = { |
||||
'name': 'MultiStepDecay', |
||||
'milestones': [50000, 100000, 200000, 300000], |
||||
'gamma': 0.5 |
||||
} |
||||
else: |
||||
self.pixel_criterion = {'name': 'L1Loss'} |
||||
self.model = { |
||||
'name': 'BaseSRModel', |
||||
'generator': self.generator, |
||||
'pixel_criterion': self.pixel_criterion |
||||
} |
||||
self.optimizer = { |
||||
'name': 'Adam', |
||||
'net_names': ['generator'], |
||||
'beta1': 0.9, |
||||
'beta2': 0.99 |
||||
} |
||||
self.lr_scheduler = { |
||||
'name': 'CosineAnnealingRestartLR', |
||||
'eta_min': 1e-07 |
||||
} |
@ -0,0 +1,81 @@ |
||||
import os |
||||
import sys |
||||
sys.path.append(os.path.abspath('../PaddleRS')) |
||||
|
||||
import paddle |
||||
import paddlers as pdrs |
||||
|
||||
if __name__ == "__main__": |
||||
|
||||
# 定义训练和验证时的transforms |
||||
train_transforms = pdrs.datasets.ComposeTrans( |
||||
input_keys=['lq', 'gt'], |
||||
output_keys=['lq', 'lqx2', 'gt'], |
||||
pipelines=[{ |
||||
'name': 'SRPairedRandomCrop', |
||||
'gt_patch_size': 192, |
||||
'scale': 4, |
||||
'scale_list': True |
||||
}, { |
||||
'name': 'PairedRandomHorizontalFlip' |
||||
}, { |
||||
'name': 'PairedRandomVerticalFlip' |
||||
}, { |
||||
'name': 'PairedRandomTransposeHW' |
||||
}, { |
||||
'name': 'Transpose' |
||||
}, { |
||||
'name': 'Normalize', |
||||
'mean': [0.0, 0.0, 0.0], |
||||
'std': [1.0, 1.0, 1.0] |
||||
}]) |
||||
|
||||
test_transforms = pdrs.datasets.ComposeTrans( |
||||
input_keys=['lq', 'gt'], |
||||
output_keys=['lq', 'gt'], |
||||
pipelines=[{ |
||||
'name': 'Transpose' |
||||
}, { |
||||
'name': 'Normalize', |
||||
'mean': [0.0, 0.0, 0.0], |
||||
'std': [1.0, 1.0, 1.0] |
||||
}]) |
||||
|
||||
# 定义训练集 |
||||
train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径 |
||||
train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径 |
||||
num_workers = 4 |
||||
batch_size = 8 |
||||
scale = 4 |
||||
train_dataset = pdrs.datasets.SRdataset( |
||||
mode='train', |
||||
gt_floder=train_gt_floder, |
||||
lq_floder=train_lq_floder, |
||||
transforms=train_transforms(), |
||||
scale=scale, |
||||
num_workers=num_workers, |
||||
batch_size=batch_size) |
||||
train_dict = train_dataset() |
||||
|
||||
# 定义测试集 |
||||
test_gt_floder = r"../work/RSdata_for_SR/test_HR" |
||||
test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4" |
||||
test_dataset = pdrs.datasets.SRdataset( |
||||
mode='test', |
||||
gt_floder=test_gt_floder, |
||||
lq_floder=test_lq_floder, |
||||
transforms=test_transforms(), |
||||
scale=scale) |
||||
|
||||
# 初始化模型,可以对网络结构的参数进行调整 |
||||
model = pdrs.tasks.DRNet( |
||||
n_blocks=30, n_feats=16, n_colors=3, rgb_range=255, negval=0.2) |
||||
|
||||
model.train( |
||||
total_iters=100000, |
||||
train_dataset=train_dataset(), |
||||
test_dataset=test_dataset(), |
||||
output_dir='output_dir', |
||||
validate=5000, |
||||
snapshot=5000, |
||||
lr_rate=0.0001) |
@ -0,0 +1,80 @@ |
||||
import os |
||||
import sys |
||||
sys.path.append(os.path.abspath('../PaddleRS')) |
||||
|
||||
import paddlers as pdrs |
||||
|
||||
# 定义训练和验证时的transforms |
||||
train_transforms = pdrs.datasets.ComposeTrans( |
||||
input_keys=['lq', 'gt'], |
||||
output_keys=['lq', 'gt'], |
||||
pipelines=[{ |
||||
'name': 'SRPairedRandomCrop', |
||||
'gt_patch_size': 128, |
||||
'scale': 4 |
||||
}, { |
||||
'name': 'PairedRandomHorizontalFlip' |
||||
}, { |
||||
'name': 'PairedRandomVerticalFlip' |
||||
}, { |
||||
'name': 'PairedRandomTransposeHW' |
||||
}, { |
||||
'name': 'Transpose' |
||||
}, { |
||||
'name': 'Normalize', |
||||
'mean': [0.0, 0.0, 0.0], |
||||
'std': [255.0, 255.0, 255.0] |
||||
}]) |
||||
|
||||
test_transforms = pdrs.datasets.ComposeTrans( |
||||
input_keys=['lq', 'gt'], |
||||
output_keys=['lq', 'gt'], |
||||
pipelines=[{ |
||||
'name': 'Transpose' |
||||
}, { |
||||
'name': 'Normalize', |
||||
'mean': [0.0, 0.0, 0.0], |
||||
'std': [255.0, 255.0, 255.0] |
||||
}]) |
||||
|
||||
# 定义训练集 |
||||
train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径 |
||||
train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径 |
||||
num_workers = 6 |
||||
batch_size = 32 |
||||
scale = 4 |
||||
train_dataset = pdrs.datasets.SRdataset( |
||||
mode='train', |
||||
gt_floder=train_gt_floder, |
||||
lq_floder=train_lq_floder, |
||||
transforms=train_transforms(), |
||||
scale=scale, |
||||
num_workers=num_workers, |
||||
batch_size=batch_size) |
||||
|
||||
# 定义测试集 |
||||
test_gt_floder = r"../work/RSdata_for_SR/test_HR" |
||||
test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4" |
||||
test_dataset = pdrs.datasets.SRdataset( |
||||
mode='test', |
||||
gt_floder=test_gt_floder, |
||||
lq_floder=test_lq_floder, |
||||
transforms=test_transforms(), |
||||
scale=scale) |
||||
|
||||
# 初始化模型,可以对网络结构的参数进行调整 |
||||
# 若loss_type='gan' 使用感知损失、对抗损失和像素损失 |
||||
# 若loss_type = 'pixel' 只使用像素损失 |
||||
model = pdrs.tasks.ESRGANet(loss_type='pixel') |
||||
|
||||
model.train( |
||||
total_iters=1000000, |
||||
train_dataset=train_dataset(), |
||||
test_dataset=test_dataset(), |
||||
output_dir='output_dir', |
||||
validate=5000, |
||||
snapshot=5000, |
||||
log=100, |
||||
lr_rate=0.0001, |
||||
periods=[250000, 250000, 250000, 250000], |
||||
restart_weights=[1, 1, 1, 1]) |
@ -0,0 +1,78 @@ |
||||
import os |
||||
import sys |
||||
sys.path.append(os.path.abspath('../PaddleRS')) |
||||
|
||||
import paddlers as pdrs |
||||
|
||||
# 定义训练和验证时的transforms |
||||
train_transforms = pdrs.datasets.ComposeTrans( |
||||
input_keys=['lq', 'gt'], |
||||
output_keys=['lq', 'gt'], |
||||
pipelines=[{ |
||||
'name': 'SRPairedRandomCrop', |
||||
'gt_patch_size': 192, |
||||
'scale': 4 |
||||
}, { |
||||
'name': 'PairedRandomHorizontalFlip' |
||||
}, { |
||||
'name': 'PairedRandomVerticalFlip' |
||||
}, { |
||||
'name': 'PairedRandomTransposeHW' |
||||
}, { |
||||
'name': 'Transpose' |
||||
}, { |
||||
'name': 'Normalize', |
||||
'mean': [0.0, 0.0, 0.0], |
||||
'std': [255.0, 255.0, 255.0] |
||||
}]) |
||||
|
||||
test_transforms = pdrs.datasets.ComposeTrans( |
||||
input_keys=['lq', 'gt'], |
||||
output_keys=['lq', 'gt'], |
||||
pipelines=[{ |
||||
'name': 'Transpose' |
||||
}, { |
||||
'name': 'Normalize', |
||||
'mean': [0.0, 0.0, 0.0], |
||||
'std': [255.0, 255.0, 255.0] |
||||
}]) |
||||
|
||||
# 定义训练集 |
||||
train_gt_floder = r"../work/RSdata_for_SR/trian_HR" # 高分辨率影像所在路径 |
||||
train_lq_floder = r"../work/RSdata_for_SR/train_LR/x4" # 低分辨率影像所在路径 |
||||
num_workers = 4 |
||||
batch_size = 16 |
||||
scale = 4 |
||||
train_dataset = pdrs.datasets.SRdataset( |
||||
mode='train', |
||||
gt_floder=train_gt_floder, |
||||
lq_floder=train_lq_floder, |
||||
transforms=train_transforms(), |
||||
scale=scale, |
||||
num_workers=num_workers, |
||||
batch_size=batch_size) |
||||
|
||||
# 定义测试集 |
||||
test_gt_floder = r"../work/RSdata_for_SR/test_HR" |
||||
test_lq_floder = r"../work/RSdata_for_SR/test_LR/x4" |
||||
test_dataset = pdrs.datasets.SRdataset( |
||||
mode='test', |
||||
gt_floder=test_gt_floder, |
||||
lq_floder=test_lq_floder, |
||||
transforms=test_transforms(), |
||||
scale=scale) |
||||
|
||||
# 初始化模型,可以对网络结构的参数进行调整 |
||||
model = pdrs.tasks.LESRCNNet(scale=4, multi_scale=False, group=1) |
||||
|
||||
model.train( |
||||
total_iters=1000000, |
||||
train_dataset=train_dataset(), |
||||
test_dataset=test_dataset(), |
||||
output_dir='output_dir', |
||||
validate=5000, |
||||
snapshot=5000, |
||||
log=100, |
||||
lr_rate=0.0001, |
||||
periods=[250000, 250000, 250000, 250000], |
||||
restart_weights=[1, 1, 1, 1]) |
Loading…
Reference in new issue