add rcan model for sr

own
kongdebug 3 years ago
parent 343f646f7d
commit 33b5f0fd3e
  1. 5
      paddlers/custom_models/gan/__init__.py
  2. 15
      paddlers/custom_models/gan/generators/__init__.py
  3. 25
      paddlers/custom_models/gan/generators/builder.py
  4. 190
      paddlers/custom_models/gan/generators/rcan.py
  5. 93
      paddlers/custom_models/gan/rcan_model.py
  6. 34
      paddlers/tasks/imagerestorer.py

@ -11,3 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
<<<<<<< HEAD
from .rcan_model import RCANModel
=======
>>>>>>> 343f646f7dabf2ff08d80fab4ac5a37511260bd2

@ -0,0 +1,15 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .rcan import RCAN

@ -0,0 +1,25 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from ....models.ppgan.utils.registry import Registry
GENERATORS = Registry("GENERATOR")
def build_generator(cfg):
cfg_copy = copy.deepcopy(cfg)
name = cfg_copy.pop('name')
generator = GENERATORS.get(name)(**cfg_copy)
return generator

@ -0,0 +1,190 @@
# base on https://github.com/kongdebug/RCAN-Paddle
import math
import paddle
import paddle.nn as nn
from .builder import GENERATORS
def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2D(
in_channels,
out_channels,
kernel_size,
padding=(kernel_size // 2),
bias_attr=bias)
class MeanShift(nn.Conv2D):
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = paddle.to_tensor(rgb_std)
self.weight.set_value(paddle.eye(3).reshape([3, 3, 1, 1]))
self.weight.set_value(self.weight / (std.reshape([3, 1, 1, 1])))
mean = paddle.to_tensor(rgb_mean)
self.bias.set_value(sign * rgb_range * mean / std)
self.weight.trainable = False
self.bias.trainable = False
## Channel Attention (CA) Layer
class CALayer(nn.Layer):
def __init__(self, channel, reduction=16):
super(CALayer, self).__init__()
# global average pooling: feature --> point
self.avg_pool = nn.AdaptiveAvgPool2D(1)
# feature channel downscale and upscale --> channel weight
self.conv_du = nn.Sequential(
nn.Conv2D(
channel, channel // reduction, 1, padding=0, bias_attr=True),
nn.ReLU(),
nn.Conv2D(
channel // reduction, channel, 1, padding=0, bias_attr=True),
nn.Sigmoid())
def forward(self, x):
y = self.avg_pool(x)
y = self.conv_du(y)
return x * y
class RCAB(nn.Layer):
def __init__(self,
conv,
n_feat,
kernel_size,
reduction=16,
bias=True,
bn=False,
act=nn.ReLU(),
res_scale=1):
super(RCAB, self).__init__()
modules_body = []
for i in range(2):
modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
if bn: modules_body.append(nn.BatchNorm2D(n_feat))
if i == 0: modules_body.append(act)
modules_body.append(CALayer(n_feat, reduction))
self.body = nn.Sequential(*modules_body)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x)
res += x
return res
## Residual Group (RG)
class ResidualGroup(nn.Layer):
def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale,
n_resblocks):
super(ResidualGroup, self).__init__()
modules_body = []
modules_body = [
RCAB(
conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(), res_scale=1) \
for _ in range(n_resblocks)]
modules_body.append(conv(n_feat, n_feat, kernel_size))
self.body = nn.Sequential(*modules_body)
def forward(self, x):
res = self.body(x)
res += x
return res
class Upsampler(nn.Sequential):
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
m = []
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
for _ in range(int(math.log(scale, 2))):
m.append(conv(n_feats, 4 * n_feats, 3, bias))
m.append(nn.PixelShuffle(2))
if bn: m.append(nn.BatchNorm2D(n_feats))
if act == 'relu':
m.append(nn.ReLU())
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
elif scale == 3:
m.append(conv(n_feats, 9 * n_feats, 3, bias))
m.append(nn.PixelShuffle(3))
if bn: m.append(nn.BatchNorm2D(n_feats))
if act == 'relu':
m.append(nn.ReLU())
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
else:
raise NotImplementedError
super(Upsampler, self).__init__(*m)
@GENERATORS.register()
class RCAN(nn.Layer):
def __init__(
self,
scale,
n_resgroups,
n_resblocks,
n_feats=64,
n_colors=3,
rgb_range=255,
kernel_size=3,
reduction=16,
conv=default_conv, ):
super(RCAN, self).__init__()
self.scale = scale
act = nn.ReLU()
n_resgroups = n_resgroups
n_resblocks = n_resblocks
n_feats = n_feats
kernel_size = kernel_size
reduction = reduction
scale = scale
act = nn.ReLU()
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = MeanShift(rgb_range, rgb_mean, rgb_std)
# define head module
modules_head = [conv(n_colors, n_feats, kernel_size)]
# define body module
modules_body = [
ResidualGroup(
conv, n_feats, kernel_size, reduction, act=act, res_scale= 1, n_resblocks=n_resblocks) \
for _ in range(n_resgroups)]
modules_body.append(conv(n_feats, n_feats, kernel_size))
# define tail module
modules_tail = [
Upsampler(
conv, scale, n_feats, act=False),
conv(n_feats, n_colors, kernel_size)
]
self.head = nn.Sequential(*modules_head)
self.body = nn.Sequential(*modules_body)
self.tail = nn.Sequential(*modules_tail)
self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1)
def forward(self, x):
x = self.sub_mean(x)
x = self.head(x)
res = self.body(x)
res += x
x = self.tail(res)
x = self.add_mean(x)
return x

