From 33b5f0fd3e8b5c57dca946a67e5395ac3c51a010 Mon Sep 17 00:00:00 2001 From: kongdebug <583816984@qq.com> Date: Sun, 3 Apr 2022 01:20:35 +0800 Subject: [PATCH] add rcan model for sr --- paddlers/custom_models/gan/__init__.py | 5 + .../custom_models/gan/generators/__init__.py | 15 ++ .../custom_models/gan/generators/builder.py | 25 +++ paddlers/custom_models/gan/generators/rcan.py | 190 ++++++++++++++++++ paddlers/custom_models/gan/rcan_model.py | 93 +++++++++ paddlers/tasks/imagerestorer.py | 34 ++++ 6 files changed, 362 insertions(+) create mode 100644 paddlers/custom_models/gan/generators/__init__.py create mode 100644 paddlers/custom_models/gan/generators/builder.py create mode 100644 paddlers/custom_models/gan/generators/rcan.py create mode 100644 paddlers/custom_models/gan/rcan_model.py diff --git a/paddlers/custom_models/gan/__init__.py b/paddlers/custom_models/gan/__init__.py index c18cdef..49a16a7 100644 --- a/paddlers/custom_models/gan/__init__.py +++ b/paddlers/custom_models/gan/__init__.py @@ -11,3 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +<<<<<<< HEAD + +from .rcan_model import RCANModel +======= +>>>>>>> 343f646f7dabf2ff08d80fab4ac5a37511260bd2 diff --git a/paddlers/custom_models/gan/generators/__init__.py b/paddlers/custom_models/gan/generators/__init__.py new file mode 100644 index 0000000..78caf7e --- /dev/null +++ b/paddlers/custom_models/gan/generators/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .rcan import RCAN diff --git a/paddlers/custom_models/gan/generators/builder.py b/paddlers/custom_models/gan/generators/builder.py new file mode 100644 index 0000000..b64766f --- /dev/null +++ b/paddlers/custom_models/gan/generators/builder.py @@ -0,0 +1,25 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from ....models.ppgan.utils.registry import Registry + +GENERATORS = Registry("GENERATOR") + + +def build_generator(cfg): + cfg_copy = copy.deepcopy(cfg) + name = cfg_copy.pop('name') + generator = GENERATORS.get(name)(**cfg_copy) + return generator diff --git a/paddlers/custom_models/gan/generators/rcan.py b/paddlers/custom_models/gan/generators/rcan.py new file mode 100644 index 0000000..6724c10 --- /dev/null +++ b/paddlers/custom_models/gan/generators/rcan.py @@ -0,0 +1,190 @@ +# base on https://github.com/kongdebug/RCAN-Paddle +import math +import paddle +import paddle.nn as nn + +from .builder import GENERATORS + + +def default_conv(in_channels, out_channels, kernel_size, bias=True): + return nn.Conv2D( + in_channels, + out_channels, + kernel_size, + padding=(kernel_size // 2), + bias_attr=bias) + + +class MeanShift(nn.Conv2D): + def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = paddle.to_tensor(rgb_std) + self.weight.set_value(paddle.eye(3).reshape([3, 3, 1, 1])) + self.weight.set_value(self.weight / (std.reshape([3, 1, 1, 1]))) + + mean = paddle.to_tensor(rgb_mean) + self.bias.set_value(sign * rgb_range * mean / std) + + self.weight.trainable = False + self.bias.trainable = False + + +## Channel Attention (CA) Layer +class CALayer(nn.Layer): + def __init__(self, channel, reduction=16): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool2D(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv2D( + channel, channel // reduction, 1, padding=0, bias_attr=True), + nn.ReLU(), + nn.Conv2D( + channel // reduction, channel, 1, padding=0, bias_attr=True), + nn.Sigmoid()) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + + +class RCAB(nn.Layer): + def __init__(self, + conv, + n_feat, + kernel_size, + reduction=16, + bias=True, + bn=False, + act=nn.ReLU(), + res_scale=1): + super(RCAB, self).__init__() + modules_body = [] + for i in range(2): + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + if bn: modules_body.append(nn.BatchNorm2D(n_feat)) + if i == 0: modules_body.append(act) + modules_body.append(CALayer(n_feat, reduction)) + self.body = nn.Sequential(*modules_body) + self.res_scale = res_scale + + def forward(self, x): + res = self.body(x) + res += x + return res + + +## Residual Group (RG) +class ResidualGroup(nn.Layer): + def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, + n_resblocks): + super(ResidualGroup, self).__init__() + modules_body = [] + modules_body = [ + RCAB( + conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(), res_scale=1) \ + for _ in range(n_resblocks)] + modules_body.append(conv(n_feat, n_feat, kernel_size)) + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) + res += x + return res + + +class Upsampler(nn.Sequential): + def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): + m = [] + if (scale & (scale - 1)) == 0: # Is scale = 2^n? + for _ in range(int(math.log(scale, 2))): + m.append(conv(n_feats, 4 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(2)) + if bn: m.append(nn.BatchNorm2D(n_feats)) + + if act == 'relu': + m.append(nn.ReLU()) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + + elif scale == 3: + m.append(conv(n_feats, 9 * n_feats, 3, bias)) + m.append(nn.PixelShuffle(3)) + if bn: m.append(nn.BatchNorm2D(n_feats)) + + if act == 'relu': + m.append(nn.ReLU()) + elif act == 'prelu': + m.append(nn.PReLU(n_feats)) + else: + raise NotImplementedError + + super(Upsampler, self).__init__(*m) + + +@GENERATORS.register() +class RCAN(nn.Layer): + def __init__( + self, + scale, + n_resgroups, + n_resblocks, + n_feats=64, + n_colors=3, + rgb_range=255, + kernel_size=3, + reduction=16, + conv=default_conv, ): + super(RCAN, self).__init__() + self.scale = scale + act = nn.ReLU() + + n_resgroups = n_resgroups + n_resblocks = n_resblocks + n_feats = n_feats + kernel_size = kernel_size + reduction = reduction + scale = scale + act = nn.ReLU() + + rgb_mean = (0.4488, 0.4371, 0.4040) + rgb_std = (1.0, 1.0, 1.0) + self.sub_mean = MeanShift(rgb_range, rgb_mean, rgb_std) + + # define head module + modules_head = [conv(n_colors, n_feats, kernel_size)] + + # define body module + modules_body = [ + ResidualGroup( + conv, n_feats, kernel_size, reduction, act=act, res_scale= 1, n_resblocks=n_resblocks) \ + for _ in range(n_resgroups)] + + modules_body.append(conv(n_feats, n_feats, kernel_size)) + + # define tail module + modules_tail = [ + Upsampler( + conv, scale, n_feats, act=False), + conv(n_feats, n_colors, kernel_size) + ] + + self.head = nn.Sequential(*modules_head) + self.body = nn.Sequential(*modules_body) + self.tail = nn.Sequential(*modules_tail) + + self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1) + + def forward(self, x): + x = self.sub_mean(x) + x = self.head(x) + + res = self.body(x) + res += x + + x = self.tail(res) + x = self.add_mean(x) + + return x diff --git a/paddlers/custom_models/gan/rcan_model.py b/paddlers/custom_models/gan/rcan_model.py new file mode 100644 index 0000000..5ad1e3e --- /dev/null +++ b/paddlers/custom_models/gan/rcan_model.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +from .generators.builder import build_generator +from ...models.ppgan.models.criterions.builder import build_criterion +from ...models.ppgan.models.base_model import BaseModel +from ...models.ppgan.models.builder import MODELS +from ...models.ppgan.utils.visual import tensor2img +from ...models.ppgan.modules.init import reset_parameters + + +@MODELS.register() +class RCANModel(BaseModel): + """Base SR model for single image super-resolution. + """ + + def __init__(self, generator, pixel_criterion=None, use_init_weight=False): + """ + Args: + generator (dict): config of generator. + pixel_criterion (dict): config of pixel criterion. + """ + super(RCANModel, self).__init__() + + self.nets['generator'] = build_generator(generator) + + if pixel_criterion: + self.pixel_criterion = build_criterion(pixel_criterion) + if use_init_weight: + init_sr_weight(self.nets['generator']) + + def setup_input(self, input): + self.lq = paddle.to_tensor(input['lq']) + self.visual_items['lq'] = self.lq + if 'gt' in input: + self.gt = paddle.to_tensor(input['gt']) + self.visual_items['gt'] = self.gt + self.image_paths = input['lq_path'] + + def forward(self): + pass + + def train_iter(self, optims=None): + optims['optim'].clear_grad() + + self.output = self.nets['generator'](self.lq) + self.visual_items['output'] = self.output + # pixel loss + loss_pixel = self.pixel_criterion(self.output, self.gt) + self.losses['loss_pixel'] = loss_pixel + + loss_pixel.backward() + optims['optim'].step() + + def test_iter(self, metrics=None): + self.nets['generator'].eval() + with paddle.no_grad(): + self.output = self.nets['generator'](self.lq) + self.visual_items['output'] = self.output + self.nets['generator'].train() + + out_img = [] + gt_img = [] + for out_tensor, gt_tensor in zip(self.output, self.gt): + out_img.append(tensor2img(out_tensor, (0., 255.))) + gt_img.append(tensor2img(gt_tensor, (0., 255.))) + + if metrics is not None: + for metric in metrics.values(): + metric.update(out_img, gt_img) + + +def init_sr_weight(net): + def reset_func(m): + if hasattr(m, 'weight') and ( + not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))): + reset_parameters(m) + + net.apply(reset_func) diff --git a/paddlers/tasks/imagerestorer.py b/paddlers/tasks/imagerestorer.py index d05447a..10ffb25 100644 --- a/paddlers/tasks/imagerestorer.py +++ b/paddlers/tasks/imagerestorer.py @@ -751,3 +751,37 @@ class ESRGANet(BasicSRNet): 'name': 'CosineAnnealingRestartLR', 'eta_min': 1e-07 } + + +# RCAN模型训练 +class RCANet(BasicSRNet): + def __init__( + self, + scale=2, + n_resgroups=10, + n_resblocks=20, ): + super(RCANet, self).__init__() + self.min_max = '(0., 255.)' + self.generator = { + 'name': 'RCAN', + 'scale': scale, + 'n_resgroups': n_resgroups, + 'n_resblocks': n_resblocks + } + self.pixel_criterion = {'name': 'L1Loss'} + self.model = { + 'name': 'RCANModel', + 'generator': self.generator, + 'pixel_criterion': self.pixel_criterion + } + self.optimizer = { + 'name': 'Adam', + 'net_names': ['generator'], + 'beta1': 0.9, + 'beta2': 0.99 + } + self.lr_scheduler = { + 'name': 'MultiStepDecay', + 'milestones': [250000, 500000, 750000, 1000000], + 'gamma': 0.5 + }