@ -0,0 +1,93 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from .generators.builder import build_generator
from ...models.ppgan.models.criterions.builder import build_criterion
from ...models.ppgan.models.base_model import BaseModel
from ...models.ppgan.models.builder import MODELS
from ...models.ppgan.utils.visual import tensor2img
from ...models.ppgan.modules.init import reset_parameters
@MODELS.register()
class RCANModel(BaseModel):
"""Base SR model for single image super-resolution.
"""
def __init__(self, generator, pixel_criterion=None, use_init_weight=False):
"""
Args:
generator (dict): config of generator.
pixel_criterion (dict): config of pixel criterion.
"""
super(RCANModel, self).__init__()
self.nets['generator'] = build_generator(generator)
if pixel_criterion:
self.pixel_criterion = build_criterion(pixel_criterion)
if use_init_weight:
init_sr_weight(self.nets['generator'])
def setup_input(self, input):
self.lq = paddle.to_tensor(input['lq'])
self.visual_items['lq'] = self.lq
if 'gt' in input:
self.gt = paddle.to_tensor(input['gt'])
self.visual_items['gt'] = self.gt
self.image_paths = input['lq_path']
def forward(self):
pass
def train_iter(self, optims=None):
optims['optim'].clear_grad()
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output
# pixel loss
loss_pixel = self.pixel_criterion(self.output, self.gt)
self.losses['loss_pixel'] = loss_pixel
loss_pixel.backward()
optims['optim'].step()
def test_iter(self, metrics=None):
self.nets['generator'].eval()
with paddle.no_grad():
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output
self.nets['generator'].train()
out_img = []
gt_img = []
for out_tensor, gt_tensor in zip(self.output, self.gt):
out_img.append(tensor2img(out_tensor, (0., 255.)))
gt_img.append(tensor2img(gt_tensor, (0., 255.)))
if metrics is not None:
for metric in metrics.values():
metric.update(out_img, gt_img)
def init_sr_weight(net):
def reset_func(m):
if hasattr(m, 'weight') and (
not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))):
reset_parameters(m)
net.apply(reset_func)

@ -751,3 +751,37 @@ class ESRGANet(BasicSRNet):
'name': 'CosineAnnealingRestartLR',
'eta_min': 1e-07
}
# RCAN模型训练
class RCANet(BasicSRNet):
def __init__(
self,
scale=2,
n_resgroups=10,
n_resblocks=20, ):
super(RCANet, self).__init__()
self.min_max = '(0., 255.)'
self.generator = {
'name': 'RCAN',
'scale': scale,
'n_resgroups': n_resgroups,
'n_resblocks': n_resblocks
}
self.pixel_criterion = {'name': 'L1Loss'}
self.model = {
'name': 'RCANModel',
'generator': self.generator,
'pixel_criterion': self.pixel_criterion
}
self.optimizer = {
'name': 'Adam',
'net_names': ['generator'],
'beta1': 0.9,
'beta2': 0.99
}
self.lr_scheduler = {
'name': 'MultiStepDecay',
'milestones': [250000, 500000, 750000, 1000000],
'gamma': 0.5
}

Loading…
Cancel
Save