[Test] Add CE (#2)
parent
a373e11835
commit
feebf47002
161 changed files with 4566 additions and 1590 deletions
@ -0,0 +1,83 @@ |
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
import os.path as osp |
||||||
|
import copy |
||||||
|
|
||||||
|
from .base import BaseDataset |
||||||
|
from paddlers.utils import logging, get_encoding, norm_path, is_pic |
||||||
|
|
||||||
|
|
||||||
|
class ResDataset(BaseDataset): |
||||||
|
""" |
||||||
|
Dataset for image restoration tasks. |
||||||
|
|
||||||
|
Args: |
||||||
|
data_dir (str): Root directory of the dataset. |
||||||
|
file_list (str): Path of the file that contains relative paths of source and target image files. |
||||||
|
transforms (paddlers.transforms.Compose): Data preprocessing and data augmentation operators to apply. |
||||||
|
num_workers (int|str, optional): Number of processes used for data loading. If `num_workers` is 'auto', |
||||||
|
the number of workers will be automatically determined according to the number of CPU cores: If |
||||||
|
there are more than 16 cores,8 workers will be used. Otherwise, the number of workers will be half |
||||||
|
the number of CPU cores. Defaults: 'auto'. |
||||||
|
shuffle (bool, optional): Whether to shuffle the samples. Defaults to False. |
||||||
|
sr_factor (int|None, optional): Scaling factor of image super-resolution task. None for other image |
||||||
|
restoration tasks. Defaults to None. |
||||||
|
""" |
||||||
|
|
||||||
|
def __init__(self, |
||||||
|
data_dir, |
||||||
|
file_list, |
||||||
|
transforms, |
||||||
|
num_workers='auto', |
||||||
|
shuffle=False, |
||||||
|
sr_factor=None): |
||||||
|
super(ResDataset, self).__init__(data_dir, None, transforms, |
||||||
|
num_workers, shuffle) |
||||||
|
self.batch_transforms = None |
||||||
|
self.file_list = list() |
||||||
|
|
||||||
|
with open(file_list, encoding=get_encoding(file_list)) as f: |
||||||
|
for line in f: |
||||||
|
items = line.strip().split() |
||||||
|
if len(items) > 2: |
||||||
|
raise ValueError( |
||||||
|
"A space is defined as the delimiter to separate the source and target image path, " \ |
||||||
|
"so the space cannot be in the source image or target image path, but the line[{}] of " \ |
||||||
|
" file_list[{}] has a space in the two paths.".format(line, file_list)) |
||||||
|
items[0] = norm_path(items[0]) |
||||||
|
items[1] = norm_path(items[1]) |
||||||
|
full_path_im = osp.join(data_dir, items[0]) |
||||||
|
full_path_tar = osp.join(data_dir, items[1]) |
||||||
|
if not is_pic(full_path_im) or not is_pic(full_path_tar): |
||||||
|
continue |
||||||
|
if not osp.exists(full_path_im): |
||||||
|
raise IOError("Source image file {} does not exist!".format( |
||||||
|
full_path_im)) |
||||||
|
if not osp.exists(full_path_tar): |
||||||
|
raise IOError("Target image file {} does not exist!".format( |
||||||
|
full_path_tar)) |
||||||
|
sample = { |
||||||
|
'image': full_path_im, |
||||||
|
'target': full_path_tar, |
||||||
|
} |
||||||
|
if sr_factor is not None: |
||||||
|
sample['sr_factor'] = sr_factor |
||||||
|
self.file_list.append(sample) |
||||||
|
self.num_samples = len(self.file_list) |
||||||
|
logging.info("{} samples in file {}".format( |
||||||
|
len(self.file_list), file_list)) |
||||||
|
|
||||||
|
def __len__(self): |
||||||
|
return len(self.file_list) |
@ -1,99 +0,0 @@ |
|||||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
|
||||||
# |
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
||||||
# you may not use this file except in compliance with the License. |
|
||||||
# You may obtain a copy of the License at |
|
||||||
# |
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0 |
|
||||||
# |
|
||||||
# Unless required by applicable law or agreed to in writing, software |
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS, |
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
||||||
# See the License for the specific language governing permissions and |
|
||||||
# limitations under the License. |
|
||||||
|
|
||||||
|
|
||||||
# 超分辨率数据集定义 |
|
||||||
class SRdataset(object): |
|
||||||
def __init__(self, |
|
||||||
mode, |
|
||||||
gt_floder, |
|
||||||
lq_floder, |
|
||||||
transforms, |
|
||||||
scale, |
|
||||||
num_workers=4, |
|
||||||
batch_size=8): |
|
||||||
if mode == 'train': |
|
||||||
preprocess = [] |
|
||||||
preprocess.append({ |
|
||||||
'name': 'LoadImageFromFile', |
|
||||||
'key': 'lq' |
|
||||||
}) # 加载方式 |
|
||||||
preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'}) |
|
||||||
preprocess.append(transforms) # 变换方式 |
|
||||||
self.dataset = { |
|
||||||
'name': 'SRDataset', |
|
||||||
'gt_folder': gt_floder, |
|
||||||
'lq_folder': lq_floder, |
|
||||||
'num_workers': num_workers, |
|
||||||
'batch_size': batch_size, |
|
||||||
'scale': scale, |
|
||||||
'preprocess': preprocess |
|
||||||
} |
|
||||||
|
|
||||||
if mode == "test": |
|
||||||
preprocess = [] |
|
||||||
preprocess.append({'name': 'LoadImageFromFile', 'key': 'lq'}) |
|
||||||
preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'}) |
|
||||||
preprocess.append(transforms) |
|
||||||
self.dataset = { |
|
||||||
'name': 'SRDataset', |
|
||||||
'gt_folder': gt_floder, |
|
||||||
'lq_folder': lq_floder, |
|
||||||
'scale': scale, |
|
||||||
'preprocess': preprocess |
|
||||||
} |
|
||||||
|
|
||||||
def __call__(self): |
|
||||||
return self.dataset |
|
||||||
|
|
||||||
|
|
||||||
# 对定义的transforms处理方式组合,返回字典 |
|
||||||
class ComposeTrans(object): |
|
||||||
def __init__(self, input_keys, output_keys, pipelines): |
|
||||||
if not isinstance(pipelines, list): |
|
||||||
raise TypeError( |
|
||||||
'Type of transforms is invalid. Must be List, but received is {}' |
|
||||||
.format(type(pipelines))) |
|
||||||
if len(pipelines) < 1: |
|
||||||
raise ValueError( |
|
||||||
'Length of transforms must not be less than 1, but received is {}' |
|
||||||
.format(len(pipelines))) |
|
||||||
self.transforms = pipelines |
|
||||||
self.output_length = len(output_keys) # 当output_keys的长度为3时,是DRN训练 |
|
||||||
self.input_keys = input_keys |
|
||||||
self.output_keys = output_keys |
|
||||||
|
|
||||||
def __call__(self): |
|
||||||
pipeline = [] |
|
||||||
for op in self.transforms: |
|
||||||
if op['name'] == 'SRPairedRandomCrop': |
|
||||||
op['keys'] = ['image'] * 2 |
|
||||||
else: |
|
||||||
op['keys'] = ['image'] * self.output_length |
|
||||||
pipeline.append(op) |
|
||||||
if self.output_length == 2: |
|
||||||
transform_dict = { |
|
||||||
'name': 'Transforms', |
|
||||||
'input_keys': self.input_keys, |
|
||||||
'pipeline': pipeline |
|
||||||
} |
|
||||||
else: |
|
||||||
transform_dict = { |
|
||||||
'name': 'Transforms', |
|
||||||
'input_keys': self.input_keys, |
|
||||||
'output_keys': self.output_keys, |
|
||||||
'pipeline': pipeline |
|
||||||
} |
|
||||||
|
|
||||||
return transform_dict |
|
@ -0,0 +1,478 @@ |
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
import paddle |
||||||
|
import paddle.nn as nn |
||||||
|
import paddle.nn.functional as F |
||||||
|
|
||||||
|
from .layers import BasicConv, MaxPool2x2, Conv1x1, Conv3x3 |
||||||
|
|
||||||
|
bn_mom = 1 - 0.0003 |
||||||
|
|
||||||
|
|
||||||
|
class NLBlock(nn.Layer): |
||||||
|
def __init__(self, in_channels): |
||||||
|
super(NLBlock, self).__init__() |
||||||
|
self.conv_v = BasicConv( |
||||||
|
in_ch=in_channels, |
||||||
|
out_ch=in_channels, |
||||||
|
kernel_size=3, |
||||||
|
norm=nn.BatchNorm2D( |
||||||
|
in_channels, momentum=0.9)) |
||||||
|
self.W = BasicConv( |
||||||
|
in_ch=in_channels, |
||||||
|
out_ch=in_channels, |
||||||
|
kernel_size=3, |
||||||
|
norm=nn.BatchNorm2D( |
||||||
|
in_channels, momentum=0.9), |
||||||
|
act=nn.ReLU()) |
||||||
|
|
||||||
|
def forward(self, x): |
||||||
|
batch_size, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] |
||||||
|
value = self.conv_v(x) |
||||||
|
value = value.reshape([batch_size, c, value.shape[2] * value.shape[3]]) |
||||||
|
value = value.transpose([0, 2, 1]) # B * (H*W) * value_channels |
||||||
|
key = x.reshape([batch_size, c, h * w]) # B * key_channels * (H*W) |
||||||
|
query = x.reshape([batch_size, c, h * w]) |
||||||
|
query = query.transpose([0, 2, 1]) |
||||||
|
|
||||||
|
sim_map = paddle.matmul(query, key) # B * (H*W) * (H*W) |
||||||
|
sim_map = (c**-.5) * sim_map # B * (H*W) * (H*W) |
||||||
|
sim_map = nn.functional.softmax(sim_map, axis=-1) # B * (H*W) * (H*W) |
||||||
|
|
||||||
|
context = paddle.matmul(sim_map, value) |
||||||
|
context = context.transpose([0, 2, 1]) |
||||||
|
context = context.reshape([batch_size, c, *x.shape[2:]]) |
||||||
|
context = self.W(context) |
||||||
|
|
||||||
|
return context |
||||||
|
|
||||||
|
|
||||||
|
class NLFPN(nn.Layer): |
||||||
|
""" Non-local feature parymid network""" |
||||||
|
|
||||||
|
def __init__(self, in_dim, reduction=True): |
||||||
|
super(NLFPN, self).__init__() |
||||||
|
if reduction: |
||||||
|
self.reduction = BasicConv( |
||||||
|
in_ch=in_dim, |
||||||
|
out_ch=in_dim // 4, |
||||||
|
kernel_size=1, |
||||||
|
norm=nn.BatchNorm2D( |
||||||
|
in_dim // 4, momentum=bn_mom), |
||||||
|
act=nn.ReLU()) |
||||||
|
self.re_reduction = BasicConv( |
||||||
|
in_ch=in_dim // 4, |
||||||
|
out_ch=in_dim, |
||||||
|
kernel_size=1, |
||||||
|
norm=nn.BatchNorm2D( |
||||||
|
in_dim, momentum=bn_mom), |
||||||
|
act=nn.ReLU()) |
||||||
|
in_dim = in_dim // 4 |
||||||
|
else: |
||||||
|
self.reduction = None |
||||||
|
self.re_reduction = None |
||||||
|
self.conv_e1 = BasicConv( |
||||||
|
in_dim, |
||||||
|
in_dim, |
||||||
|
kernel_size=3, |
||||||
|
norm=nn.BatchNorm2D( |
||||||
|
in_dim, momentum=bn_mom), |
||||||
|
act=nn.ReLU()) |
||||||
|
self.conv_e2 = BasicConv( |
||||||
|
in_dim, |
||||||
|
in_dim * 2, |
||||||
|
kernel_size=3, |
||||||
|
norm=nn.BatchNorm2D( |
||||||
|
in_dim * 2, momentum=bn_mom), |
||||||
|
act=nn.ReLU()) |
||||||
|
self.conv_e3 = BasicConv( |
||||||
|
in_dim * 2, |
||||||
|
in_dim * 4, |
||||||
|
kernel_size=3, |
||||||
|
norm=nn.BatchNorm2D( |
||||||
|
in_dim * 4, momentum=bn_mom), |
||||||
|
act=nn.ReLU()) |
||||||
|
self.conv_d1 = BasicConv( |
||||||
|
in_dim, |
||||||
|
in_dim, |
||||||
|
kernel_size=3, |
||||||
|
norm=nn.BatchNorm2D( |
||||||
|
in_dim, momentum=bn_mom), |
||||||
|
act=nn.ReLU()) |
||||||
|
self.conv_d2 = BasicConv( |
||||||
|
in_dim * 2, |
||||||
|
in_dim, |
||||||
|
kernel_size=3, |
||||||
|
norm=nn.BatchNorm2D( |
||||||
|
in_dim, momentum=bn_mom), |
||||||
|
act=nn.ReLU()) |
||||||
|
self.conv_d3 = BasicConv( |
||||||
|
in_dim * 4, |
||||||
|
in_dim * 2, |
||||||
|
kernel_size=3, |
||||||
|
norm=nn.BatchNorm2D( |
||||||
|
in_dim * 2, momentum=bn_mom), |
||||||
|
act=nn.ReLU()) |
||||||
|
self.nl3 = NLBlock(in_dim * 2) |
||||||
|
self.nl2 = NLBlock(in_dim) |
||||||
|
self.nl1 = NLBlock(in_dim) |
||||||
|
|
||||||
|
self.downsample_x2 = nn.MaxPool2D(stride=2, kernel_size=2) |
||||||
|
self.upsample_x2 = nn.UpsamplingBilinear2D(scale_factor=2) |
||||||
|
|
||||||
|
def forward(self, x): |
||||||
|
if self.reduction is not None: |
||||||
|
x = self.reduction(x) |
||||||
|
e1 = self.conv_e1(x) # C,H,W |
||||||
|
e2 = self.conv_e2(self.downsample_x2(e1)) # 2C,H/2,W/2 |
||||||
|
e3 = self.conv_e3(self.downsample_x2(e2)) # 4C,H/4,W/4 |
||||||
|
|
||||||
|
d3 = self.conv_d3(e3) # 2C,H/4,W/4 |
||||||
|
nl = self.nl3(d3) |
||||||
|
d3 = self.upsample_x2(paddle.multiply(d3, nl)) ##2C,H/2,W/2 |
||||||
|
d2 = self.conv_d2(e2 + d3) # C,H/2,W/2 |
||||||
|
nl = self.nl2(d2) |
||||||
|
d2 = self.upsample_x2(paddle.multiply(d2, nl)) # C,H,W |
||||||
|
d1 = self.conv_d1(e1 + d2) |
||||||
|
nl = self.nl1(d1) |
||||||
|
d1 = paddle.multiply(d1, nl) # C,H,W |
||||||
|
if self.re_reduction is not None: |
||||||
|
d1 = self.re_reduction(d1) |
||||||
|
|
||||||
|
return d1 |
||||||
|
|
||||||
|
|
||||||
|
class Cat(nn.Layer): |
||||||
|
def __init__(self, in_chn_high, in_chn_low, out_chn, upsample=False): |
||||||
|
super(Cat, self).__init__() |
||||||
|
self.do_upsample = upsample |
||||||
|
self.upsample = nn.Upsample(scale_factor=2, mode="nearest") |
||||||
|
self.conv2d = BasicConv( |
||||||
|
in_chn_high + in_chn_low, |
||||||
|
out_chn, |
||||||
|
kernel_size=1, |
||||||
|
norm=nn.BatchNorm2D( |
||||||
|
out_chn, momentum=bn_mom), |
||||||
|
act=nn.ReLU()) |
||||||
|
|
||||||
|
def forward(self, x, y): |
||||||
|
if self.do_upsample: |
||||||
|
x = self.upsample(x) |
||||||
|
|
||||||
|
x = paddle.concat((x, y), 1) |
||||||
|
|
||||||
|
return self.conv2d(x) |
||||||
|
|
||||||
|
|
||||||
|
class DoubleConv(nn.Layer): |
||||||
|
def __init__(self, in_chn, out_chn, stride=1, dilation=1): |
||||||
|
super(DoubleConv, self).__init__() |
||||||
|
self.conv = nn.Sequential( |
||||||
|
nn.Conv2D( |
||||||
|
in_chn, |
||||||
|
out_chn, |
||||||
|
kernel_size=3, |
||||||
|
stride=stride, |
||||||
|
dilation=dilation, |
||||||
|
padding=dilation), |
||||||
|
nn.BatchNorm2D( |
||||||
|
out_chn, momentum=bn_mom), |
||||||
|
nn.ReLU(), |
||||||
|
nn.Conv2D( |
||||||
|
out_chn, out_chn, kernel_size=3, stride=1, padding=1), |
||||||
|
nn.BatchNorm2D( |
||||||
|
out_chn, momentum=bn_mom), |
||||||
|
nn.ReLU()) |
||||||
|
|
||||||
|
def forward(self, x): |
||||||
|
x = self.conv(x) |
||||||
|
return x |
||||||
|
|
||||||
|
|
||||||
|
class SEModule(nn.Layer): |
||||||
|
def __init__(self, channels, reduction_channels): |
||||||
|
super(SEModule, self).__init__() |
||||||
|
self.fc1 = nn.Conv2D( |
||||||
|
channels, |
||||||
|
reduction_channels, |
||||||
|
kernel_size=1, |
||||||
|
padding=0, |
||||||
|
bias_attr=True) |
||||||
|
self.ReLU = nn.ReLU() |
||||||
|
self.fc2 = nn.Conv2D( |
||||||
|
reduction_channels, |
||||||
|
channels, |
||||||
|
kernel_size=1, |
||||||
|
padding=0, |
||||||
|
bias_attr=True) |
||||||
|
|
||||||
|
def forward(self, x): |
||||||
|
x_se = x.reshape( |
||||||
|
[x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).mean(-1).reshape( |
||||||
|
[x.shape[0], x.shape[1], 1, 1]) |
||||||
|
|
||||||
|
x_se = self.fc1(x_se) |
||||||
|
x_se = self.ReLU(x_se) |
||||||
|
x_se = self.fc2(x_se) |
||||||
|
return x * F.sigmoid(x_se) |
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Layer): |
||||||
|
expansion = 1 |
||||||
|
|
||||||
|
def __init__(self, |
||||||
|
inplanes, |
||||||
|
planes, |
||||||
|
downsample=None, |
||||||
|
use_se=False, |
||||||
|
stride=1, |
||||||
|
dilation=1): |
||||||
|
super(BasicBlock, self).__init__() |
||||||
|
first_planes = planes |
||||||
|
outplanes = planes * self.expansion |
||||||
|
|
||||||
|
self.conv1 = DoubleConv(inplanes, first_planes) |
||||||
|
self.conv2 = DoubleConv( |
||||||
|
first_planes, outplanes, stride=stride, dilation=dilation) |
||||||
|
self.se = SEModule(outplanes, planes // 4) if use_se else None |
||||||
|
self.downsample = MaxPool2x2() if downsample else None |
||||||
|
self.ReLU = nn.ReLU() |
||||||
|
|
||||||
|
def forward(self, x): |
||||||
|
out = self.conv1(x) |
||||||
|
residual = out |
||||||
|
out = self.conv2(out) |
||||||
|
|
||||||
|
if self.se is not None: |
||||||
|
out = self.se(out) |
||||||
|
|
||||||
|
if self.downsample is not None: |
||||||
|
residual = self.downsample(residual) |
||||||
|
|
||||||
|
out = out + residual |
||||||
|
out = self.ReLU(out) |
||||||
|
return out |
||||||
|
|
||||||
|
|
||||||
|
class DenseCatAdd(nn.Layer): |
||||||
|
def __init__(self, in_chn, out_chn): |
||||||
|
super(DenseCatAdd, self).__init__() |
||||||
|
self.conv1 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) |
||||||
|
self.conv2 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) |
||||||
|
self.conv3 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) |
||||||
|
self.conv_out = BasicConv( |
||||||
|
in_chn, |
||||||
|
out_chn, |
||||||
|
kernel_size=1, |
||||||
|
norm=nn.BatchNorm2D( |
||||||
|
out_chn, momentum=bn_mom), |
||||||
|
act=nn.ReLU()) |
||||||
|
|
||||||
|
def forward(self, x, y): |
||||||
|
x1 = self.conv1(x) |
||||||
|
x2 = self.conv2(x1) |
||||||
|
x3 = self.conv3(x2 + x1) |
||||||
|
|
||||||
|
y1 = self.conv1(y) |
||||||
|
y2 = self.conv2(y1) |
||||||
|
y3 = self.conv3(y2 + y1) |
||||||
|
|
||||||
|
return self.conv_out(x1 + x2 + x3 + y1 + y2 + y3) |
||||||
|
|
||||||
|
|
||||||
|
class DenseCatDiff(nn.Layer): |
||||||
|
def __init__(self, in_chn, out_chn): |
||||||
|
super(DenseCatDiff, self).__init__() |
||||||
|
self.conv1 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) |
||||||
|
self.conv2 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) |
||||||
|
self.conv3 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) |
||||||
|
self.conv_out = BasicConv( |
||||||
|
in_ch=in_chn, |
||||||
|
out_ch=out_chn, |
||||||
|
kernel_size=1, |
||||||
|
norm=nn.BatchNorm2D( |
||||||
|
out_chn, momentum=bn_mom), |
||||||
|
act=nn.ReLU()) |
||||||
|
|
||||||
|
def forward(self, x, y): |
||||||
|
x1 = self.conv1(x) |
||||||
|
x2 = self.conv2(x1) |
||||||
|
x3 = self.conv3(x2 + x1) |
||||||
|
|
||||||
|
y1 = self.conv1(y) |
||||||
|
y2 = self.conv2(y1) |
||||||
|
y3 = self.conv3(y2 + y1) |
||||||
|
out = self.conv_out(paddle.abs(x1 + x2 + x3 - y1 - y2 - y3)) |
||||||
|
return out |
||||||
|
|
||||||
|
|
||||||
|
class DFModule(nn.Layer): |
||||||
|
"""Dense connection-based feature fusion module""" |
||||||
|
|
||||||
|
def __init__(self, dim_in, dim_out, reduction=True): |
||||||
|
super(DFModule, self).__init__() |
||||||
|
if reduction: |
||||||
|
self.reduction = Conv1x1( |
||||||
|
dim_in, |
||||||
|
dim_in // 2, |
||||||
|
norm=nn.BatchNorm2D( |
||||||
|
dim_in // 2, momentum=bn_mom), |
||||||
|
act=nn.ReLU()) |
||||||
|
dim_in = dim_in // 2 |
||||||
|
else: |
||||||
|
self.reduction = None |
||||||
|
self.cat1 = DenseCatAdd(dim_in, dim_out) |
||||||
|
self.cat2 = DenseCatDiff(dim_in, dim_out) |
||||||
|
self.conv1 = Conv3x3( |
||||||
|
dim_out, |
||||||
|
dim_out, |
||||||
|
norm=nn.BatchNorm2D( |
||||||
|
dim_out, momentum=bn_mom), |
||||||
|
act=nn.ReLU()) |
||||||
|
|
||||||
|
def forward(self, x1, x2): |
||||||
|
if self.reduction is not None: |
||||||
|
x1 = self.reduction(x1) |
||||||
|
x2 = self.reduction(x2) |
||||||
|
x_add = self.cat1(x1, x2) |
||||||
|
x_diff = self.cat2(x1, x2) |
||||||
|
y = self.conv1(x_diff) + x_add |
||||||
|
return y |
||||||
|
|
||||||
|
|
||||||
|
class FCCDN(nn.Layer): |
||||||
|
""" |
||||||
|
The FCCDN implementation based on PaddlePaddle. |
||||||
|
|
||||||
|
The original article refers to |
||||||
|
Pan Chen, et al., "FCCDN: Feature Constraint Network for VHR Image Change Detection" |
||||||
|
(https://arxiv.org/pdf/2105.10860.pdf). |
||||||
|
|
||||||
|
Args: |
||||||
|
in_channels (int): Number of input channels. Default: 3. |
||||||
|
num_classes (int): Number of target classes. Default: 2. |
||||||
|
os (int): Number of output stride. Default: 16. |
||||||
|
use_se (bool): Whether to use SEModule. Default: True. |
||||||
|
""" |
||||||
|
|
||||||
|
def __init__(self, in_channels=3, num_classes=2, os=16, use_se=True): |
||||||
|
super(FCCDN, self).__init__() |
||||||
|
if os >= 16: |
||||||
|
dilation_list = [1, 1, 1, 1] |
||||||
|
stride_list = [2, 2, 2, 2] |
||||||
|
pool_list = [True, True, True, True] |
||||||
|
elif os == 8: |
||||||
|
dilation_list = [2, 1, 1, 1] |
||||||
|
stride_list = [1, 2, 2, 2] |
||||||
|
pool_list = [False, True, True, True] |
||||||
|
else: |
||||||
|
dilation_list = [2, 2, 1, 1] |
||||||
|
stride_list = [1, 1, 2, 2] |
||||||
|
pool_list = [False, False, True, True] |
||||||
|
se_list = [use_se, use_se, use_se, use_se] |
||||||
|
channel_list = [256, 128, 64, 32] |
||||||
|
# Encoder |
||||||
|
self.block1 = BasicBlock(in_channels, channel_list[3], pool_list[3], |
||||||
|
se_list[3], stride_list[3], dilation_list[3]) |
||||||
|
self.block2 = BasicBlock(channel_list[3], channel_list[2], pool_list[2], |
||||||
|
se_list[2], stride_list[2], dilation_list[2]) |
||||||
|
self.block3 = BasicBlock(channel_list[2], channel_list[1], pool_list[1], |
||||||
|
se_list[1], stride_list[1], dilation_list[1]) |
||||||
|
self.block4 = BasicBlock(channel_list[1], channel_list[0], pool_list[0], |
||||||
|
se_list[0], stride_list[0], dilation_list[0]) |
||||||
|
|
||||||
|
# Center |
||||||
|
self.center = NLFPN(channel_list[0], True) |
||||||
|
|
||||||
|
# Decoder |
||||||
|
self.decoder3 = Cat(channel_list[0], |
||||||
|
channel_list[1], |
||||||
|
channel_list[1], |
||||||
|
upsample=pool_list[0]) |
||||||
|
self.decoder2 = Cat(channel_list[1], |
||||||
|
channel_list[2], |
||||||
|
channel_list[2], |
||||||
|
upsample=pool_list[1]) |
||||||
|
self.decoder1 = Cat(channel_list[2], |
||||||
|
channel_list[3], |
||||||
|
channel_list[3], |
||||||
|
upsample=pool_list[2]) |
||||||
|
|
||||||
|
self.df1 = DFModule(channel_list[3], channel_list[3], True) |
||||||
|
self.df2 = DFModule(channel_list[2], channel_list[2], True) |
||||||
|
self.df3 = DFModule(channel_list[1], channel_list[1], True) |
||||||
|
self.df4 = DFModule(channel_list[0], channel_list[0], True) |
||||||
|
|
||||||
|
self.catc3 = Cat(channel_list[0], |
||||||
|
channel_list[1], |
||||||
|
channel_list[1], |
||||||
|
upsample=pool_list[0]) |
||||||
|
self.catc2 = Cat(channel_list[1], |
||||||
|
channel_list[2], |
||||||
|
channel_list[2], |
||||||
|
upsample=pool_list[1]) |
||||||
|
self.catc1 = Cat(channel_list[2], |
||||||
|
channel_list[3], |
||||||
|
channel_list[3], |
||||||
|
upsample=pool_list[2]) |
||||||
|
|
||||||
|
self.upsample_x2 = nn.Sequential( |
||||||
|
nn.Conv2D( |
||||||
|
channel_list[3], 8, kernel_size=3, stride=1, padding=1), |
||||||
|
nn.BatchNorm2D( |
||||||
|
8, momentum=bn_mom), |
||||||
|
nn.ReLU(), |
||||||
|
nn.UpsamplingBilinear2D(scale_factor=2)) |
||||||
|
|
||||||
|
self.conv_out = nn.Conv2D( |
||||||
|
8, num_classes, kernel_size=3, stride=1, padding=1) |
||||||
|
self.conv_out_class = nn.Conv2D( |
||||||
|
channel_list[3], 1, kernel_size=1, stride=1, padding=0) |
||||||
|
|
||||||
|
def forward(self, t1, t2): |
||||||
|
e1_1 = self.block1(t1) |
||||||
|
e2_1 = self.block2(e1_1) |
||||||
|
e3_1 = self.block3(e2_1) |
||||||
|
y1 = self.block4(e3_1) |
||||||
|
|
||||||
|
e1_2 = self.block1(t2) |
||||||
|
e2_2 = self.block2(e1_2) |
||||||
|
e3_2 = self.block3(e2_2) |
||||||
|
y2 = self.block4(e3_2) |
||||||
|
|
||||||
|
y1 = self.center(y1) |
||||||
|
y2 = self.center(y2) |
||||||
|
c = self.df4(y1, y2) |
||||||
|
|
||||||
|
y1 = self.decoder3(y1, e3_1) |
||||||
|
y2 = self.decoder3(y2, e3_2) |
||||||
|
c = self.catc3(c, self.df3(y1, y2)) |
||||||
|
|
||||||
|
y1 = self.decoder2(y1, e2_1) |
||||||
|
y2 = self.decoder2(y2, e2_2) |
||||||
|
c = self.catc2(c, self.df2(y1, y2)) |
||||||
|
|
||||||
|
y1 = self.decoder1(y1, e1_1) |
||||||
|
y2 = self.decoder1(y2, e1_2) |
||||||
|
|
||||||
|
c = self.catc1(c, self.df1(y1, y2)) |
||||||
|
y = self.conv_out(self.upsample_x2(c)) |
||||||
|
|
||||||
|
if self.training: |
||||||
|
y1 = self.conv_out_class(y1) |
||||||
|
y2 = self.conv_out_class(y2) |
||||||
|
return [y, [y1, y2]] |
||||||
|
else: |
||||||
|
return [y] |
@ -0,0 +1,170 @@ |
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
import paddle |
||||||
|
import paddle.nn as nn |
||||||
|
import paddle.nn.functional as F |
||||||
|
|
||||||
|
|
||||||
|
class DiceLoss(nn.Layer): |
||||||
|
def __init__(self, batch=True): |
||||||
|
super(DiceLoss, self).__init__() |
||||||
|
self.batch = batch |
||||||
|
|
||||||
|
def soft_dice_coeff(self, y_pred, y_true): |
||||||
|
smooth = 0.00001 |
||||||
|
if self.batch: |
||||||
|
i = paddle.sum(y_true) |
||||||
|
j = paddle.sum(y_pred) |
||||||
|
intersection = paddle.sum(y_true * y_pred) |
||||||
|
else: |
||||||
|
i = y_true.sum(1).sum(1).sum(1) |
||||||
|
j = y_pred.sum(1).sum(1).sum(1) |
||||||
|
intersection = (y_true * y_pred).sum(1).sum(1).sum(1) |
||||||
|
score = (2. * intersection + smooth) / (i + j + smooth) |
||||||
|
return score.mean() |
||||||
|
|
||||||
|
def soft_dice_loss(self, y_pred, y_true): |
||||||
|
loss = 1 - self.soft_dice_coeff(y_pred, y_true) |
||||||
|
return loss |
||||||
|
|
||||||
|
def forward(self, y_pred, y_true): |
||||||
|
return self.soft_dice_loss(y_pred.astype(paddle.float32), y_true) |
||||||
|
|
||||||
|
|
||||||
|
class MultiClassDiceLoss(nn.Layer): |
||||||
|
def __init__( |
||||||
|
self, |
||||||
|
weight, |
||||||
|
batch=True, |
||||||
|
ignore_index=-1, |
||||||
|
do_softmax=False, |
||||||
|
**kwargs, ): |
||||||
|
super(MultiClassDiceLoss, self).__init__() |
||||||
|
self.ignore_index = ignore_index |
||||||
|
self.weight = weight |
||||||
|
self.do_softmax = do_softmax |
||||||
|
self.binary_diceloss = DiceLoss(batch) |
||||||
|
|
||||||
|
def forward(self, y_pred, y_true): |
||||||
|
if self.do_softmax: |
||||||
|
y_pred = paddle.nn.functional.softmax(y_pred, axis=1) |
||||||
|
y_true = F.one_hot(y_true.long(), y_pred.shape[1]).permute(0, 3, 1, 2) |
||||||
|
total_loss = 0.0 |
||||||
|
tmp_i = 0.0 |
||||||
|
for i in range(y_pred.shape[1]): |
||||||
|
if i != self.ignore_index: |
||||||
|
diceloss = self.binary_diceloss(y_pred[:, i, :, :], |
||||||
|
y_true[:, i, :, :]) |
||||||
|
total_loss += paddle.multiply(diceloss, self.weight[i]) |
||||||
|
tmp_i += 1.0 |
||||||
|
return total_loss / tmp_i |
||||||
|
|
||||||
|
|
||||||
|
class DiceBCELoss(nn.Layer): |
||||||
|
"""Binary change detection task loss""" |
||||||
|
|
||||||
|
def __init__(self): |
||||||
|
super(DiceBCELoss, self).__init__() |
||||||
|
self.bce_loss = nn.BCELoss() |
||||||
|
self.binnary_dice = DiceLoss() |
||||||
|
|
||||||
|
def forward(self, scores, labels, do_sigmoid=True): |
||||||
|
if len(scores.shape) > 3: |
||||||
|
scores = scores.squeeze(1) |
||||||
|
if len(labels.shape) > 3: |
||||||
|
labels = labels.squeeze(1) |
||||||
|
if do_sigmoid: |
||||||
|
scores = paddle.nn.functional.sigmoid(scores.clone()) |
||||||
|
diceloss = self.binnary_dice(scores, labels) |
||||||
|
bceloss = self.bce_loss(scores, labels) |
||||||
|
return diceloss + bceloss |
||||||
|
|
||||||
|
|
||||||
|
class McDiceBCELoss(nn.Layer): |
||||||
|
"""Multi-class change detection task loss""" |
||||||
|
|
||||||
|
def __init__(self, weight, do_sigmoid=True): |
||||||
|
super(McDiceBCELoss, self).__init__() |
||||||
|
self.ce_loss = nn.CrossEntropyLoss(weight) |
||||||
|
self.dice = MultiClassDiceLoss(weight, do_sigmoid) |
||||||
|
|
||||||
|
def forward(self, scores, labels): |
||||||
|
if len(scores.shape) < 4: |
||||||
|
scores = scores.unsqueeze(1) |
||||||
|
if len(labels.shape) < 4: |
||||||
|
labels = labels.unsqueeze(1) |
||||||
|
diceloss = self.dice(scores, labels) |
||||||
|
bceloss = self.ce_loss(scores, labels) |
||||||
|
return diceloss + bceloss |
||||||
|
|
||||||
|
|
||||||
|
def fccdn_ssl_loss(logits_list, labels): |
||||||
|
""" |
||||||
|
Self-supervised learning loss for change detection. |
||||||
|
|
||||||
|
The original article refers to |
||||||
|
Pan Chen, et al., "FCCDN: Feature Constraint Network for VHR Image Change Detection" |
||||||
|
(https://arxiv.org/pdf/2105.10860.pdf). |
||||||
|
|
||||||
|
Args: |
||||||
|
logits_list (list[paddle.Tensor]): Single-channel segmentation logit maps for each of the two temporal phases. |
||||||
|
labels (paddle.Tensor): Binary change labels. |
||||||
|
""" |
||||||
|
|
||||||
|
# Create loss |
||||||
|
criterion_ssl = DiceBCELoss() |
||||||
|
|
||||||
|
# Get downsampled change map |
||||||
|
h, w = logits_list[0].shape[-2], logits_list[0].shape[-1] |
||||||
|
labels_downsample = F.interpolate(x=labels.unsqueeze(1), size=[h, w]) |
||||||
|
labels_type = str(labels_downsample.dtype) |
||||||
|
assert "int" in labels_type or "bool" in labels_type,\ |
||||||
|
f"Expected dtype of labels to be int or bool, but got {labels_type}" |
||||||
|
|
||||||
|
# Seg map |
||||||
|
out1 = paddle.nn.functional.sigmoid(logits_list[0]).clone() |
||||||
|
out2 = paddle.nn.functional.sigmoid(logits_list[1]).clone() |
||||||
|
out3 = out1.clone() |
||||||
|
out4 = out2.clone() |
||||||
|
|
||||||
|
out1 = paddle.where(labels_downsample == 1, paddle.zeros_like(out1), out1) |
||||||
|
out2 = paddle.where(labels_downsample == 1, paddle.zeros_like(out2), out2) |
||||||
|
out3 = paddle.where(labels_downsample != 1, paddle.zeros_like(out3), out3) |
||||||
|
out4 = paddle.where(labels_downsample != 1, paddle.zeros_like(out4), out4) |
||||||
|
|
||||||
|
pred_seg_pre_tmp1 = paddle.where(out1 <= 0.5, |
||||||
|
paddle.zeros_like(out1), |
||||||
|
paddle.ones_like(out1)) |
||||||
|
pred_seg_post_tmp1 = paddle.where(out2 <= 0.5, |
||||||
|
paddle.zeros_like(out2), |
||||||
|
paddle.ones_like(out2)) |
||||||
|
|
||||||
|
pred_seg_pre_tmp2 = paddle.where(out3 <= 0.5, |
||||||
|
paddle.zeros_like(out3), |
||||||
|
paddle.ones_like(out3)) |
||||||
|
pred_seg_post_tmp2 = paddle.where(out4 <= 0.5, |
||||||
|
paddle.zeros_like(out4), |
||||||
|
paddle.ones_like(out4)) |
||||||
|
|
||||||
|
# Seg loss |
||||||
|
labels_downsample = labels_downsample.astype(paddle.float32) |
||||||
|
loss_aux = 0.2 * criterion_ssl(out1, pred_seg_post_tmp1, False) |
||||||
|
loss_aux += 0.2 * criterion_ssl(out2, pred_seg_pre_tmp1, False) |
||||||
|
loss_aux += 0.2 * criterion_ssl( |
||||||
|
out3, labels_downsample - pred_seg_post_tmp2, False) |
||||||
|
loss_aux += 0.2 * criterion_ssl(out4, labels_downsample - pred_seg_pre_tmp2, |
||||||
|
False) |
||||||
|
|
||||||
|
return loss_aux |
@ -0,0 +1,27 @@ |
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
import paddle |
||||||
|
import paddle.nn as nn |
||||||
|
|
||||||
|
from paddlers.models.ppgan.modules.init import reset_parameters |
||||||
|
|
||||||
|
|
||||||
|
def init_sr_weight(net): |
||||||
|
def reset_func(m): |
||||||
|
if hasattr(m, 'weight') and ( |
||||||
|
not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))): |
||||||
|
reset_parameters(m) |
||||||
|
|
||||||
|
net.apply(reset_func) |
@ -1,106 +0,0 @@ |
|||||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. |
|
||||||
# |
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
||||||
# you may not use this file except in compliance with the License. |
|
||||||
# You may obtain a copy of the License at |
|
||||||
# |
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0 |
|
||||||
# |
|
||||||
# Unless required by applicable law or agreed to in writing, software |
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS, |
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
||||||
# See the License for the specific language governing permissions and |
|
||||||
# limitations under the License. |
|
||||||
|
|
||||||
import paddle |
|
||||||
import paddle.nn as nn |
|
||||||
|
|
||||||
from .generators.builder import build_generator |
|
||||||
from ...models.ppgan.models.criterions.builder import build_criterion |
|
||||||
from ...models.ppgan.models.base_model import BaseModel |
|
||||||
from ...models.ppgan.models.builder import MODELS |
|
||||||
from ...models.ppgan.utils.visual import tensor2img |
|
||||||
from ...models.ppgan.modules.init import reset_parameters |
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register() |
|
||||||
class RCANModel(BaseModel): |
|
||||||
""" |
|
||||||
Base SR model for single image super-resolution. |
|
||||||
""" |
|
||||||
|
|
||||||
def __init__(self, generator, pixel_criterion=None, use_init_weight=False): |
|
||||||
""" |
|
||||||
Args: |
|
||||||
generator (dict): config of generator. |
|
||||||
pixel_criterion (dict): config of pixel criterion. |
|
||||||
""" |
|
||||||
super(RCANModel, self).__init__() |
|
||||||
|
|
||||||
self.nets['generator'] = build_generator(generator) |
|
||||||
self.error_last = 1e8 |
|
||||||
self.batch = 0 |
|
||||||
if pixel_criterion: |
|
||||||
self.pixel_criterion = build_criterion(pixel_criterion) |
|
||||||
if use_init_weight: |
|
||||||
init_sr_weight(self.nets['generator']) |
|
||||||
|
|
||||||
def setup_input(self, input): |
|
||||||
self.lq = paddle.to_tensor(input['lq']) |
|
||||||
self.visual_items['lq'] = self.lq |
|
||||||
if 'gt' in input: |
|
||||||
self.gt = paddle.to_tensor(input['gt']) |
|
||||||
self.visual_items['gt'] = self.gt |
|
||||||
self.image_paths = input['lq_path'] |
|
||||||
|
|
||||||
def forward(self): |
|
||||||
pass |
|
||||||
|
|
||||||
def train_iter(self, optims=None): |
|
||||||
optims['optim'].clear_grad() |
|
||||||
|
|
||||||
self.output = self.nets['generator'](self.lq) |
|
||||||
self.visual_items['output'] = self.output |
|
||||||
# pixel loss |
|
||||||
loss_pixel = self.pixel_criterion(self.output, self.gt) |
|
||||||
self.losses['loss_pixel'] = loss_pixel |
|
||||||
|
|
||||||
skip_threshold = 1e6 |
|
||||||
|
|
||||||
if loss_pixel.item() < skip_threshold * self.error_last: |
|
||||||
loss_pixel.backward() |
|
||||||
optims['optim'].step() |
|
||||||
else: |
|
||||||
print('Skip this batch {}! (Loss: {})'.format(self.batch + 1, |
|
||||||
loss_pixel.item())) |
|
||||||
self.batch += 1 |
|
||||||
|
|
||||||
if self.batch % 1000 == 0: |
|
||||||
self.error_last = loss_pixel.item() / 1000 |
|
||||||
print("update error_last:{}".format(self.error_last)) |
|
||||||
|
|
||||||
def test_iter(self, metrics=None): |
|
||||||
self.nets['generator'].eval() |
|
||||||
with paddle.no_grad(): |
|
||||||
self.output = self.nets['generator'](self.lq) |
|
||||||
self.visual_items['output'] = self.output |
|
||||||
self.nets['generator'].train() |
|
||||||
|
|
||||||
out_img = [] |
|
||||||
gt_img = [] |
|
||||||
for out_tensor, gt_tensor in zip(self.output, self.gt): |
|
||||||
out_img.append(tensor2img(out_tensor, (0., 255.))) |
|
||||||
gt_img.append(tensor2img(gt_tensor, (0., 255.))) |
|
||||||
|
|
||||||
if metrics is not None: |
|
||||||
for metric in metrics.values(): |
|
||||||
metric.update(out_img, gt_img) |
|
||||||
|
|
||||||
|
|
||||||
def init_sr_weight(net): |
|
||||||
def reset_func(m): |
|
||||||
if hasattr(m, 'weight') and ( |
|
||||||
not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))): |
|
||||||
reset_parameters(m) |
|
||||||
|
|
||||||
net.apply(reset_func) |
|
@ -1,786 +0,0 @@ |
|||||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
|
||||||
# |
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
||||||
# you may not use this file except in compliance with the License. |
|
||||||
# You may obtain a copy of the License at |
|
||||||
# |
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0 |
|
||||||
# |
|
||||||
# Unless required by applicable law or agreed to in writing, software |
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS, |
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
||||||
# See the License for the specific language governing permissions and |
|
||||||
# limitations under the License. |
|
||||||
|
|
||||||
import os |
|
||||||
import time |
|
||||||
import datetime |
|
||||||
|
|
||||||
import paddle |
|
||||||
from paddle.distributed import ParallelEnv |
|
||||||
|
|
||||||
from ..models.ppgan.datasets.builder import build_dataloader |
|
||||||
from ..models.ppgan.models.builder import build_model |
|
||||||
from ..models.ppgan.utils.visual import tensor2img, save_image |
|
||||||
from ..models.ppgan.utils.filesystem import makedirs, save, load |
|
||||||
from ..models.ppgan.utils.timer import TimeAverager |
|
||||||
from ..models.ppgan.utils.profiler import add_profiler_step |
|
||||||
from ..models.ppgan.utils.logger import setup_logger |
|
||||||
|
|
||||||
|
|
||||||
# 定义AttrDict类实现动态属性 |
|
||||||
class AttrDict(dict): |
|
||||||
def __getattr__(self, key): |
|
||||||
try: |
|
||||||
return self[key] |
|
||||||
except KeyError: |
|
||||||
raise AttributeError(key) |
|
||||||
|
|
||||||
def __setattr__(self, key, value): |
|
||||||
if key in self.__dict__: |
|
||||||
self.__dict__[key] = value |
|
||||||
else: |
|
||||||
self[key] = value |
|
||||||
|
|
||||||
|
|
||||||
# 创建AttrDict类 |
|
||||||
def create_attr_dict(config_dict): |
|
||||||
from ast import literal_eval |
|
||||||
for key, value in config_dict.items(): |
|
||||||
if type(value) is dict: |
|
||||||
config_dict[key] = value = AttrDict(value) |
|
||||||
if isinstance(value, str): |
|
||||||
try: |
|
||||||
value = literal_eval(value) |
|
||||||
except BaseException: |
|
||||||
pass |
|
||||||
if isinstance(value, AttrDict): |
|
||||||
create_attr_dict(config_dict[key]) |
|
||||||
else: |
|
||||||
config_dict[key] = value |
|
||||||
|
|
||||||
|
|
||||||
# 数据加载类 |
|
||||||
class IterLoader: |
|
||||||
def __init__(self, dataloader): |
|
||||||
self._dataloader = dataloader |
|
||||||
self.iter_loader = iter(self._dataloader) |
|
||||||
self._epoch = 1 |
|
||||||
|
|
||||||
@property |
|
||||||
def epoch(self): |
|
||||||
return self._epoch |
|
||||||
|
|
||||||
def __next__(self): |
|
||||||
try: |
|
||||||
data = next(self.iter_loader) |
|
||||||
except StopIteration: |
|
||||||
self._epoch += 1 |
|
||||||
self.iter_loader = iter(self._dataloader) |
|
||||||
data = next(self.iter_loader) |
|
||||||
|
|
||||||
return data |
|
||||||
|
|
||||||
def __len__(self): |
|
||||||
return len(self._dataloader) |
|
||||||
|
|
||||||
|
|
||||||
# 基础训练类 |
|
||||||
class Restorer: |
|
||||||
""" |
|
||||||
# trainer calling logic: |
|
||||||
# |
|
||||||
# build_model || model(BaseModel) |
|
||||||
# | || |
|
||||||
# build_dataloader || dataloader |
|
||||||
# | || |
|
||||||
# model.setup_lr_schedulers || lr_scheduler |
|
||||||
# | || |
|
||||||
# model.setup_optimizers || optimizers |
|
||||||
# | || |
|
||||||
# train loop (model.setup_input + model.train_iter) || train loop |
|
||||||
# | || |
|
||||||
# print log (model.get_current_losses) || |
|
||||||
# | || |
|
||||||
# save checkpoint (model.nets) \/ |
|
||||||
""" |
|
||||||
|
|
||||||
def __init__(self, cfg, logger): |
|
||||||
# base config |
|
||||||
# self.logger = logging.getLogger(__name__) |
|
||||||
self.logger = logger |
|
||||||
self.cfg = cfg |
|
||||||
self.output_dir = cfg.output_dir |
|
||||||
self.max_eval_steps = cfg.model.get('max_eval_steps', None) |
|
||||||
|
|
||||||
self.local_rank = ParallelEnv().local_rank |
|
||||||
self.world_size = ParallelEnv().nranks |
|
||||||
self.log_interval = cfg.log_config.interval |
|
||||||
self.visual_interval = cfg.log_config.visiual_interval |
|
||||||
self.weight_interval = cfg.snapshot_config.interval |
|
||||||
|
|
||||||
self.start_epoch = 1 |
|
||||||
self.current_epoch = 1 |
|
||||||
self.current_iter = 1 |
|
||||||
self.inner_iter = 1 |
|
||||||
self.batch_id = 0 |
|
||||||
self.global_steps = 0 |
|
||||||
|
|
||||||
# build model |
|
||||||
self.model = build_model(cfg.model) |
|
||||||
# multiple gpus prepare |
|
||||||
if ParallelEnv().nranks > 1: |
|
||||||
self.distributed_data_parallel() |
|
||||||
|
|
||||||
# build metrics |
|
||||||
self.metrics = None |
|
||||||
self.is_save_img = True |
|
||||||
validate_cfg = cfg.get('validate', None) |
|
||||||
if validate_cfg and 'metrics' in validate_cfg: |
|
||||||
self.metrics = self.model.setup_metrics(validate_cfg['metrics']) |
|
||||||
if validate_cfg and 'save_img' in validate_cfg: |
|
||||||
self.is_save_img = validate_cfg['save_img'] |
|
||||||
|
|
||||||
self.enable_visualdl = cfg.get('enable_visualdl', False) |
|
||||||
if self.enable_visualdl: |
|
||||||
import visualdl |
|
||||||
self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir) |
|
||||||
|
|
||||||
# evaluate only |
|
||||||
if not cfg.is_train: |
|
||||||
return |
|
||||||
|
|
||||||
# build train dataloader |
|
||||||
self.train_dataloader = build_dataloader(cfg.dataset.train) |
|
||||||
self.iters_per_epoch = len(self.train_dataloader) |
|
||||||
|
|
||||||
# build lr scheduler |
|
||||||
# TODO: has a better way? |
|
||||||
if 'lr_scheduler' in cfg and 'iters_per_epoch' in cfg.lr_scheduler: |
|
||||||
cfg.lr_scheduler.iters_per_epoch = self.iters_per_epoch |
|
||||||
self.lr_schedulers = self.model.setup_lr_schedulers(cfg.lr_scheduler) |
|
||||||
|
|
||||||
# build optimizers |
|
||||||
self.optimizers = self.model.setup_optimizers(self.lr_schedulers, |
|
||||||
cfg.optimizer) |
|
||||||
|
|
||||||
self.epochs = cfg.get('epochs', None) |
|
||||||
if self.epochs: |
|
||||||
self.total_iters = self.epochs * self.iters_per_epoch |
|
||||||
self.by_epoch = True |
|
||||||
else: |
|
||||||
self.by_epoch = False |
|
||||||
self.total_iters = cfg.total_iters |
|
||||||
|
|
||||||
if self.by_epoch: |
|
||||||
self.weight_interval *= self.iters_per_epoch |
|
||||||
|
|
||||||
self.validate_interval = -1 |
|
||||||
if cfg.get('validate', None) is not None: |
|
||||||
self.validate_interval = cfg.validate.get('interval', -1) |
|
||||||
|
|
||||||
self.time_count = {} |
|
||||||
self.best_metric = {} |
|
||||||
self.model.set_total_iter(self.total_iters) |
|
||||||
self.profiler_options = cfg.profiler_options |
|
||||||
|
|
||||||
def distributed_data_parallel(self): |
|
||||||
paddle.distributed.init_parallel_env() |
|
||||||
find_unused_parameters = self.cfg.get('find_unused_parameters', False) |
|
||||||
for net_name, net in self.model.nets.items(): |
|
||||||
self.model.nets[net_name] = paddle.DataParallel( |
|
||||||
net, find_unused_parameters=find_unused_parameters) |
|
||||||
|
|
||||||
def learning_rate_scheduler_step(self): |
|
||||||
if isinstance(self.model.lr_scheduler, dict): |
|
||||||
for lr_scheduler in self.model.lr_scheduler.values(): |
|
||||||
lr_scheduler.step() |
|
||||||
elif isinstance(self.model.lr_scheduler, |
|
||||||
paddle.optimizer.lr.LRScheduler): |
|
||||||
self.model.lr_scheduler.step() |
|
||||||
else: |
|
||||||
raise ValueError( |
|
||||||
'lr schedulter must be a dict or an instance of LRScheduler') |
|
||||||
|
|
||||||
def train(self): |
|
||||||
reader_cost_averager = TimeAverager() |
|
||||||
batch_cost_averager = TimeAverager() |
|
||||||
|
|
||||||
iter_loader = IterLoader(self.train_dataloader) |
|
||||||
|
|
||||||
# set model.is_train = True |
|
||||||
self.model.setup_train_mode(is_train=True) |
|
||||||
while self.current_iter < (self.total_iters + 1): |
|
||||||
self.current_epoch = iter_loader.epoch |
|
||||||
self.inner_iter = self.current_iter % self.iters_per_epoch |
|
||||||
|
|
||||||
add_profiler_step(self.profiler_options) |
|
||||||
|
|
||||||
start_time = step_start_time = time.time() |
|
||||||
data = next(iter_loader) |
|
||||||
reader_cost_averager.record(time.time() - step_start_time) |
|
||||||
# unpack data from dataset and apply preprocessing |
|
||||||
# data input should be dict |
|
||||||
self.model.setup_input(data) |
|
||||||
self.model.train_iter(self.optimizers) |
|
||||||
|
|
||||||
batch_cost_averager.record( |
|
||||||
time.time() - step_start_time, |
|
||||||
num_samples=self.cfg['dataset']['train'].get('batch_size', 1)) |
|
||||||
|
|
||||||
step_start_time = time.time() |
|
||||||
|
|
||||||
if self.current_iter % self.log_interval == 0: |
|
||||||
self.data_time = reader_cost_averager.get_average() |
|
||||||
self.step_time = batch_cost_averager.get_average() |
|
||||||
self.ips = batch_cost_averager.get_ips_average() |
|
||||||
self.print_log() |
|
||||||
|
|
||||||
reader_cost_averager.reset() |
|
||||||
batch_cost_averager.reset() |
|
||||||
|
|
||||||
if self.current_iter % self.visual_interval == 0 and self.local_rank == 0: |
|
||||||
self.visual('visual_train') |
|
||||||
|
|
||||||
self.learning_rate_scheduler_step() |
|
||||||
|
|
||||||
if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0: |
|
||||||
self.test() |
|
||||||
|
|
||||||
if self.current_iter % self.weight_interval == 0: |
|
||||||
self.save(self.current_iter, 'weight', keep=-1) |
|
||||||
self.save(self.current_iter) |
|
||||||
|
|
||||||
self.current_iter += 1 |
|
||||||
|
|
||||||
def test(self): |
|
||||||
if not hasattr(self, 'test_dataloader'): |
|
||||||
self.test_dataloader = build_dataloader( |
|
||||||
self.cfg.dataset.test, is_train=False) |
|
||||||
iter_loader = IterLoader(self.test_dataloader) |
|
||||||
if self.max_eval_steps is None: |
|
||||||
self.max_eval_steps = len(self.test_dataloader) |
|
||||||
|
|
||||||
if self.metrics: |
|
||||||
for metric in self.metrics.values(): |
|
||||||
metric.reset() |
|
||||||
|
|
||||||
# set model.is_train = False |
|
||||||
self.model.setup_train_mode(is_train=False) |
|
||||||
|
|
||||||
for i in range(self.max_eval_steps): |
|
||||||
if self.max_eval_steps < self.log_interval or i % self.log_interval == 0: |
|
||||||
self.logger.info('Test iter: [%d/%d]' % ( |
|
||||||
i * self.world_size, self.max_eval_steps * self.world_size)) |
|
||||||
|
|
||||||
data = next(iter_loader) |
|
||||||
self.model.setup_input(data) |
|
||||||
self.model.test_iter(metrics=self.metrics) |
|
||||||
|
|
||||||
if self.is_save_img: |
|
||||||
visual_results = {} |
|
||||||
current_paths = self.model.get_image_paths() |
|
||||||
current_visuals = self.model.get_current_visuals() |
|
||||||
|
|
||||||
if len(current_visuals) > 0 and list(current_visuals.values())[ |
|
||||||
0].shape == 4: |
|
||||||
num_samples = list(current_visuals.values())[0].shape[0] |
|
||||||
else: |
|
||||||
num_samples = 1 |
|
||||||
|
|
||||||
for j in range(num_samples): |
|
||||||
if j < len(current_paths): |
|
||||||
short_path = os.path.basename(current_paths[j]) |
|
||||||
basename = os.path.splitext(short_path)[0] |
|
||||||
else: |
|
||||||
basename = '{:04d}_{:04d}'.format(i, j) |
|
||||||
for k, img_tensor in current_visuals.items(): |
|
||||||
name = '%s_%s' % (basename, k) |
|
||||||
if len(img_tensor.shape) == 4: |
|
||||||
visual_results.update({name: img_tensor[j]}) |
|
||||||
else: |
|
||||||
visual_results.update({name: img_tensor}) |
|
||||||
|
|
||||||
self.visual( |
|
||||||
'visual_test', |
|
||||||
visual_results=visual_results, |
|
||||||
step=self.batch_id, |
|
||||||
is_save_image=True) |
|
||||||
|
|
||||||
if self.metrics: |
|
||||||
for metric_name, metric in self.metrics.items(): |
|
||||||
self.logger.info("Metric {}: {:.4f}".format( |
|
||||||
metric_name, metric.accumulate())) |
|
||||||
|
|
||||||
def print_log(self): |
|
||||||
losses = self.model.get_current_losses() |
|
||||||
|
|
||||||
message = '' |
|
||||||
if self.by_epoch: |
|
||||||
message += 'Epoch: %d/%d, iter: %d/%d ' % ( |
|
||||||
self.current_epoch, self.epochs, self.inner_iter, |
|
||||||
self.iters_per_epoch) |
|
||||||
else: |
|
||||||
message += 'Iter: %d/%d ' % (self.current_iter, self.total_iters) |
|
||||||
|
|
||||||
message += f'lr: {self.current_learning_rate:.3e} ' |
|
||||||
|
|
||||||
for k, v in losses.items(): |
|
||||||
message += '%s: %.3f ' % (k, v) |
|
||||||
if self.enable_visualdl: |
|
||||||
self.vdl_logger.add_scalar(k, v, step=self.global_steps) |
|
||||||
|
|
||||||
if hasattr(self, 'step_time'): |
|
||||||
message += 'batch_cost: %.5f sec ' % self.step_time |
|
||||||
|
|
||||||
if hasattr(self, 'data_time'): |
|
||||||
message += 'reader_cost: %.5f sec ' % self.data_time |
|
||||||
|
|
||||||
if hasattr(self, 'ips'): |
|
||||||
message += 'ips: %.5f images/s ' % self.ips |
|
||||||
|
|
||||||
if hasattr(self, 'step_time'): |
|
||||||
eta = self.step_time * (self.total_iters - self.current_iter) |
|
||||||
eta = eta if eta > 0 else 0 |
|
||||||
|
|
||||||
eta_str = str(datetime.timedelta(seconds=int(eta))) |
|
||||||
message += f'eta: {eta_str}' |
|
||||||
|
|
||||||
# print the message |
|
||||||
self.logger.info(message) |
|
||||||
|
|
||||||
@property |
|
||||||
def current_learning_rate(self): |
|
||||||
for optimizer in self.model.optimizers.values(): |
|
||||||
return optimizer.get_lr() |
|
||||||
|
|
||||||
def visual(self, |
|
||||||
results_dir, |
|
||||||
visual_results=None, |
|
||||||
step=None, |
|
||||||
is_save_image=False): |
|
||||||
""" |
|
||||||
visual the images, use visualdl or directly write to the directory |
|
||||||
Parameters: |
|
||||||
results_dir (str) -- directory name which contains saved images |
|
||||||
visual_results (dict) -- the results images dict |
|
||||||
step (int) -- global steps, used in visualdl |
|
||||||
is_save_image (bool) -- weather write to the directory or visualdl |
|
||||||
""" |
|
||||||
self.model.compute_visuals() |
|
||||||
|
|
||||||
if visual_results is None: |
|
||||||
visual_results = self.model.get_current_visuals() |
|
||||||
|
|
||||||
min_max = self.cfg.get('min_max', None) |
|
||||||
if min_max is None: |
|
||||||
min_max = (-1., 1.) |
|
||||||
|
|
||||||
image_num = self.cfg.get('image_num', None) |
|
||||||
if (image_num is None) or (not self.enable_visualdl): |
|
||||||
image_num = 1 |
|
||||||
for label, image in visual_results.items(): |
|
||||||
image_numpy = tensor2img(image, min_max, image_num) |
|
||||||
if (not is_save_image) and self.enable_visualdl: |
|
||||||
self.vdl_logger.add_image( |
|
||||||
results_dir + '/' + label, |
|
||||||
image_numpy, |
|
||||||
step=step if step else self.global_steps, |
|
||||||
dataformats="HWC" if image_num == 1 else "NCHW") |
|
||||||
else: |
|
||||||
if self.cfg.is_train: |
|
||||||
if self.by_epoch: |
|
||||||
msg = 'epoch%.3d_' % self.current_epoch |
|
||||||
else: |
|
||||||
msg = 'iter%.3d_' % self.current_iter |
|
||||||
else: |
|
||||||
msg = '' |
|
||||||
makedirs(os.path.join(self.output_dir, results_dir)) |
|
||||||
img_path = os.path.join(self.output_dir, results_dir, |
|
||||||
msg + '%s.png' % (label)) |
|
||||||
save_image(image_numpy, img_path) |
|
||||||
|
|
||||||
def save(self, epoch, name='checkpoint', keep=1): |
|
||||||
if self.local_rank != 0: |
|
||||||
return |
|
||||||
|
|
||||||
assert name in ['checkpoint', 'weight'] |
|
||||||
|
|
||||||
state_dicts = {} |
|
||||||
if self.by_epoch: |
|
||||||
save_filename = 'epoch_%s_%s.pdparams' % ( |
|
||||||
epoch // self.iters_per_epoch, name) |
|
||||||
else: |
|
||||||
save_filename = 'iter_%s_%s.pdparams' % (epoch, name) |
|
||||||
|
|
||||||
os.makedirs(self.output_dir, exist_ok=True) |
|
||||||
save_path = os.path.join(self.output_dir, save_filename) |
|
||||||
for net_name, net in self.model.nets.items(): |
|
||||||
state_dicts[net_name] = net.state_dict() |
|
||||||
|
|
||||||
if name == 'weight': |
|
||||||
save(state_dicts, save_path) |
|
||||||
return |
|
||||||
|
|
||||||
state_dicts['epoch'] = epoch |
|
||||||
|
|
||||||
for opt_name, opt in self.model.optimizers.items(): |
|
||||||
state_dicts[opt_name] = opt.state_dict() |
|
||||||
|
|
||||||
save(state_dicts, save_path) |
|
||||||
|
|
||||||
if keep > 0: |
|
||||||
try: |
|
||||||
if self.by_epoch: |
|
||||||
checkpoint_name_to_be_removed = os.path.join( |
|
||||||
self.output_dir, 'epoch_%s_%s.pdparams' % ( |
|
||||||
(epoch - keep * self.weight_interval) // |
|
||||||
self.iters_per_epoch, name)) |
|
||||||
else: |
|
||||||
checkpoint_name_to_be_removed = os.path.join( |
|
||||||
self.output_dir, 'iter_%s_%s.pdparams' % |
|
||||||
(epoch - keep * self.weight_interval, name)) |
|
||||||
|
|
||||||
if os.path.exists(checkpoint_name_to_be_removed): |
|
||||||
os.remove(checkpoint_name_to_be_removed) |
|
||||||
|
|
||||||
except Exception as e: |
|
||||||
self.logger.info('remove old checkpoints error: {}'.format(e)) |
|
||||||
|
|
||||||
def resume(self, checkpoint_path): |
|
||||||
state_dicts = load(checkpoint_path) |
|
||||||
if state_dicts.get('epoch', None) is not None: |
|
||||||
self.start_epoch = state_dicts['epoch'] + 1 |
|
||||||
self.global_steps = self.iters_per_epoch * state_dicts['epoch'] |
|
||||||
|
|
||||||
self.current_iter = state_dicts['epoch'] + 1 |
|
||||||
|
|
||||||
for net_name, net in self.model.nets.items(): |
|
||||||
net.set_state_dict(state_dicts[net_name]) |
|
||||||
|
|
||||||
for opt_name, opt in self.model.optimizers.items(): |
|
||||||
opt.set_state_dict(state_dicts[opt_name]) |
|
||||||
|
|
||||||
def load(self, weight_path): |
|
||||||
state_dicts = load(weight_path) |
|
||||||
|
|
||||||
for net_name, net in self.model.nets.items(): |
|
||||||
if net_name in state_dicts: |
|
||||||
net.set_state_dict(state_dicts[net_name]) |
|
||||||
self.logger.info('Loaded pretrained weight for net {}'.format( |
|
||||||
net_name)) |
|
||||||
else: |
|
||||||
self.logger.warning( |
|
||||||
'Can not find state dict of net {}. Skip load pretrained weight for net {}' |
|
||||||
.format(net_name, net_name)) |
|
||||||
|
|
||||||
def close(self): |
|
||||||
""" |
|
||||||
when finish the training need close file handler or other. |
|
||||||
""" |
|
||||||
if self.enable_visualdl: |
|
||||||
self.vdl_logger.close() |
|
||||||
|
|
||||||
|
|
||||||
# 基础超分模型训练类 |
|
||||||
class BasicSRNet: |
|
||||||
def __init__(self): |
|
||||||
self.model = {} |
|
||||||
self.optimizer = {} |
|
||||||
self.lr_scheduler = {} |
|
||||||
self.min_max = '' |
|
||||||
|
|
||||||
def train( |
|
||||||
self, |
|
||||||
total_iters, |
|
||||||
train_dataset, |
|
||||||
test_dataset, |
|
||||||
output_dir, |
|
||||||
validate, |
|
||||||
snapshot, |
|
||||||
log, |
|
||||||
lr_rate, |
|
||||||
evaluate_weights='', |
|
||||||
resume='', |
|
||||||
pretrain_weights='', |
|
||||||
periods=[100000], |
|
||||||
restart_weights=[1], ): |
|
||||||
self.lr_scheduler['learning_rate'] = lr_rate |
|
||||||
|
|
||||||
if self.lr_scheduler['name'] == 'CosineAnnealingRestartLR': |
|
||||||
self.lr_scheduler['periods'] = periods |
|
||||||
self.lr_scheduler['restart_weights'] = restart_weights |
|
||||||
|
|
||||||
validate = { |
|
||||||
'interval': validate, |
|
||||||
'save_img': False, |
|
||||||
'metrics': { |
|
||||||
'psnr': { |
|
||||||
'name': 'PSNR', |
|
||||||
'crop_border': 4, |
|
||||||
'test_y_channel': True |
|
||||||
}, |
|
||||||
'ssim': { |
|
||||||
'name': 'SSIM', |
|
||||||
'crop_border': 4, |
|
||||||
'test_y_channel': True |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
log_config = {'interval': log, 'visiual_interval': 500} |
|
||||||
snapshot_config = {'interval': snapshot} |
|
||||||
|
|
||||||
cfg = { |
|
||||||
'total_iters': total_iters, |
|
||||||
'output_dir': output_dir, |
|
||||||
'min_max': self.min_max, |
|
||||||
'model': self.model, |
|
||||||
'dataset': { |
|
||||||
'train': train_dataset, |
|
||||||
'test': test_dataset |
|
||||||
}, |
|
||||||
'lr_scheduler': self.lr_scheduler, |
|
||||||
'optimizer': self.optimizer, |
|
||||||
'validate': validate, |
|
||||||
'log_config': log_config, |
|
||||||
'snapshot_config': snapshot_config |
|
||||||
} |
|
||||||
|
|
||||||
cfg = AttrDict(cfg) |
|
||||||
create_attr_dict(cfg) |
|
||||||
|
|
||||||
cfg.is_train = True |
|
||||||
cfg.profiler_options = None |
|
||||||
cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime()) |
|
||||||
|
|
||||||
if cfg.model.name == 'BaseSRModel': |
|
||||||
floderModelName = cfg.model.generator.name |
|
||||||
else: |
|
||||||
floderModelName = cfg.model.name |
|
||||||
cfg.output_dir = os.path.join(cfg.output_dir, |
|
||||||
floderModelName + cfg.timestamp) |
|
||||||
|
|
||||||
logger_cfg = setup_logger(cfg.output_dir) |
|
||||||
logger_cfg.info('Configs: {}'.format(cfg)) |
|
||||||
|
|
||||||
if paddle.is_compiled_with_cuda(): |
|
||||||
paddle.set_device('gpu') |
|
||||||
else: |
|
||||||
paddle.set_device('cpu') |
|
||||||
|
|
||||||
# build trainer |
|
||||||
trainer = Restorer(cfg, logger_cfg) |
|
||||||
|
|
||||||
# continue train or evaluate, checkpoint need contain epoch and optimizer info |
|
||||||
if len(resume) > 0: |
|
||||||
trainer.resume(resume) |
|
||||||
# evaluate or finute, only load generator weights |
|
||||||
elif len(pretrain_weights) > 0: |
|
||||||
trainer.load(pretrain_weights) |
|
||||||
if len(evaluate_weights) > 0: |
|
||||||
trainer.load(evaluate_weights) |
|
||||||
trainer.test() |
|
||||||
return |
|
||||||
# training, when keyboard interrupt save weights |
|
||||||
try: |
|
||||||
trainer.train() |
|
||||||
except KeyboardInterrupt as e: |
|
||||||
trainer.save(trainer.current_epoch) |
|
||||||
|
|
||||||
trainer.close() |
|
||||||
|
|
||||||
|
|
||||||
# DRN模型训练 |
|
||||||
class DRNet(BasicSRNet): |
|
||||||
def __init__(self, |
|
||||||
n_blocks=30, |
|
||||||
n_feats=16, |
|
||||||
n_colors=3, |
|
||||||
rgb_range=255, |
|
||||||
negval=0.2): |
|
||||||
super(DRNet, self).__init__() |
|
||||||
self.min_max = '(0., 255.)' |
|
||||||
self.generator = { |
|
||||||
'name': 'DRNGenerator', |
|
||||||
'scale': (2, 4), |
|
||||||
'n_blocks': n_blocks, |
|
||||||
'n_feats': n_feats, |
|
||||||
'n_colors': n_colors, |
|
||||||
'rgb_range': rgb_range, |
|
||||||
'negval': negval |
|
||||||
} |
|
||||||
self.pixel_criterion = {'name': 'L1Loss'} |
|
||||||
self.model = { |
|
||||||
'name': 'DRN', |
|
||||||
'generator': self.generator, |
|
||||||
'pixel_criterion': self.pixel_criterion |
|
||||||
} |
|
||||||
self.optimizer = { |
|
||||||
'optimG': { |
|
||||||
'name': 'Adam', |
|
||||||
'net_names': ['generator'], |
|
||||||
'weight_decay': 0.0, |
|
||||||
'beta1': 0.9, |
|
||||||
'beta2': 0.999 |
|
||||||
}, |
|
||||||
'optimD': { |
|
||||||
'name': 'Adam', |
|
||||||
'net_names': ['dual_model_0', 'dual_model_1'], |
|
||||||
'weight_decay': 0.0, |
|
||||||
'beta1': 0.9, |
|
||||||
'beta2': 0.999 |
|
||||||
} |
|
||||||
} |
|
||||||
self.lr_scheduler = { |
|
||||||
'name': 'CosineAnnealingRestartLR', |
|
||||||
'eta_min': 1e-07 |
|
||||||
} |
|
||||||
|
|
||||||
|
|
||||||
# 轻量化超分模型LESRCNN训练 |
|
||||||
class LESRCNNet(BasicSRNet): |
|
||||||
def __init__(self, scale=4, multi_scale=False, group=1): |
|
||||||
super(LESRCNNet, self).__init__() |
|
||||||
self.min_max = '(0., 1.)' |
|
||||||
self.generator = { |
|
||||||
'name': 'LESRCNNGenerator', |
|
||||||
'scale': scale, |
|
||||||
'multi_scale': False, |
|
||||||
'group': 1 |
|
||||||
} |
|
||||||
self.pixel_criterion = {'name': 'L1Loss'} |
|
||||||
self.model = { |
|
||||||
'name': 'BaseSRModel', |
|
||||||
'generator': self.generator, |
|
||||||
'pixel_criterion': self.pixel_criterion |
|
||||||
} |
|
||||||
self.optimizer = { |
|
||||||
'name': 'Adam', |
|
||||||
'net_names': ['generator'], |
|
||||||
'beta1': 0.9, |
|
||||||
'beta2': 0.99 |
|
||||||
} |
|
||||||
self.lr_scheduler = { |
|
||||||
'name': 'CosineAnnealingRestartLR', |
|
||||||
'eta_min': 1e-07 |
|
||||||
} |
|
||||||
|
|
||||||
|
|
||||||
# ESRGAN模型训练 |
|
||||||
# 若loss_type='gan' 使用感知损失、对抗损失和像素损失 |
|
||||||
# 若loss_type = 'pixel' 只使用像素损失 |
|
||||||
class ESRGANet(BasicSRNet): |
|
||||||
def __init__(self, loss_type='gan', in_nc=3, out_nc=3, nf=64, nb=23): |
|
||||||
super(ESRGANet, self).__init__() |
|
||||||
self.min_max = '(0., 1.)' |
|
||||||
self.generator = { |
|
||||||
'name': 'RRDBNet', |
|
||||||
'in_nc': in_nc, |
|
||||||
'out_nc': out_nc, |
|
||||||
'nf': nf, |
|
||||||
'nb': nb |
|
||||||
} |
|
||||||
|
|
||||||
if loss_type == 'gan': |
|
||||||
# 定义损失函数 |
|
||||||
self.pixel_criterion = {'name': 'L1Loss', 'loss_weight': 0.01} |
|
||||||
self.discriminator = { |
|
||||||
'name': 'VGGDiscriminator128', |
|
||||||
'in_channels': 3, |
|
||||||
'num_feat': 64 |
|
||||||
} |
|
||||||
self.perceptual_criterion = { |
|
||||||
'name': 'PerceptualLoss', |
|
||||||
'layer_weights': { |
|
||||||
'34': 1.0 |
|
||||||
}, |
|
||||||
'perceptual_weight': 1.0, |
|
||||||
'style_weight': 0.0, |
|
||||||
'norm_img': False |
|
||||||
} |
|
||||||
self.gan_criterion = { |
|
||||||
'name': 'GANLoss', |
|
||||||
'gan_mode': 'vanilla', |
|
||||||
'loss_weight': 0.005 |
|
||||||
} |
|
||||||
# 定义模型 |
|
||||||
self.model = { |
|
||||||
'name': 'ESRGAN', |
|
||||||
'generator': self.generator, |
|
||||||
'discriminator': self.discriminator, |
|
||||||
'pixel_criterion': self.pixel_criterion, |
|
||||||
'perceptual_criterion': self.perceptual_criterion, |
|
||||||
'gan_criterion': self.gan_criterion |
|
||||||
} |
|
||||||
self.optimizer = { |
|
||||||
'optimG': { |
|
||||||
'name': 'Adam', |
|
||||||
'net_names': ['generator'], |
|
||||||
'weight_decay': 0.0, |
|
||||||
'beta1': 0.9, |
|
||||||
'beta2': 0.99 |
|
||||||
}, |
|
||||||
'optimD': { |
|
||||||
'name': 'Adam', |
|
||||||
'net_names': ['discriminator'], |
|
||||||
'weight_decay': 0.0, |
|
||||||
'beta1': 0.9, |
|
||||||
'beta2': 0.99 |
|
||||||
} |
|
||||||
} |
|
||||||
self.lr_scheduler = { |
|
||||||
'name': 'MultiStepDecay', |
|
||||||
'milestones': [50000, 100000, 200000, 300000], |
|
||||||
'gamma': 0.5 |
|
||||||
} |
|
||||||
else: |
|
||||||
self.pixel_criterion = {'name': 'L1Loss'} |
|
||||||
self.model = { |
|
||||||
'name': 'BaseSRModel', |
|
||||||
'generator': self.generator, |
|
||||||
'pixel_criterion': self.pixel_criterion |
|
||||||
} |
|
||||||
self.optimizer = { |
|
||||||
'name': 'Adam', |
|
||||||
'net_names': ['generator'], |
|
||||||
'beta1': 0.9, |
|
||||||
'beta2': 0.99 |
|
||||||
} |
|
||||||
self.lr_scheduler = { |
|
||||||
'name': 'CosineAnnealingRestartLR', |
|
||||||
'eta_min': 1e-07 |
|
||||||
} |
|
||||||
|
|
||||||
|
|
||||||
# RCAN模型训练 |
|
||||||
class RCANet(BasicSRNet): |
|
||||||
def __init__( |
|
||||||
self, |
|
||||||
scale=2, |
|
||||||
n_resgroups=10, |
|
||||||
n_resblocks=20, ): |
|
||||||
super(RCANet, self).__init__() |
|
||||||
self.min_max = '(0., 255.)' |
|
||||||
self.generator = { |
|
||||||
'name': 'RCAN', |
|
||||||
'scale': scale, |
|
||||||
'n_resgroups': n_resgroups, |
|
||||||
'n_resblocks': n_resblocks |
|
||||||
} |
|
||||||
self.pixel_criterion = {'name': 'L1Loss'} |
|
||||||
self.model = { |
|
||||||
'name': 'RCANModel', |
|
||||||
'generator': self.generator, |
|
||||||
'pixel_criterion': self.pixel_criterion |
|
||||||
} |
|
||||||
self.optimizer = { |
|
||||||
'name': 'Adam', |
|
||||||
'net_names': ['generator'], |
|
||||||
'beta1': 0.9, |
|
||||||
'beta2': 0.99 |
|
||||||
} |
|
||||||
self.lr_scheduler = { |
|
||||||
'name': 'MultiStepDecay', |
|
||||||
'milestones': [250000, 500000, 750000, 1000000], |
|
||||||
'gamma': 0.5 |
|
||||||
} |
|
@ -0,0 +1,936 @@ |
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||||
|
# |
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||||
|
# you may not use this file except in compliance with the License. |
||||||
|
# You may obtain a copy of the License at |
||||||
|
# |
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||||
|
# |
||||||
|
# Unless required by applicable law or agreed to in writing, software |
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||||
|
# See the License for the specific language governing permissions and |
||||||
|
# limitations under the License. |
||||||
|
|
||||||
|
import os |
||||||
|
import os.path as osp |
||||||
|
from collections import OrderedDict |
||||||
|
|
||||||
|
import numpy as np |
||||||
|
import cv2 |
||||||
|
import paddle |
||||||
|
import paddle.nn.functional as F |
||||||
|
from paddle.static import InputSpec |
||||||
|
|
||||||
|
import paddlers |
||||||
|
import paddlers.models.ppgan as ppgan |
||||||
|
import paddlers.rs_models.res as cmres |
||||||
|
import paddlers.models.ppgan.metrics as metrics |
||||||
|
import paddlers.utils.logging as logging |
||||||
|
from paddlers.models import res_losses |
||||||
|
from paddlers.transforms import Resize, decode_image |
||||||
|
from paddlers.transforms.functions import calc_hr_shape |
||||||
|
from paddlers.utils import get_single_card_bs |
||||||
|
from .base import BaseModel |
||||||
|
from .utils.res_adapters import GANAdapter, OptimizerAdapter |
||||||
|
from .utils.infer_nets import InferResNet |
||||||
|
|
||||||
|
__all__ = ["DRN", "LESRCNN", "ESRGAN"] |
||||||
|
|
||||||
|
|
||||||
|
class BaseRestorer(BaseModel): |
||||||
|
MIN_MAX = (0., 1.) |
||||||
|
TEST_OUT_KEY = None |
||||||
|
|
||||||
|
def __init__(self, model_name, losses=None, sr_factor=None, **params): |
||||||
|
self.init_params = locals() |
||||||
|
if 'with_net' in self.init_params: |
||||||
|
del self.init_params['with_net'] |
||||||
|
super(BaseRestorer, self).__init__('restorer') |
||||||
|
self.model_name = model_name |
||||||
|
self.losses = losses |
||||||
|
self.sr_factor = sr_factor |
||||||
|
if params.get('with_net', True): |
||||||
|
params.pop('with_net', None) |
||||||
|
self.net = self.build_net(**params) |
||||||
|
self.find_unused_parameters = True |
||||||
|
|
||||||
|
def build_net(self, **params): |
||||||
|
# Currently, only use models from cmres. |
||||||
|
if not hasattr(cmres, self.model_name): |
||||||
|
raise ValueError("ERROR: There is no model named {}.".format( |
||||||
|
model_name)) |
||||||
|
net = dict(**cmres.__dict__)[self.model_name](**params) |
||||||
|
return net |
||||||
|
|
||||||
|
def _build_inference_net(self): |
||||||
|
# For GAN models, only the generator will be used for inference. |
||||||
|
if isinstance(self.net, GANAdapter): |
||||||
|
infer_net = InferResNet( |
||||||
|
self.net.generator, out_key=self.TEST_OUT_KEY) |
||||||
|
else: |
||||||
|
infer_net = InferResNet(self.net, out_key=self.TEST_OUT_KEY) |
||||||
|
infer_net.eval() |
||||||
|
return infer_net |
||||||
|
|
||||||
|
def _fix_transforms_shape(self, image_shape): |
||||||
|
if hasattr(self, 'test_transforms'): |
||||||
|
if self.test_transforms is not None: |
||||||
|
has_resize_op = False |
||||||
|
resize_op_idx = -1 |
||||||
|
normalize_op_idx = len(self.test_transforms.transforms) |
||||||
|
for idx, op in enumerate(self.test_transforms.transforms): |
||||||
|
name = op.__class__.__name__ |
||||||
|
if name == 'Normalize': |
||||||
|
normalize_op_idx = idx |
||||||
|
if 'Resize' in name: |
||||||
|
has_resize_op = True |
||||||
|
resize_op_idx = idx |
||||||
|
|
||||||
|
if not has_resize_op: |
||||||
|
self.test_transforms.transforms.insert( |
||||||
|
normalize_op_idx, Resize(target_size=image_shape)) |
||||||
|
else: |
||||||
|
self.test_transforms.transforms[resize_op_idx] = Resize( |
||||||
|
target_size=image_shape) |
||||||
|
|
||||||
|
def _get_test_inputs(self, image_shape): |
||||||
|
if image_shape is not None: |
||||||
|
if len(image_shape) == 2: |
||||||
|
image_shape = [1, 3] + image_shape |
||||||
|
self._fix_transforms_shape(image_shape[-2:]) |
||||||
|
else: |
||||||
|
image_shape = [None, 3, -1, -1] |
||||||
|
self.fixed_input_shape = image_shape |
||||||
|
input_spec = [ |
||||||
|
InputSpec( |
||||||
|
shape=image_shape, name='image', dtype='float32') |
||||||
|
] |
||||||
|
return input_spec |
||||||
|
|
||||||
|
def run(self, net, inputs, mode): |
||||||
|
outputs = OrderedDict() |
||||||
|
|
||||||
|
if mode == 'test': |
||||||
|
tar_shape = inputs[1] |
||||||
|
if self.status == 'Infer': |
||||||
|
net_out = net(inputs[0]) |
||||||
|
res_map_list = self.postprocess( |
||||||
|
net_out, tar_shape, transforms=inputs[2]) |
||||||
|
else: |
||||||
|
if isinstance(net, GANAdapter): |
||||||
|
net_out = net.generator(inputs[0]) |
||||||
|
else: |
||||||
|
net_out = net(inputs[0]) |
||||||
|
if self.TEST_OUT_KEY is not None: |
||||||
|
net_out = net_out[self.TEST_OUT_KEY] |
||||||
|
pred = self.postprocess( |
||||||
|
net_out, tar_shape, transforms=inputs[2]) |
||||||
|
res_map_list = [] |
||||||
|
for res_map in pred: |
||||||
|
res_map = self._tensor_to_images(res_map) |
||||||
|
res_map_list.append(res_map) |
||||||
|
outputs['res_map'] = res_map_list |
||||||
|
|
||||||
|
if mode == 'eval': |
||||||
|
if isinstance(net, GANAdapter): |
||||||
|
net_out = net.generator(inputs[0]) |
||||||
|
else: |
||||||
|
net_out = net(inputs[0]) |
||||||
|
if self.TEST_OUT_KEY is not None: |
||||||
|
net_out = net_out[self.TEST_OUT_KEY] |
||||||
|
tar = inputs[1] |
||||||
|
tar_shape = [tar.shape[-2:]] |
||||||
|
pred = self.postprocess( |
||||||
|
net_out, tar_shape, transforms=inputs[2])[0] # NCHW |
||||||
|
pred = self._tensor_to_images(pred) |
||||||
|
outputs['pred'] = pred |
||||||
|
tar = self._tensor_to_images(tar) |
||||||
|
outputs['tar'] = tar |
||||||
|
|
||||||
|
if mode == 'train': |
||||||
|
# This is used by non-GAN models. |
||||||
|
# For GAN models, self.run_gan() should be used. |
||||||
|
net_out = net(inputs[0]) |
||||||
|
loss = self.losses(net_out, inputs[1]) |
||||||
|
outputs['loss'] = loss |
||||||
|
return outputs |
||||||
|
|
||||||
|
def run_gan(self, net, inputs, mode, gan_mode): |
||||||
|
raise NotImplementedError |
||||||
|
|
||||||
|
def default_loss(self): |
||||||
|
return res_losses.L1Loss() |
||||||
|
|
||||||
|
def default_optimizer(self, |
||||||
|
parameters, |
||||||
|
learning_rate, |
||||||
|
num_epochs, |
||||||
|
num_steps_each_epoch, |
||||||
|
lr_decay_power=0.9): |
||||||
|
decay_step = num_epochs * num_steps_each_epoch |
||||||
|
lr_scheduler = paddle.optimizer.lr.PolynomialDecay( |
||||||
|
learning_rate, decay_step, end_lr=0, power=lr_decay_power) |
||||||
|
optimizer = paddle.optimizer.Momentum( |
||||||
|
learning_rate=lr_scheduler, |
||||||
|
parameters=parameters, |
||||||
|
momentum=0.9, |
||||||
|
weight_decay=4e-5) |
||||||
|
return optimizer |
||||||
|
|
||||||
|
def train(self, |
||||||
|
num_epochs, |
||||||
|
train_dataset, |
||||||
|
train_batch_size=2, |
||||||
|
eval_dataset=None, |
||||||
|
optimizer=None, |
||||||
|
save_interval_epochs=1, |
||||||
|
log_interval_steps=2, |
||||||
|
save_dir='output', |
||||||
|
pretrain_weights=None, |
||||||
|
learning_rate=0.01, |
||||||
|
lr_decay_power=0.9, |
||||||
|
early_stop=False, |
||||||
|
early_stop_patience=5, |
||||||
|
use_vdl=True, |
||||||
|
resume_checkpoint=None): |
||||||
|
""" |
||||||
|
Train the model. |
||||||
|
|
||||||
|
Args: |
||||||
|
num_epochs (int): Number of epochs. |
||||||
|
train_dataset (paddlers.datasets.ResDataset): Training dataset. |
||||||
|
train_batch_size (int, optional): Total batch size among all cards used in |
||||||
|
training. Defaults to 2. |
||||||
|
eval_dataset (paddlers.datasets.ResDataset|None, optional): Evaluation dataset. |
||||||
|
If None, the model will not be evaluated during training process. |
||||||
|
Defaults to None. |
||||||
|
optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in |
||||||
|
training. If None, a default optimizer will be used. Defaults to None. |
||||||
|
save_interval_epochs (int, optional): Epoch interval for saving the model. |
||||||
|
Defaults to 1. |
||||||
|
log_interval_steps (int, optional): Step interval for printing training |
||||||
|
information. Defaults to 2. |
||||||
|
save_dir (str, optional): Directory to save the model. Defaults to 'output'. |
||||||
|
pretrain_weights (str|None, optional): None or name/path of pretrained |
||||||
|
weights. If None, no pretrained weights will be loaded. |
||||||
|
Defaults to None. |
||||||
|
learning_rate (float, optional): Learning rate for training. Defaults to .01. |
||||||
|
lr_decay_power (float, optional): Learning decay power. Defaults to .9. |
||||||
|
early_stop (bool, optional): Whether to adopt early stop strategy. Defaults |
||||||
|
to False. |
||||||
|
early_stop_patience (int, optional): Early stop patience. Defaults to 5. |
||||||
|
use_vdl (bool, optional): Whether to use VisualDL to monitor the training |
||||||
|
process. Defaults to True. |
||||||
|
resume_checkpoint (str|None, optional): Path of the checkpoint to resume |
||||||
|
training from. If None, no training checkpoint will be resumed. At most |
||||||
|
Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously. |
||||||
|
Defaults to None. |
||||||
|
""" |
||||||
|
|
||||||
|
if self.status == 'Infer': |
||||||
|
logging.error( |
||||||
|
"Exported inference model does not support training.", |
||||||
|
exit=True) |
||||||
|
if pretrain_weights is not None and resume_checkpoint is not None: |
||||||
|
logging.error( |
||||||
|
"pretrain_weights and resume_checkpoint cannot be set simultaneously.", |
||||||
|
exit=True) |
||||||
|
|
||||||
|
if self.losses is None: |
||||||
|
self.losses = self.default_loss() |
||||||
|
|
||||||
|
if optimizer is None: |
||||||
|
num_steps_each_epoch = train_dataset.num_samples // train_batch_size |
||||||
|
if isinstance(self.net, GANAdapter): |
||||||
|
parameters = {'params_g': [], 'params_d': []} |
||||||
|
for net_g in self.net.generators: |
||||||
|
parameters['params_g'].append(net_g.parameters()) |
||||||
|
for net_d in self.net.discriminators: |
||||||
|
parameters['params_d'].append(net_d.parameters()) |
||||||
|
else: |
||||||
|
parameters = self.net.parameters() |
||||||
|
self.optimizer = self.default_optimizer( |
||||||
|
parameters, learning_rate, num_epochs, num_steps_each_epoch, |
||||||
|
lr_decay_power) |
||||||
|
else: |
||||||
|
self.optimizer = optimizer |
||||||
|
|
||||||
|
if pretrain_weights is not None and not osp.exists(pretrain_weights): |
||||||
|
logging.warning("Path of pretrain_weights('{}') does not exist!". |
||||||
|
format(pretrain_weights)) |
||||||
|
elif pretrain_weights is not None and osp.exists(pretrain_weights): |
||||||
|
if osp.splitext(pretrain_weights)[-1] != '.pdparams': |
||||||
|
logging.error( |
||||||
|
"Invalid pretrain weights. Please specify a '.pdparams' file.", |
||||||
|
exit=True) |
||||||
|
pretrained_dir = osp.join(save_dir, 'pretrain') |
||||||
|
is_backbone_weights = pretrain_weights == 'IMAGENET' |
||||||
|
self.net_initialize( |
||||||
|
pretrain_weights=pretrain_weights, |
||||||
|
save_dir=pretrained_dir, |
||||||
|
resume_checkpoint=resume_checkpoint, |
||||||
|
is_backbone_weights=is_backbone_weights) |
||||||
|
|
||||||
|
self.train_loop( |
||||||
|
num_epochs=num_epochs, |
||||||
|
train_dataset=train_dataset, |
||||||
|
train_batch_size=train_batch_size, |
||||||
|
eval_dataset=eval_dataset, |
||||||
|
save_interval_epochs=save_interval_epochs, |
||||||
|
log_interval_steps=log_interval_steps, |
||||||
|
save_dir=save_dir, |
||||||
|
early_stop=early_stop, |
||||||
|
early_stop_patience=early_stop_patience, |
||||||
|
use_vdl=use_vdl) |
||||||
|
|
||||||
|
def quant_aware_train(self, |
||||||
|
num_epochs, |
||||||
|
train_dataset, |
||||||
|
train_batch_size=2, |
||||||
|
eval_dataset=None, |
||||||
|
optimizer=None, |
||||||
|
save_interval_epochs=1, |
||||||
|
log_interval_steps=2, |
||||||
|
save_dir='output', |
||||||
|
learning_rate=0.0001, |
||||||
|
lr_decay_power=0.9, |
||||||
|
early_stop=False, |
||||||
|
early_stop_patience=5, |
||||||
|
use_vdl=True, |
||||||
|
resume_checkpoint=None, |
||||||
|
quant_config=None): |
||||||
|
""" |
||||||
|
Quantization-aware training. |
||||||
|
|
||||||
|
Args: |
||||||
|
num_epochs (int): Number of epochs. |
||||||
|
train_dataset (paddlers.datasets.ResDataset): Training dataset. |
||||||
|
train_batch_size (int, optional): Total batch size among all cards used in |
||||||
|
training. Defaults to 2. |
||||||
|
eval_dataset (paddlers.datasets.ResDataset|None, optional): Evaluation dataset. |
||||||
|
If None, the model will not be evaluated during training process. |
||||||
|
Defaults to None. |
||||||
|
optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in |
||||||
|
training. If None, a default optimizer will be used. Defaults to None. |
||||||
|
save_interval_epochs (int, optional): Epoch interval for saving the model. |
||||||
|
Defaults to 1. |
||||||
|
log_interval_steps (int, optional): Step interval for printing training |
||||||
|
information. Defaults to 2. |
||||||
|
save_dir (str, optional): Directory to save the model. Defaults to 'output'. |
||||||
|
learning_rate (float, optional): Learning rate for training. |
||||||
|
Defaults to .0001. |
||||||
|
lr_decay_power (float, optional): Learning decay power. Defaults to .9. |
||||||
|
early_stop (bool, optional): Whether to adopt early stop strategy. |
||||||
|
Defaults to False. |
||||||
|
early_stop_patience (int, optional): Early stop patience. Defaults to 5. |
||||||
|
use_vdl (bool, optional): Whether to use VisualDL to monitor the training |
||||||
|
process. Defaults to True. |
||||||
|
quant_config (dict|None, optional): Quantization configuration. If None, |
||||||
|
a default rule of thumb configuration will be used. Defaults to None. |
||||||
|
resume_checkpoint (str|None, optional): Path of the checkpoint to resume |
||||||
|
quantization-aware training from. If None, no training checkpoint will |
||||||
|
be resumed. Defaults to None. |
||||||
|
""" |
||||||
|
|
||||||
|
self._prepare_qat(quant_config) |
||||||
|
self.train( |
||||||
|
num_epochs=num_epochs, |
||||||
|
train_dataset=train_dataset, |
||||||
|
train_batch_size=train_batch_size, |
||||||
|
eval_dataset=eval_dataset, |
||||||
|
optimizer=optimizer, |
||||||
|
save_interval_epochs=save_interval_epochs, |
||||||
|
log_interval_steps=log_interval_steps, |
||||||
|
save_dir=save_dir, |
||||||
|
pretrain_weights=None, |
||||||
|
learning_rate=learning_rate, |
||||||
|
lr_decay_power=lr_decay_power, |
||||||
|
early_stop=early_stop, |
||||||
|
early_stop_patience=early_stop_patience, |
||||||
|
use_vdl=use_vdl, |
||||||
|
resume_checkpoint=resume_checkpoint) |
||||||
|
|
||||||
|
def evaluate(self, eval_dataset, batch_size=1, return_details=False): |
||||||
|
""" |
||||||
|
Evaluate the model. |
||||||
|
|
||||||
|
Args: |
||||||
|
eval_dataset (paddlers.datasets.ResDataset): Evaluation dataset. |
||||||
|
batch_size (int, optional): Total batch size among all cards used for |
||||||
|
evaluation. Defaults to 1. |
||||||
|
return_details (bool, optional): Whether to return evaluation details. |
||||||
|
Defaults to False. |
||||||
|
|
||||||
|
Returns: |
||||||
|
If `return_details` is False, return collections.OrderedDict with |
||||||
|
key-value pairs: |
||||||
|
{"psnr": `peak signal-to-noise ratio`, |
||||||
|
"ssim": `structural similarity`}. |
||||||
|
|
||||||
|
""" |
||||||
|
|
||||||
|
self._check_transforms(eval_dataset.transforms, 'eval') |
||||||
|
|
||||||
|
self.net.eval() |
||||||
|
nranks = paddle.distributed.get_world_size() |
||||||
|
local_rank = paddle.distributed.get_rank() |
||||||
|
if nranks > 1: |
||||||
|
# Initialize parallel environment if not done. |
||||||
|
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized( |
||||||
|
): |
||||||
|
paddle.distributed.init_parallel_env() |
||||||
|
|
||||||
|
# TODO: Distributed evaluation |
||||||
|
if batch_size > 1: |
||||||
|
logging.warning( |
||||||
|
"Restorer only supports single card evaluation with batch_size=1 " |
||||||
|
"during evaluation, so batch_size is forcibly set to 1.") |
||||||
|
batch_size = 1 |
||||||
|
|
||||||
|
if nranks < 2 or local_rank == 0: |
||||||
|
self.eval_data_loader = self.build_data_loader( |
||||||
|
eval_dataset, batch_size=batch_size, mode='eval') |
||||||
|
# XXX: Hard-code crop_border and test_y_channel |
||||||
|
psnr = metrics.PSNR(crop_border=4, test_y_channel=True) |
||||||
|
ssim = metrics.SSIM(crop_border=4, test_y_channel=True) |
||||||
|
logging.info( |
||||||
|
"Start to evaluate(total_samples={}, total_steps={})...".format( |
||||||
|
eval_dataset.num_samples, eval_dataset.num_samples)) |
||||||
|
with paddle.no_grad(): |
||||||
|
for step, data in enumerate(self.eval_data_loader): |
||||||
|
data.append(eval_dataset.transforms.transforms) |
||||||
|
outputs = self.run(self.net, data, 'eval') |
||||||
|
psnr.update(outputs['pred'], outputs['tar']) |
||||||
|
ssim.update(outputs['pred'], outputs['tar']) |
||||||
|
|
||||||
|
# DO NOT use psnr.accumulate() here, otherwise the program hangs in multi-card training. |
||||||
|
assert len(psnr.results) > 0 |
||||||
|
assert len(ssim.results) > 0 |
||||||
|
eval_metrics = OrderedDict( |
||||||
|
zip(['psnr', 'ssim'], |
||||||
|
[np.mean(psnr.results), np.mean(ssim.results)])) |
||||||
|
|
||||||
|
if return_details: |
||||||
|
# TODO: Add details |
||||||
|
return eval_metrics, None |
||||||
|
|
||||||
|
return eval_metrics |
||||||
|
|
||||||
|
def predict(self, img_file, transforms=None): |
||||||
|
""" |
||||||
|
Do inference. |
||||||
|
|
||||||
|
Args: |
||||||
|
img_file (list[np.ndarray|str] | str | np.ndarray): Image path or decoded |
||||||
|
image data, which also could constitute a list, meaning all images to be |
||||||
|
predicted as a mini-batch. |
||||||
|
transforms (paddlers.transforms.Compose|None, optional): Transforms for |
||||||
|
inputs. If None, the transforms for evaluation process will be used. |
||||||
|
Defaults to None. |
||||||
|
|
||||||
|
Returns: |
||||||
|
If `img_file` is a tuple of string or np.array, the result is a dict with |
||||||
|
the following key-value pairs: |
||||||
|
res_map (np.ndarray): Restored image (HWC). |
||||||
|
|
||||||
|
If `img_file` is a list, the result is a list composed of dicts with the |
||||||
|
above keys. |
||||||
|
""" |
||||||
|
|
||||||
|
if transforms is None and not hasattr(self, 'test_transforms'): |
||||||
|
raise ValueError("transforms need to be defined, now is None.") |
||||||
|
if transforms is None: |
||||||
|
transforms = self.test_transforms |
||||||
|
if isinstance(img_file, (str, np.ndarray)): |
||||||
|
images = [img_file] |
||||||
|
else: |
||||||
|
images = img_file |
||||||
|
batch_im, batch_tar_shape = self.preprocess(images, transforms, |
||||||
|
self.model_type) |
||||||
|
self.net.eval() |
||||||
|
data = (batch_im, batch_tar_shape, transforms.transforms) |
||||||
|
outputs = self.run(self.net, data, 'test') |
||||||
|
res_map_list = outputs['res_map'] |
||||||
|
if isinstance(img_file, list): |
||||||
|
prediction = [{'res_map': m} for m in res_map_list] |
||||||
|
else: |
||||||
|
prediction = {'res_map': res_map_list[0]} |
||||||
|
return prediction |
||||||
|
|
||||||
|
def preprocess(self, images, transforms, to_tensor=True): |
||||||
|
self._check_transforms(transforms, 'test') |
||||||
|
batch_im = list() |
||||||
|
batch_tar_shape = list() |
||||||
|
for im in images: |
||||||
|
if isinstance(im, str): |
||||||
|
im = decode_image(im, to_rgb=False) |
||||||
|
ori_shape = im.shape[:2] |
||||||
|
sample = {'image': im} |
||||||
|
im = transforms(sample)[0] |
||||||
|
batch_im.append(im) |
||||||
|
batch_tar_shape.append(self._get_target_shape(ori_shape)) |
||||||
|
if to_tensor: |
||||||
|
batch_im = paddle.to_tensor(batch_im) |
||||||
|
else: |
||||||
|
batch_im = np.asarray(batch_im) |
||||||
|
|
||||||
|
return batch_im, batch_tar_shape |
||||||
|
|
||||||
|
def _get_target_shape(self, ori_shape): |
||||||
|
if self.sr_factor is None: |
||||||
|
return ori_shape |
||||||
|
else: |
||||||
|
return calc_hr_shape(ori_shape, self.sr_factor) |
||||||
|
|
||||||
|
@staticmethod |
||||||
|
def get_transforms_shape_info(batch_tar_shape, transforms): |
||||||
|
batch_restore_list = list() |
||||||
|
for tar_shape in batch_tar_shape: |
||||||
|
restore_list = list() |
||||||
|
h, w = tar_shape[0], tar_shape[1] |
||||||
|
for op in transforms: |
||||||
|
if op.__class__.__name__ == 'Resize': |
||||||
|
restore_list.append(('resize', (h, w))) |
||||||
|
h, w = op.target_size |
||||||
|
elif op.__class__.__name__ == 'ResizeByShort': |
||||||
|
restore_list.append(('resize', (h, w))) |
||||||
|
im_short_size = min(h, w) |
||||||
|
im_long_size = max(h, w) |
||||||
|
scale = float(op.short_size) / float(im_short_size) |
||||||
|
if 0 < op.max_size < np.round(scale * im_long_size): |
||||||
|
scale = float(op.max_size) / float(im_long_size) |
||||||
|
h = int(round(h * scale)) |
||||||
|
w = int(round(w * scale)) |
||||||
|
elif op.__class__.__name__ == 'ResizeByLong': |
||||||
|
restore_list.append(('resize', (h, w))) |
||||||
|
im_long_size = max(h, w) |
||||||
|
scale = float(op.long_size) / float(im_long_size) |
||||||
|
h = int(round(h * scale)) |
||||||
|
w = int(round(w * scale)) |
||||||
|
elif op.__class__.__name__ == 'Pad': |
||||||
|
if op.target_size: |
||||||
|
target_h, target_w = op.target_size |
||||||
|
else: |
||||||
|
target_h = int( |
||||||
|
(np.ceil(h / op.size_divisor) * op.size_divisor)) |
||||||
|
target_w = int( |
||||||
|
(np.ceil(w / op.size_divisor) * op.size_divisor)) |
||||||
|
|
||||||
|
if op.pad_mode == -1: |
||||||
|
offsets = op.offsets |
||||||
|
elif op.pad_mode == 0: |
||||||
|
offsets = [0, 0] |
||||||
|
elif op.pad_mode == 1: |
||||||
|
offsets = [(target_h - h) // 2, (target_w - w) // 2] |
||||||
|
else: |
||||||
|
offsets = [target_h - h, target_w - w] |
||||||
|
restore_list.append(('padding', (h, w), offsets)) |
||||||
|
h, w = target_h, target_w |
||||||
|
|
||||||
|
batch_restore_list.append(restore_list) |
||||||
|
return batch_restore_list |
||||||
|
|
||||||
|
def postprocess(self, batch_pred, batch_tar_shape, transforms): |
||||||
|
batch_restore_list = BaseRestorer.get_transforms_shape_info( |
||||||
|
batch_tar_shape, transforms) |
||||||
|
if self.status == 'Infer': |
||||||
|
return self._infer_postprocess( |
||||||
|
batch_res_map=batch_pred, batch_restore_list=batch_restore_list) |
||||||
|
results = [] |
||||||
|
if batch_pred.dtype == paddle.float32: |
||||||
|
mode = 'bilinear' |
||||||
|
else: |
||||||
|
mode = 'nearest' |
||||||
|
for pred, restore_list in zip(batch_pred, batch_restore_list): |
||||||
|
pred = paddle.unsqueeze(pred, axis=0) |
||||||
|
for item in restore_list[::-1]: |
||||||
|
h, w = item[1][0], item[1][1] |
||||||
|
if item[0] == 'resize': |
||||||
|
pred = F.interpolate( |
||||||
|
pred, (h, w), mode=mode, data_format='NCHW') |
||||||
|
elif item[0] == 'padding': |
||||||
|
x, y = item[2] |
||||||
|
pred = pred[:, :, y:y + h, x:x + w] |
||||||
|
else: |
||||||
|
pass |
||||||
|
results.append(pred) |
||||||
|
return results |
||||||
|
|
||||||
|
def _infer_postprocess(self, batch_res_map, batch_restore_list): |
||||||
|
res_maps = [] |
||||||
|
for res_map, restore_list in zip(batch_res_map, batch_restore_list): |
||||||
|
if not isinstance(res_map, np.ndarray): |
||||||
|
res_map = paddle.unsqueeze(res_map, axis=0) |
||||||
|
for item in restore_list[::-1]: |
||||||
|
h, w = item[1][0], item[1][1] |
||||||
|
if item[0] == 'resize': |
||||||
|
if isinstance(res_map, np.ndarray): |
||||||
|
res_map = cv2.resize( |
||||||
|
res_map, (w, h), interpolation=cv2.INTER_LINEAR) |
||||||
|
else: |
||||||
|
res_map = F.interpolate( |
||||||
|
res_map, (h, w), |
||||||
|
mode='bilinear', |
||||||
|
data_format='NHWC') |
||||||
|
elif item[0] == 'padding': |
||||||
|
x, y = item[2] |
||||||
|
if isinstance(res_map, np.ndarray): |
||||||
|
res_map = res_map[y:y + h, x:x + w] |
||||||
|
else: |
||||||
|
res_map = res_map[:, y:y + h, x:x + w, :] |
||||||
|
else: |
||||||
|
pass |
||||||
|
res_map = res_map.squeeze() |
||||||
|
if not isinstance(res_map, np.ndarray): |
||||||
|
res_map = res_map.numpy() |
||||||
|
res_map = self._normalize(res_map) |
||||||
|
res_maps.append(res_map.squeeze()) |
||||||
|
return res_maps |
||||||
|
|
||||||
|
def _check_transforms(self, transforms, mode): |
||||||
|
super()._check_transforms(transforms, mode) |
||||||
|
if not isinstance(transforms.arrange, |
||||||
|
paddlers.transforms.ArrangeRestorer): |
||||||
|
raise TypeError( |
||||||
|
"`transforms.arrange` must be an ArrangeRestorer object.") |
||||||
|
|
||||||
|
def build_data_loader(self, dataset, batch_size, mode='train'): |
||||||
|
if dataset.num_samples < batch_size: |
||||||
|
raise ValueError( |
||||||
|
'The volume of dataset({}) must be larger than batch size({}).' |
||||||
|
.format(dataset.num_samples, batch_size)) |
||||||
|
|
||||||
|
if mode != 'train': |
||||||
|
return paddle.io.DataLoader( |
||||||
|
dataset, |
||||||
|
batch_size=batch_size, |
||||||
|
shuffle=dataset.shuffle, |
||||||
|
drop_last=False, |
||||||
|
collate_fn=dataset.batch_transforms, |
||||||
|
num_workers=dataset.num_workers, |
||||||
|
return_list=True, |
||||||
|
use_shared_memory=False) |
||||||
|
else: |
||||||
|
return super(BaseRestorer, self).build_data_loader(dataset, |
||||||
|
batch_size, mode) |
||||||
|
|
||||||
|
def set_losses(self, losses): |
||||||
|
self.losses = losses |
||||||
|
|
||||||
|
def _tensor_to_images(self, |
||||||
|
tensor, |
||||||
|
transpose=True, |
||||||
|
squeeze=True, |
||||||
|
quantize=True): |
||||||
|
if transpose: |
||||||
|
tensor = paddle.transpose(tensor, perm=[0, 2, 3, 1]) # NHWC |
||||||
|
if squeeze: |
||||||
|
tensor = tensor.squeeze() |
||||||
|
images = tensor.numpy().astype('float32') |
||||||
|
images = self._normalize( |
||||||
|
images, copy=True, clip=True, quantize=quantize) |
||||||
|
return images |
||||||
|
|
||||||
|
def _normalize(self, im, copy=False, clip=True, quantize=True): |
||||||
|
if copy: |
||||||
|
im = im.copy() |
||||||
|
if clip: |
||||||
|
im = np.clip(im, self.MIN_MAX[0], self.MIN_MAX[1]) |
||||||
|
im -= im.min() |
||||||
|
im /= im.max() + 1e-32 |
||||||
|
if quantize: |
||||||
|
im *= 255 |
||||||
|
im = im.astype('uint8') |
||||||
|
return im |
||||||
|
|
||||||
|
|
||||||
|
class DRN(BaseRestorer): |
||||||
|
TEST_OUT_KEY = -1 |
||||||
|
|
||||||
|
def __init__(self, |
||||||
|
losses=None, |
||||||
|
sr_factor=4, |
||||||
|
scales=(2, 4), |
||||||
|
n_blocks=30, |
||||||
|
n_feats=16, |
||||||
|
n_colors=3, |
||||||
|
rgb_range=1.0, |
||||||
|
negval=0.2, |
||||||
|
lq_loss_weight=0.1, |
||||||
|
dual_loss_weight=0.1, |
||||||
|
**params): |
||||||
|
if sr_factor != max(scales): |
||||||
|
raise ValueError(f"`sr_factor` must be equal to `max(scales)`.") |
||||||
|
params.update({ |
||||||
|
'scale': scales, |
||||||
|
'n_blocks': n_blocks, |
||||||
|
'n_feats': n_feats, |
||||||
|
'n_colors': n_colors, |
||||||
|
'rgb_range': rgb_range, |
||||||
|
'negval': negval |
||||||
|
}) |
||||||
|
self.lq_loss_weight = lq_loss_weight |
||||||
|
self.dual_loss_weight = dual_loss_weight |
||||||
|
self.scales = scales |
||||||
|
super(DRN, self).__init__( |
||||||
|
model_name='DRN', losses=losses, sr_factor=sr_factor, **params) |
||||||
|
|
||||||
|
def build_net(self, **params): |
||||||
|
from ppgan.modules.init import init_weights |
||||||
|
generators = [ppgan.models.generators.DRNGenerator(**params)] |
||||||
|
init_weights(generators[-1]) |
||||||
|
for scale in params['scale']: |
||||||
|
dual_model = ppgan.models.generators.drn.DownBlock( |
||||||
|
params['negval'], params['n_feats'], params['n_colors'], 2) |
||||||
|
generators.append(dual_model) |
||||||
|
init_weights(generators[-1]) |
||||||
|
return GANAdapter(generators, []) |
||||||
|
|
||||||
|
def default_optimizer(self, parameters, *args, **kwargs): |
||||||
|
optims_g = [ |
||||||
|
super(DRN, self).default_optimizer(params_g, *args, **kwargs) |
||||||
|
for params_g in parameters['params_g'] |
||||||
|
] |
||||||
|
return OptimizerAdapter(*optims_g) |
||||||
|
|
||||||
|
def run_gan(self, net, inputs, mode, gan_mode='forward_primary'): |
||||||
|
if mode != 'train': |
||||||
|
raise ValueError("`mode` is not 'train'.") |
||||||
|
outputs = OrderedDict() |
||||||
|
if gan_mode == 'forward_primary': |
||||||
|
sr = net.generator(inputs[0]) |
||||||
|
lr = [inputs[0]] |
||||||
|
lr.extend([ |
||||||
|
F.interpolate( |
||||||
|
inputs[0], scale_factor=s, mode='bicubic') |
||||||
|
for s in self.scales[:-1] |
||||||
|
]) |
||||||
|
loss = self.losses(sr[-1], inputs[1]) |
||||||
|
for i in range(1, len(sr)): |
||||||
|
if self.lq_loss_weight > 0: |
||||||
|
loss += self.losses(sr[i - 1 - len(sr)], |
||||||
|
lr[i - len(sr)]) * self.lq_loss_weight |
||||||
|
outputs['loss_prim'] = loss |
||||||
|
outputs['sr'] = sr |
||||||
|
outputs['lr'] = lr |
||||||
|
elif gan_mode == 'forward_dual': |
||||||
|
sr, lr = inputs[0], inputs[1] |
||||||
|
sr2lr = [] |
||||||
|
n_scales = len(self.scales) |
||||||
|
for i in range(n_scales): |
||||||
|
sr2lr_i = net.generators[1 + i](sr[i - n_scales]) |
||||||
|
sr2lr.append(sr2lr_i) |
||||||
|
loss = self.losses(sr2lr[0], lr[0]) |
||||||
|
for i in range(1, n_scales): |
||||||
|
if self.dual_loss_weight > 0.0: |
||||||
|
loss += self.losses(sr2lr[i], lr[i]) * self.dual_loss_weight |
||||||
|
outputs['loss_dual'] = loss |
||||||
|
else: |
||||||
|
raise ValueError("Invalid `gan_mode`!") |
||||||
|
return outputs |
||||||
|
|
||||||
|
def train_step(self, step, data, net): |
||||||
|
outputs = self.run_gan( |
||||||
|
net, data, mode='train', gan_mode='forward_primary') |
||||||
|
outputs.update( |
||||||
|
self.run_gan( |
||||||
|
net, (outputs['sr'], outputs['lr']), |
||||||
|
mode='train', |
||||||
|
gan_mode='forward_dual')) |
||||||
|
self.optimizer.clear_grad() |
||||||
|
(outputs['loss_prim'] + outputs['loss_dual']).backward() |
||||||
|
self.optimizer.step() |
||||||
|
return { |
||||||
|
'loss': outputs['loss_prim'] + outputs['loss_dual'], |
||||||
|
'loss_prim': outputs['loss_prim'], |
||||||
|
'loss_dual': outputs['loss_dual'] |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
class LESRCNN(BaseRestorer): |
||||||
|
def __init__(self, |
||||||
|
losses=None, |
||||||
|
sr_factor=4, |
||||||
|
multi_scale=False, |
||||||
|
group=1, |
||||||
|
**params): |
||||||
|
params.update({ |
||||||
|
'scale': sr_factor, |
||||||
|
'multi_scale': multi_scale, |
||||||
|
'group': group |
||||||
|
}) |
||||||
|
super(LESRCNN, self).__init__( |
||||||
|
model_name='LESRCNN', losses=losses, sr_factor=sr_factor, **params) |
||||||
|
|
||||||
|
def build_net(self, **params): |
||||||
|
net = ppgan.models.generators.LESRCNNGenerator(**params) |
||||||
|
return net |
||||||
|
|
||||||
|
|
||||||
|
class ESRGAN(BaseRestorer): |
||||||
|
def __init__(self, |
||||||
|
losses=None, |
||||||
|
sr_factor=4, |
||||||
|
use_gan=True, |
||||||
|
in_channels=3, |
||||||
|
out_channels=3, |
||||||
|
nf=64, |
||||||
|
nb=23, |
||||||
|
**params): |
||||||
|
if sr_factor != 4: |
||||||
|
raise ValueError("`sr_factor` must be 4.") |
||||||
|
params.update({ |
||||||
|
'in_nc': in_channels, |
||||||
|
'out_nc': out_channels, |
||||||
|
'nf': nf, |
||||||
|
'nb': nb |
||||||
|
}) |
||||||
|
self.use_gan = use_gan |
||||||
|
super(ESRGAN, self).__init__( |
||||||
|
model_name='ESRGAN', losses=losses, sr_factor=sr_factor, **params) |
||||||
|
|
||||||
|
def build_net(self, **params): |
||||||
|
from ppgan.modules.init import init_weights |
||||||
|
generator = ppgan.models.generators.RRDBNet(**params) |
||||||
|
init_weights(generator) |
||||||
|
if self.use_gan: |
||||||
|
discriminator = ppgan.models.discriminators.VGGDiscriminator128( |
||||||
|
in_channels=params['out_nc'], num_feat=64) |
||||||
|
net = GANAdapter( |
||||||
|
generators=[generator], discriminators=[discriminator]) |
||||||
|
else: |
||||||
|
net = generator |
||||||
|
return net |
||||||
|
|
||||||
|
def default_loss(self): |
||||||
|
if self.use_gan: |
||||||
|
return { |
||||||
|
'pixel': res_losses.L1Loss(loss_weight=0.01), |
||||||
|
'perceptual': res_losses.PerceptualLoss( |
||||||
|
layer_weights={'34': 1.0}, |
||||||
|
perceptual_weight=1.0, |
||||||
|
style_weight=0.0, |
||||||
|
norm_img=False), |
||||||
|
'gan': res_losses.GANLoss( |
||||||
|
gan_mode='vanilla', loss_weight=0.005) |
||||||
|
} |
||||||
|
else: |
||||||
|
return res_losses.L1Loss() |
||||||
|
|
||||||
|
def default_optimizer(self, parameters, *args, **kwargs): |
||||||
|
if self.use_gan: |
||||||
|
optim_g = super(ESRGAN, self).default_optimizer( |
||||||
|
parameters['params_g'][0], *args, **kwargs) |
||||||
|
optim_d = super(ESRGAN, self).default_optimizer( |
||||||
|
parameters['params_d'][0], *args, **kwargs) |
||||||
|
return OptimizerAdapter(optim_g, optim_d) |
||||||
|
else: |
||||||
|
return super(ESRGAN, self).default_optimizer(parameters, *args, |
||||||
|
**kwargs) |
||||||
|
|
||||||
|
def run_gan(self, net, inputs, mode, gan_mode='forward_g'): |
||||||
|
if mode != 'train': |
||||||
|
raise ValueError("`mode` is not 'train'.") |
||||||
|
outputs = OrderedDict() |
||||||
|
if gan_mode == 'forward_g': |
||||||
|
loss_g = 0 |
||||||
|
g_pred = net.generator(inputs[0]) |
||||||
|
loss_pix = self.losses['pixel'](g_pred, inputs[1]) |
||||||
|
loss_perc, loss_sty = self.losses['perceptual'](g_pred, inputs[1]) |
||||||
|
loss_g += loss_pix |
||||||
|
if loss_perc is not None: |
||||||
|
loss_g += loss_perc |
||||||
|
if loss_sty is not None: |
||||||
|
loss_g += loss_sty |
||||||
|
self._set_requires_grad(net.discriminator, False) |
||||||
|
real_d_pred = net.discriminator(inputs[1]).detach() |
||||||
|
fake_g_pred = net.discriminator(g_pred) |
||||||
|
loss_g_real = self.losses['gan']( |
||||||
|
real_d_pred - paddle.mean(fake_g_pred), False, |
||||||
|
is_disc=False) * 0.5 |
||||||
|
loss_g_fake = self.losses['gan']( |
||||||
|
fake_g_pred - paddle.mean(real_d_pred), True, |
||||||
|
is_disc=False) * 0.5 |
||||||
|
loss_g_gan = loss_g_real + loss_g_fake |
||||||
|
outputs['g_pred'] = g_pred.detach() |
||||||
|
outputs['loss_g_pps'] = loss_g |
||||||
|
outputs['loss_g_gan'] = loss_g_gan |
||||||
|
elif gan_mode == 'forward_d': |
||||||
|
self._set_requires_grad(net.discriminator, True) |
||||||
|
# Real |
||||||
|
fake_d_pred = net.discriminator(inputs[0]).detach() |
||||||
|
real_d_pred = net.discriminator(inputs[1]) |
||||||
|
loss_d_real = self.losses['gan']( |
||||||
|
real_d_pred - paddle.mean(fake_d_pred), True, |
||||||
|
is_disc=True) * 0.5 |
||||||
|
# Fake |
||||||
|
fake_d_pred = net.discriminator(inputs[0].detach()) |
||||||
|
loss_d_fake = self.losses['gan']( |
||||||
|
fake_d_pred - paddle.mean(real_d_pred.detach()), |
||||||
|
False, |
||||||
|
is_disc=True) * 0.5 |
||||||
|
outputs['loss_d'] = loss_d_real + loss_d_fake |
||||||
|
else: |
||||||
|
raise ValueError("Invalid `gan_mode`!") |
||||||
|
return outputs |
||||||
|
|
||||||
|
def train_step(self, step, data, net): |
||||||
|
if self.use_gan: |
||||||
|
optim_g, optim_d = self.optimizer |
||||||
|
|
||||||
|
outputs = self.run_gan( |
||||||
|
net, data, mode='train', gan_mode='forward_g') |
||||||
|
optim_g.clear_grad() |
||||||
|
(outputs['loss_g_pps'] + outputs['loss_g_gan']).backward() |
||||||
|
optim_g.step() |
||||||
|
|
||||||
|
outputs.update( |
||||||
|
self.run_gan( |
||||||
|
net, (outputs['g_pred'], data[1]), |
||||||
|
mode='train', |
||||||
|
gan_mode='forward_d')) |
||||||
|
optim_d.clear_grad() |
||||||
|
outputs['loss_d'].backward() |
||||||
|
optim_d.step() |
||||||
|
|
||||||
|
outputs['loss'] = outputs['loss_g_pps'] + outputs[ |
||||||
|
'loss_g_gan'] + outputs['loss_d'] |
||||||
|
|
||||||
|
return { |
||||||
|
'loss': outputs['loss'], |
||||||
|
'loss_g_pps': outputs['loss_g_pps'], |
||||||
|
'loss_g_gan': outputs['loss_g_gan'], |
||||||
|
'loss_d': outputs['loss_d'] |
||||||
|
} |
||||||
|
else: |
||||||
|
return super(ESRGAN, self).train_step(step, data, net) |
||||||
|
|
||||||
|
def _set_requires_grad(self, net, requires_grad): |
||||||
|
for p in net.parameters(): |
||||||
|
p.trainable = requires_grad |
||||||
|
|
||||||
|
|
||||||
|
class RCAN(BaseRestorer): |
||||||
|
def __init__(self, |
||||||
|
losses=None, |
||||||
|
sr_factor=4, |
||||||
|
n_resgroups=10, |
||||||
|
n_resblocks=20, |
||||||
|
n_feats=64, |
||||||
|
n_colors=3, |
||||||
|
rgb_range=1.0, |
||||||
|
kernel_size=3, |
||||||
|
reduction=16, |
||||||
|
**params): |
||||||
|
params.update({ |
||||||
|
'n_resgroups': n_resgroups, |
||||||
|
'n_resblocks': n_resblocks, |
||||||
|
'n_feats': n_feats, |
||||||
|
'n_colors': n_colors, |
||||||
|
'rgb_range': rgb_range, |
||||||
|
'kernel_size': kernel_size, |
||||||
|
'reduction': reduction |
||||||
|
}) |
||||||
|
super(RCAN, self).__init__( |
||||||
|
model_name='RCAN', losses=losses, sr_factor=sr_factor, **params) |
@ -0,0 +1,132 @@ |
|||||||
|
from functools import wraps |
||||||
|
from inspect import isfunction, isgeneratorfunction, getmembers |
||||||
|
from collections.abc import Sequence |
||||||
|
from abc import ABC |
||||||
|
|
||||||
|
import paddle |
||||||
|
import paddle.nn as nn |
||||||
|
|
||||||
|
__all__ = ['GANAdapter', 'OptimizerAdapter'] |
||||||
|
|
||||||
|
|
||||||
|
class _AttrDesc: |
||||||
|
def __init__(self, key): |
||||||
|
self.key = key |
||||||
|
|
||||||
|
def __get__(self, instance, owner): |
||||||
|
return tuple(getattr(ele, self.key) for ele in instance) |
||||||
|
|
||||||
|
def __set__(self, instance, value): |
||||||
|
for ele in instance: |
||||||
|
setattr(ele, self.key, value) |
||||||
|
|
||||||
|
|
||||||
|
def _func_deco(cls, func_name): |
||||||
|
@wraps(getattr(cls.__ducktype__, func_name)) |
||||||
|
def _wrapper(self, *args, **kwargs): |
||||||
|
return tuple(getattr(ele, func_name)(*args, **kwargs) for ele in self) |
||||||
|
|
||||||
|
return _wrapper |
||||||
|
|
||||||
|
|
||||||
|
def _generator_deco(cls, func_name): |
||||||
|
@wraps(getattr(cls.__ducktype__, func_name)) |
||||||
|
def _wrapper(self, *args, **kwargs): |
||||||
|
for ele in self: |
||||||
|
yield from getattr(ele, func_name)(*args, **kwargs) |
||||||
|
|
||||||
|
return _wrapper |
||||||
|
|
||||||
|
|
||||||
|
class Adapter(Sequence, ABC): |
||||||
|
__ducktype__ = object |
||||||
|
__ava__ = () |
||||||
|
|
||||||
|
def __init__(self, *args): |
||||||
|
if not all(map(self._check, args)): |
||||||
|
raise TypeError("Please check the input type.") |
||||||
|
self._seq = tuple(args) |
||||||
|
|
||||||
|
def __getitem__(self, key): |
||||||
|
return self._seq[key] |
||||||
|
|
||||||
|
def __len__(self): |
||||||
|
return len(self._seq) |
||||||
|
|
||||||
|
def __repr__(self): |
||||||
|
return repr(self._seq) |
||||||
|
|
||||||
|
@classmethod |
||||||
|
def _check(cls, obj): |
||||||
|
for attr in cls.__ava__: |
||||||
|
try: |
||||||
|
getattr(obj, attr) |
||||||
|
# TODO: Check function signature |
||||||
|
except AttributeError: |
||||||
|
return False |
||||||
|
return True |
||||||
|
|
||||||
|
|
||||||
|
def make_adapter(cls): |
||||||
|
members = dict(getmembers(cls.__ducktype__)) |
||||||
|
for k in cls.__ava__: |
||||||
|
if hasattr(cls, k): |
||||||
|
continue |
||||||
|
if k in members: |
||||||
|
v = members[k] |
||||||
|
if isgeneratorfunction(v): |
||||||
|
setattr(cls, k, _generator_deco(cls, k)) |
||||||
|
elif isfunction(v): |
||||||
|
setattr(cls, k, _func_deco(cls, k)) |
||||||
|
else: |
||||||
|
setattr(cls, k, _AttrDesc(k)) |
||||||
|
return cls |
||||||
|
|
||||||
|
|
||||||
|
class GANAdapter(nn.Layer): |
||||||
|
__ducktype__ = nn.Layer |
||||||
|
__ava__ = ('state_dict', 'set_state_dict', 'train', 'eval') |
||||||
|
|
||||||
|
def __init__(self, generators, discriminators): |
||||||
|
super(GANAdapter, self).__init__() |
||||||
|
self.generators = nn.LayerList(generators) |
||||||
|
self.discriminators = nn.LayerList(discriminators) |
||||||
|
self._m = [*generators, *discriminators] |
||||||
|
|
||||||
|
def __len__(self): |
||||||
|
return len(self._m) |
||||||
|
|
||||||
|
def __getitem__(self, key): |
||||||
|
return self._m[key] |
||||||
|
|
||||||
|
def __contains__(self, m): |
||||||
|
return m in self._m |
||||||
|
|
||||||
|
def __repr__(self): |
||||||
|
return repr(self._m) |
||||||
|
|
||||||
|
@property |
||||||
|
def generator(self): |
||||||
|
return self.generators[0] |
||||||
|
|
||||||
|
@property |
||||||
|
def discriminator(self): |
||||||
|
return self.discriminators[0] |
||||||
|
|
||||||
|
|
||||||
|
Adapter.register(GANAdapter) |
||||||
|
|
||||||
|
|
||||||
|
@make_adapter |
||||||
|
class OptimizerAdapter(Adapter): |
||||||
|
__ducktype__ = paddle.optimizer.Optimizer |
||||||
|
__ava__ = ('state_dict', 'set_state_dict', 'clear_grad', 'step', 'get_lr') |
||||||
|
|
||||||
|
def set_state_dict(self, state_dicts): |
||||||
|
# Special dispatching rule |
||||||
|
for optim, state_dict in zip(self, state_dicts): |
||||||
|
optim.set_state_dict(state_dict) |
||||||
|
|
||||||
|
def get_lr(self): |
||||||
|
# Return the lr of the first optimizer |
||||||
|
return self[0].get_lr() |
@ -0,0 +1,8 @@ |
|||||||
|
# Configurations of CDNet with AirChange dataset |
||||||
|
|
||||||
|
_base_: ../_base_/airchange.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/cdnet/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: CDNet |
@ -0,0 +1,8 @@ |
|||||||
|
# Configurations of cdnet with LEVIR-CD dataset |
||||||
|
|
||||||
|
_base_: ../_base_/levircd.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/cdnet/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: CDNet |
@ -0,0 +1,53 @@ |
|||||||
|
===========================train_params=========================== |
||||||
|
model_name:cd:cdnet |
||||||
|
python:python |
||||||
|
gpu_list:0|0,1 |
||||||
|
use_gpu:null|null |
||||||
|
--precision:null |
||||||
|
--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=10 |
||||||
|
--save_dir:adaptive |
||||||
|
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 |
||||||
|
--model_path:null |
||||||
|
--config:lite_train_lite_infer=./test_tipc/configs/cd/cdnet/cdnet_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/cdnet/cdnet_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/cdnet/cdnet_levircd.yaml |
||||||
|
train_model_name:best_model |
||||||
|
null:null |
||||||
|
## |
||||||
|
trainer:norm |
||||||
|
norm_train:test_tipc/run_task.py train cd |
||||||
|
pact_train:null |
||||||
|
fpgm_train:null |
||||||
|
distill_train:null |
||||||
|
null:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================eval_params=========================== |
||||||
|
eval:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================export_params=========================== |
||||||
|
--save_dir:adaptive |
||||||
|
--model_dir:adaptive |
||||||
|
--fixed_input_shape:[-1,3,256,256] |
||||||
|
norm_export:deploy/export/export_model.py |
||||||
|
quant_export:null |
||||||
|
fpgm_export:null |
||||||
|
distill_export:null |
||||||
|
export1:null |
||||||
|
export2:null |
||||||
|
===========================infer_params=========================== |
||||||
|
infer_model:null |
||||||
|
infer_export:null |
||||||
|
infer_quant:False |
||||||
|
inference:test_tipc/infer.py |
||||||
|
--device:cpu|gpu |
||||||
|
--enable_mkldnn:True |
||||||
|
--cpu_threads:6 |
||||||
|
--batch_size:1 |
||||||
|
--use_trt:False |
||||||
|
--precision:fp32 |
||||||
|
--model_dir:null |
||||||
|
--config:null |
||||||
|
--save_log_path:null |
||||||
|
--benchmark:True |
||||||
|
--model_name:cdnet |
||||||
|
null:null |
@ -0,0 +1,8 @@ |
|||||||
|
# Configurations of ChangeFormer with AirChange dataset |
||||||
|
|
||||||
|
_base_: ../_base_/airchange.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/changeformer/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: ChangeFormer |
@ -0,0 +1,8 @@ |
|||||||
|
# Configurations of ChangeFormer with LEVIR-CD dataset |
||||||
|
|
||||||
|
_base_: ../_base_/levircd.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/changeformer/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: ChangeFormer |
@ -0,0 +1,8 @@ |
|||||||
|
# Configurations of DSAMNet with AirChange dataset |
||||||
|
|
||||||
|
_base_: ../_base_/airchange.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/dsamnet/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: DSAMNet |
@ -0,0 +1,8 @@ |
|||||||
|
# Configurations of DSAMNet with LEVIR-CD dataset |
||||||
|
|
||||||
|
_base_: ../_base_/levircd.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/dsamnet/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: DSAMNet |
@ -0,0 +1,53 @@ |
|||||||
|
===========================train_params=========================== |
||||||
|
model_name:cd:dsamnet |
||||||
|
python:python |
||||||
|
gpu_list:0|0,1 |
||||||
|
use_gpu:null|null |
||||||
|
--precision:null |
||||||
|
--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=10 |
||||||
|
--save_dir:adaptive |
||||||
|
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 |
||||||
|
--model_path:null |
||||||
|
--config:lite_train_lite_infer=./test_tipc/configs/cd/dsamnet/dsamnet_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/dsamnet/dsamnet_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/dsamnet/dsamnet_levircd.yaml |
||||||
|
train_model_name:best_model |
||||||
|
null:null |
||||||
|
## |
||||||
|
trainer:norm |
||||||
|
norm_train:test_tipc/run_task.py train cd |
||||||
|
pact_train:null |
||||||
|
fpgm_train:null |
||||||
|
distill_train:null |
||||||
|
null:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================eval_params=========================== |
||||||
|
eval:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================export_params=========================== |
||||||
|
--save_dir:adaptive |
||||||
|
--model_dir:adaptive |
||||||
|
--fixed_input_shape:[-1,3,256,256] |
||||||
|
norm_export:deploy/export/export_model.py |
||||||
|
quant_export:null |
||||||
|
fpgm_export:null |
||||||
|
distill_export:null |
||||||
|
export1:null |
||||||
|
export2:null |
||||||
|
===========================infer_params=========================== |
||||||
|
infer_model:null |
||||||
|
infer_export:null |
||||||
|
infer_quant:False |
||||||
|
inference:test_tipc/infer.py |
||||||
|
--device:cpu|gpu |
||||||
|
--enable_mkldnn:True |
||||||
|
--cpu_threads:6 |
||||||
|
--batch_size:1 |
||||||
|
--use_trt:False |
||||||
|
--precision:fp32 |
||||||
|
--model_dir:null |
||||||
|
--config:null |
||||||
|
--save_log_path:null |
||||||
|
--benchmark:True |
||||||
|
--model_name:dsamnet |
||||||
|
null:null |
@ -0,0 +1,8 @@ |
|||||||
|
# Configurations of DSIFN with AirChange dataset |
||||||
|
|
||||||
|
_base_: ../_base_/airchange.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/dsifn/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: DSIFN |
@ -0,0 +1,8 @@ |
|||||||
|
# Configurations of DSIFN with LEVIR-CD dataset |
||||||
|
|
||||||
|
_base_: ../_base_/levircd.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/dsifn/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: DSIFN |
@ -0,0 +1,53 @@ |
|||||||
|
===========================train_params=========================== |
||||||
|
model_name:cd:dsifn |
||||||
|
python:python |
||||||
|
gpu_list:0|0,1 |
||||||
|
use_gpu:null|null |
||||||
|
--precision:null |
||||||
|
--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=10 |
||||||
|
--save_dir:adaptive |
||||||
|
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 |
||||||
|
--model_path:null |
||||||
|
--config:lite_train_lite_infer=./test_tipc/configs/cd/dsifn/dsifn_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/dsifn/dsifn_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/dsifn/dsifn_levircd.yaml |
||||||
|
train_model_name:best_model |
||||||
|
null:null |
||||||
|
## |
||||||
|
trainer:norm |
||||||
|
norm_train:test_tipc/run_task.py train cd |
||||||
|
pact_train:null |
||||||
|
fpgm_train:null |
||||||
|
distill_train:null |
||||||
|
null:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================eval_params=========================== |
||||||
|
eval:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================export_params=========================== |
||||||
|
--save_dir:adaptive |
||||||
|
--model_dir:adaptive |
||||||
|
--fixed_input_shape:[-1,3,256,256] |
||||||
|
norm_export:deploy/export/export_model.py |
||||||
|
quant_export:null |
||||||
|
fpgm_export:null |
||||||
|
distill_export:null |
||||||
|
export1:null |
||||||
|
export2:null |
||||||
|
===========================infer_params=========================== |
||||||
|
infer_model:null |
||||||
|
infer_export:null |
||||||
|
infer_quant:False |
||||||
|
inference:test_tipc/infer.py |
||||||
|
--device:cpu|gpu |
||||||
|
--enable_mkldnn:True |
||||||
|
--cpu_threads:6 |
||||||
|
--batch_size:1 |
||||||
|
--use_trt:False |
||||||
|
--precision:fp32 |
||||||
|
--model_dir:null |
||||||
|
--config:null |
||||||
|
--save_log_path:null |
||||||
|
--benchmark:True |
||||||
|
--model_name:dsifn |
||||||
|
null:null |
@ -0,0 +1,8 @@ |
|||||||
|
# Configurations of FC-EF with AirChange dataset |
||||||
|
|
||||||
|
_base_: ../_base_/airchange.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/fc_ef/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: FCEarlyFusion |
@ -0,0 +1,8 @@ |
|||||||
|
# Configurations of FC-EF with LEVIR-CD dataset |
||||||
|
|
||||||
|
_base_: ../_base_/levircd.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/fc_ef/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: FCEarlyFusion |
@ -0,0 +1,53 @@ |
|||||||
|
===========================train_params=========================== |
||||||
|
model_name:cd:fc_ef |
||||||
|
python:python |
||||||
|
gpu_list:0|0,1 |
||||||
|
use_gpu:null|null |
||||||
|
--precision:null |
||||||
|
--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=20 |
||||||
|
--save_dir:adaptive |
||||||
|
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 |
||||||
|
--model_path:null |
||||||
|
--config:lite_train_lite_infer=./test_tipc/configs/cd/fc_ef/fc_ef_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/fc_ef/fc_ef_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/fc_ef/fc_ef_levircd.yaml |
||||||
|
train_model_name:best_model |
||||||
|
null:null |
||||||
|
## |
||||||
|
trainer:norm |
||||||
|
norm_train:test_tipc/run_task.py train cd |
||||||
|
pact_train:null |
||||||
|
fpgm_train:null |
||||||
|
distill_train:null |
||||||
|
null:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================eval_params=========================== |
||||||
|
eval:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================export_params=========================== |
||||||
|
--save_dir:adaptive |
||||||
|
--model_dir:adaptive |
||||||
|
--fixed_input_shape:[-1,3,256,256] |
||||||
|
norm_export:deploy/export/export_model.py |
||||||
|
quant_export:null |
||||||
|
fpgm_export:null |
||||||
|
distill_export:null |
||||||
|
export1:null |
||||||
|
export2:null |
||||||
|
===========================infer_params=========================== |
||||||
|
infer_model:null |
||||||
|
infer_export:null |
||||||
|
infer_quant:False |
||||||
|
inference:test_tipc/infer.py |
||||||
|
--device:cpu|gpu |
||||||
|
--enable_mkldnn:True |
||||||
|
--cpu_threads:6 |
||||||
|
--batch_size:1 |
||||||
|
--use_trt:False |
||||||
|
--precision:fp32 |
||||||
|
--model_dir:null |
||||||
|
--config:null |
||||||
|
--save_log_path:null |
||||||
|
--benchmark:True |
||||||
|
--model_name:fc_ef |
||||||
|
null:null |
@ -0,0 +1,8 @@ |
|||||||
|
# Configurations of FC-Siam-conc with AirChange dataset |
||||||
|
|
||||||
|
_base_: ../_base_/airchange.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/fc_siam_conc/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: FCSiamConc |
@ -0,0 +1,8 @@ |
|||||||
|
# Configurations of FC-Siam-conc with LEVIR-CD dataset |
||||||
|
|
||||||
|
_base_: ../_base_/levircd.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/fc_siam_conc/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: FCSiamConc |
@ -0,0 +1,53 @@ |
|||||||
|
===========================train_params=========================== |
||||||
|
model_name:cd:fc_siam_conc |
||||||
|
python:python |
||||||
|
gpu_list:0|0,1 |
||||||
|
use_gpu:null|null |
||||||
|
--precision:null |
||||||
|
--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=20 |
||||||
|
--save_dir:adaptive |
||||||
|
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 |
||||||
|
--model_path:null |
||||||
|
--config:lite_train_lite_infer=./test_tipc/configs/cd/fc_siam_conc/fc_siam_conc_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/fc_siam_conc/fc_siam_conc_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/fc_siam_conc/fc_siam_conc_levircd.yaml |
||||||
|
train_model_name:best_model |
||||||
|
null:null |
||||||
|
## |
||||||
|
trainer:norm |
||||||
|
norm_train:test_tipc/run_task.py train cd |
||||||
|
pact_train:null |
||||||
|
fpgm_train:null |
||||||
|
distill_train:null |
||||||
|
null:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================eval_params=========================== |
||||||
|
eval:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================export_params=========================== |
||||||
|
--save_dir:adaptive |
||||||
|
--model_dir:adaptive |
||||||
|
--fixed_input_shape:[-1,3,256,256] |
||||||
|
norm_export:deploy/export/export_model.py |
||||||
|
quant_export:null |
||||||
|
fpgm_export:null |
||||||
|
distill_export:null |
||||||
|
export1:null |
||||||
|
export2:null |
||||||
|
===========================infer_params=========================== |
||||||
|
infer_model:null |
||||||
|
infer_export:null |
||||||
|
infer_quant:False |
||||||
|
inference:test_tipc/infer.py |
||||||
|
--device:cpu|gpu |
||||||
|
--enable_mkldnn:True |
||||||
|
--cpu_threads:6 |
||||||
|
--batch_size:1 |
||||||
|
--use_trt:False |
||||||
|
--precision:fp32 |
||||||
|
--model_dir:null |
||||||
|
--config:null |
||||||
|
--save_log_path:null |
||||||
|
--benchmark:True |
||||||
|
--model_name:fc_siam_conc |
||||||
|
null:null |
@ -0,0 +1,8 @@ |
|||||||
|
# Configurations of FC-Siam-diff with AirChange dataset |
||||||
|
|
||||||
|
_base_: ../_base_/airchange.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/fc_siam_diff/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: FCSiamDiff |
@ -0,0 +1,8 @@ |
|||||||
|
# Configurations of FC-Siam-diff with LEVIR-CD dataset |
||||||
|
|
||||||
|
_base_: ../_base_/levircd.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/fc_siam_diff/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: FCSiamDiff |
@ -0,0 +1,53 @@ |
|||||||
|
===========================train_params=========================== |
||||||
|
model_name:cd:fc_siam_diff |
||||||
|
python:python |
||||||
|
gpu_list:0|0,1 |
||||||
|
use_gpu:null|null |
||||||
|
--precision:null |
||||||
|
--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=20 |
||||||
|
--save_dir:adaptive |
||||||
|
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 |
||||||
|
--model_path:null |
||||||
|
--config:lite_train_lite_infer=./test_tipc/configs/cd/fc_siam_diff/fc_siam_diff_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/fc_siam_diff/fc_siam_diff_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/fc_siam_diff/fc_siam_diff_levircd.yaml |
||||||
|
train_model_name:best_model |
||||||
|
null:null |
||||||
|
## |
||||||
|
trainer:norm |
||||||
|
norm_train:test_tipc/run_task.py train cd |
||||||
|
pact_train:null |
||||||
|
fpgm_train:null |
||||||
|
distill_train:null |
||||||
|
null:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================eval_params=========================== |
||||||
|
eval:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================export_params=========================== |
||||||
|
--save_dir:adaptive |
||||||
|
--model_dir:adaptive |
||||||
|
--fixed_input_shape:[-1,3,256,256] |
||||||
|
norm_export:deploy/export/export_model.py |
||||||
|
quant_export:null |
||||||
|
fpgm_export:null |
||||||
|
distill_export:null |
||||||
|
export1:null |
||||||
|
export2:null |
||||||
|
===========================infer_params=========================== |
||||||
|
infer_model:null |
||||||
|
infer_export:null |
||||||
|
infer_quant:False |
||||||
|
inference:test_tipc/infer.py |
||||||
|
--device:cpu|gpu |
||||||
|
--enable_mkldnn:True |
||||||
|
--cpu_threads:6 |
||||||
|
--batch_size:1 |
||||||
|
--use_trt:False |
||||||
|
--precision:fp32 |
||||||
|
--model_dir:null |
||||||
|
--config:null |
||||||
|
--save_log_path:null |
||||||
|
--benchmark:True |
||||||
|
--model_name:fc_siam_diff |
||||||
|
null:null |
@ -0,0 +1,13 @@ |
|||||||
|
# Configurations of FCCDN with AirChange dataset |
||||||
|
|
||||||
|
_base_: ../_base_/airchange.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/fccdn/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: FCCDN |
||||||
|
|
||||||
|
learning_rate: 0.07 |
||||||
|
lr_decay_power: 0.6 |
||||||
|
log_interval_steps: 100 |
||||||
|
save_interval_epochs: 3 |
@ -0,0 +1,8 @@ |
|||||||
|
# Configurations of FCCDN with LEVIR-CD dataset |
||||||
|
|
||||||
|
_base_: ../_base_/levircd.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/fccdn/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: FCCDN |
@ -0,0 +1,53 @@ |
|||||||
|
===========================train_params=========================== |
||||||
|
model_name:cd:fccdn |
||||||
|
python:python |
||||||
|
gpu_list:0|0,1 |
||||||
|
use_gpu:null|null |
||||||
|
--precision:null |
||||||
|
--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=10 |
||||||
|
--save_dir:adaptive |
||||||
|
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 |
||||||
|
--model_path:null |
||||||
|
--config:lite_train_lite_infer=./test_tipc/configs/cd/fccdn/fccdn_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/fccdn/fccdn_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/fccdn/fccdn_levircd.yaml |
||||||
|
train_model_name:best_model |
||||||
|
null:null |
||||||
|
## |
||||||
|
trainer:norm |
||||||
|
norm_train:test_tipc/run_task.py train cd |
||||||
|
pact_train:null |
||||||
|
fpgm_train:null |
||||||
|
distill_train:null |
||||||
|
null:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================eval_params=========================== |
||||||
|
eval:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================export_params=========================== |
||||||
|
--save_dir:adaptive |
||||||
|
--model_dir:adaptive |
||||||
|
--fixed_input_shape:[-1,3,256,256] |
||||||
|
norm_export:deploy/export/export_model.py |
||||||
|
quant_export:null |
||||||
|
fpgm_export:null |
||||||
|
distill_export:null |
||||||
|
export1:null |
||||||
|
export2:null |
||||||
|
===========================infer_params=========================== |
||||||
|
infer_model:null |
||||||
|
infer_export:null |
||||||
|
infer_quant:False |
||||||
|
inference:test_tipc/infer.py |
||||||
|
--device:cpu|gpu |
||||||
|
--enable_mkldnn:True |
||||||
|
--cpu_threads:6 |
||||||
|
--batch_size:1 |
||||||
|
--use_trt:False |
||||||
|
--precision:fp32 |
||||||
|
--model_dir:null |
||||||
|
--config:null |
||||||
|
--save_log_path:null |
||||||
|
--benchmark:True |
||||||
|
--model_name:fccdn |
||||||
|
null:null |
@ -0,0 +1,8 @@ |
|||||||
|
# Configurations of SNUNet with AirChange dataset |
||||||
|
|
||||||
|
_base_: ../_base_/airchange.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/snunet/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: SNUNet |
@ -0,0 +1,8 @@ |
|||||||
|
# Configurations of SNUNet with LEVIR-CD dataset |
||||||
|
|
||||||
|
_base_: ../_base_/levircd.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/snunet/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: SNUNet |
@ -0,0 +1,53 @@ |
|||||||
|
===========================train_params=========================== |
||||||
|
model_name:cd:snunet |
||||||
|
python:python |
||||||
|
gpu_list:0|0,1 |
||||||
|
use_gpu:null|null |
||||||
|
--precision:null |
||||||
|
--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=10 |
||||||
|
--save_dir:adaptive |
||||||
|
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 |
||||||
|
--model_path:null |
||||||
|
--config:lite_train_lite_infer=./test_tipc/configs/cd/snunet/snunet_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/snunet/snunet_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/snunet/snunet_levircd.yaml |
||||||
|
train_model_name:best_model |
||||||
|
null:null |
||||||
|
## |
||||||
|
trainer:norm |
||||||
|
norm_train:test_tipc/run_task.py train cd |
||||||
|
pact_train:null |
||||||
|
fpgm_train:null |
||||||
|
distill_train:null |
||||||
|
null:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================eval_params=========================== |
||||||
|
eval:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================export_params=========================== |
||||||
|
--save_dir:adaptive |
||||||
|
--model_dir:adaptive |
||||||
|
--fixed_input_shape:[-1,3,256,256] |
||||||
|
norm_export:deploy/export/export_model.py |
||||||
|
quant_export:null |
||||||
|
fpgm_export:null |
||||||
|
distill_export:null |
||||||
|
export1:null |
||||||
|
export2:null |
||||||
|
===========================infer_params=========================== |
||||||
|
infer_model:null |
||||||
|
infer_export:null |
||||||
|
infer_quant:False |
||||||
|
inference:test_tipc/infer.py |
||||||
|
--device:cpu|gpu |
||||||
|
--enable_mkldnn:True |
||||||
|
--cpu_threads:6 |
||||||
|
--batch_size:1 |
||||||
|
--use_trt:False |
||||||
|
--precision:fp32 |
||||||
|
--model_dir:null |
||||||
|
--config:null |
||||||
|
--save_log_path:null |
||||||
|
--benchmark:True |
||||||
|
--model_name:snunet |
||||||
|
null:null |
@ -0,0 +1,8 @@ |
|||||||
|
# Configurations of STANet with AirChange dataset |
||||||
|
|
||||||
|
_base_: ../_base_/airchange.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/stanet/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: STANet |
@ -0,0 +1,8 @@ |
|||||||
|
# Configurations of STANet with LEVIR-CD dataset |
||||||
|
|
||||||
|
_base_: ../_base_/levircd.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/cd/stanet/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: STANet |
@ -0,0 +1,53 @@ |
|||||||
|
===========================train_params=========================== |
||||||
|
model_name:cd:stanet |
||||||
|
python:python |
||||||
|
gpu_list:0|0,1 |
||||||
|
use_gpu:null|null |
||||||
|
--precision:null |
||||||
|
--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=10 |
||||||
|
--save_dir:adaptive |
||||||
|
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=8 |
||||||
|
--model_path:null |
||||||
|
--config:lite_train_lite_infer=./test_tipc/configs/cd/stanet/stanet_airchange.yaml|lite_train_whole_infer=./test_tipc/configs/cd/stanet/stanet_airchange.yaml|whole_train_whole_infer=./test_tipc/configs/cd/stanet/stanet_levircd.yaml |
||||||
|
train_model_name:best_model |
||||||
|
null:null |
||||||
|
## |
||||||
|
trainer:norm |
||||||
|
norm_train:test_tipc/run_task.py train cd |
||||||
|
pact_train:null |
||||||
|
fpgm_train:null |
||||||
|
distill_train:null |
||||||
|
null:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================eval_params=========================== |
||||||
|
eval:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================export_params=========================== |
||||||
|
--save_dir:adaptive |
||||||
|
--model_dir:adaptive |
||||||
|
--fixed_input_shape:[-1,3,256,256] |
||||||
|
norm_export:deploy/export/export_model.py |
||||||
|
quant_export:null |
||||||
|
fpgm_export:null |
||||||
|
distill_export:null |
||||||
|
export1:null |
||||||
|
export2:null |
||||||
|
===========================infer_params=========================== |
||||||
|
infer_model:null |
||||||
|
infer_export:null |
||||||
|
infer_quant:False |
||||||
|
inference:test_tipc/infer.py |
||||||
|
--device:cpu|gpu |
||||||
|
--enable_mkldnn:True |
||||||
|
--cpu_threads:6 |
||||||
|
--batch_size:1 |
||||||
|
--use_trt:False |
||||||
|
--precision:fp32 |
||||||
|
--model_dir:null |
||||||
|
--config:null |
||||||
|
--save_log_path:null |
||||||
|
--benchmark:True |
||||||
|
--model_name:stanet |
||||||
|
null:null |
@ -0,0 +1,10 @@ |
|||||||
|
# Configurations of HRNet with UCMerced dataset |
||||||
|
|
||||||
|
_base_: ../_base_/ucmerced.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/clas/hrnet/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: HRNet_W18_C |
||||||
|
args: |
||||||
|
num_classes: 21 |
@ -0,0 +1,10 @@ |
|||||||
|
# Configurations of MobileNetV3 with UCMerced dataset |
||||||
|
|
||||||
|
_base_: ../_base_/ucmerced.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/clas/mobilenetv3/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: MobileNetV3_small_x1_0 |
||||||
|
args: |
||||||
|
num_classes: 21 |
@ -0,0 +1,53 @@ |
|||||||
|
===========================train_params=========================== |
||||||
|
model_name:clas:mobilenetv3 |
||||||
|
python:python |
||||||
|
gpu_list:0|0,1 |
||||||
|
use_gpu:null|null |
||||||
|
--precision:null |
||||||
|
--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=10 |
||||||
|
--save_dir:adaptive |
||||||
|
--train_batch_size:lite_train_lite_infer=16|lite_train_whole_infer=16|whole_train_whole_infer=16 |
||||||
|
--model_path:null |
||||||
|
--config:lite_train_lite_infer=./test_tipc/configs/clas/mobilenetv3/mobilenetv3_ucmerced.yaml|lite_train_whole_infer=./test_tipc/configs/clas/mobilenetv3/mobilenetv3_ucmerced.yaml|whole_train_whole_infer=./test_tipc/configs/clas/mobilenetv3/mobilenetv3_ucmerced.yaml |
||||||
|
train_model_name:best_model |
||||||
|
null:null |
||||||
|
## |
||||||
|
trainer:norm |
||||||
|
norm_train:test_tipc/run_task.py train clas |
||||||
|
pact_train:null |
||||||
|
fpgm_train:null |
||||||
|
distill_train:null |
||||||
|
null:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================eval_params=========================== |
||||||
|
eval:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================export_params=========================== |
||||||
|
--save_dir:adaptive |
||||||
|
--model_dir:adaptive |
||||||
|
--fixed_input_shape:[-1,3,256,256] |
||||||
|
norm_export:deploy/export/export_model.py |
||||||
|
quant_export:null |
||||||
|
fpgm_export:null |
||||||
|
distill_export:null |
||||||
|
export1:null |
||||||
|
export2:null |
||||||
|
===========================infer_params=========================== |
||||||
|
infer_model:null |
||||||
|
infer_export:null |
||||||
|
infer_quant:False |
||||||
|
inference:test_tipc/infer.py |
||||||
|
--device:cpu|gpu |
||||||
|
--enable_mkldnn:True |
||||||
|
--cpu_threads:6 |
||||||
|
--batch_size:1 |
||||||
|
--use_trt:False |
||||||
|
--precision:fp32 |
||||||
|
--model_dir:null |
||||||
|
--config:null |
||||||
|
--save_log_path:null |
||||||
|
--benchmark:True |
||||||
|
--model_name:mobilenetv3 |
||||||
|
null:null |
@ -0,0 +1,10 @@ |
|||||||
|
# Configurations of ResNet50-vd with UCMerced dataset |
||||||
|
|
||||||
|
_base_: ../_base_/ucmerced.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/clas/resnet50_vd/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: ResNet50_vd |
||||||
|
args: |
||||||
|
num_classes: 21 |
@ -0,0 +1,53 @@ |
|||||||
|
===========================train_params=========================== |
||||||
|
model_name:clas:resnet50_vd |
||||||
|
python:python |
||||||
|
gpu_list:0|0,1 |
||||||
|
use_gpu:null|null |
||||||
|
--precision:null |
||||||
|
--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=10 |
||||||
|
--save_dir:adaptive |
||||||
|
--train_batch_size:lite_train_lite_infer=16|lite_train_whole_infer=16|whole_train_whole_infer=16 |
||||||
|
--model_path:null |
||||||
|
--config:lite_train_lite_infer=./test_tipc/configs/clas/resnet50_vd/resnet50_vd_ucmerced.yaml|lite_train_whole_infer=./test_tipc/configs/clas/resnet50_vd/resnet50_vd_ucmerced.yaml|whole_train_whole_infer=./test_tipc/configs/clas/resnet50_vd/resnet50_vd_ucmerced.yaml |
||||||
|
train_model_name:best_model |
||||||
|
null:null |
||||||
|
## |
||||||
|
trainer:norm |
||||||
|
norm_train:test_tipc/run_task.py train clas |
||||||
|
pact_train:null |
||||||
|
fpgm_train:null |
||||||
|
distill_train:null |
||||||
|
null:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================eval_params=========================== |
||||||
|
eval:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================export_params=========================== |
||||||
|
--save_dir:adaptive |
||||||
|
--model_dir:adaptive |
||||||
|
--fixed_input_shape:[-1,3,256,256] |
||||||
|
norm_export:deploy/export/export_model.py |
||||||
|
quant_export:null |
||||||
|
fpgm_export:null |
||||||
|
distill_export:null |
||||||
|
export1:null |
||||||
|
export2:null |
||||||
|
===========================infer_params=========================== |
||||||
|
infer_model:null |
||||||
|
infer_export:null |
||||||
|
infer_quant:False |
||||||
|
inference:test_tipc/infer.py |
||||||
|
--device:cpu|gpu |
||||||
|
--enable_mkldnn:True |
||||||
|
--cpu_threads:6 |
||||||
|
--batch_size:1 |
||||||
|
--use_trt:False |
||||||
|
--precision:fp32 |
||||||
|
--model_dir:null |
||||||
|
--config:null |
||||||
|
--save_log_path:null |
||||||
|
--benchmark:True |
||||||
|
--model_name:resnet50_vd |
||||||
|
null:null |
@ -0,0 +1,72 @@ |
|||||||
|
# Basic configurations of RSOD dataset |
||||||
|
|
||||||
|
datasets: |
||||||
|
train: !Node |
||||||
|
type: VOCDetDataset |
||||||
|
args: |
||||||
|
data_dir: ./test_tipc/data/rsod/ |
||||||
|
file_list: ./test_tipc/data/rsod/train.txt |
||||||
|
label_list: ./test_tipc/data/rsod/labels.txt |
||||||
|
shuffle: True |
||||||
|
eval: !Node |
||||||
|
type: VOCDetDataset |
||||||
|
args: |
||||||
|
data_dir: ./test_tipc/data/rsod/ |
||||||
|
file_list: ./test_tipc/data/rsod/val.txt |
||||||
|
label_list: ./test_tipc/data/rsod/labels.txt |
||||||
|
shuffle: False |
||||||
|
transforms: |
||||||
|
train: |
||||||
|
- !Node |
||||||
|
type: DecodeImg |
||||||
|
- !Node |
||||||
|
type: RandomDistort |
||||||
|
- !Node |
||||||
|
type: RandomExpand |
||||||
|
- !Node |
||||||
|
type: RandomCrop |
||||||
|
- !Node |
||||||
|
type: RandomHorizontalFlip |
||||||
|
- !Node |
||||||
|
type: BatchRandomResize |
||||||
|
args: |
||||||
|
target_sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608] |
||||||
|
interp: RANDOM |
||||||
|
- !Node |
||||||
|
type: Normalize |
||||||
|
args: |
||||||
|
mean: [0.485, 0.456, 0.406] |
||||||
|
std: [0.229, 0.224, 0.225] |
||||||
|
- !Node |
||||||
|
type: ArrangeDetector |
||||||
|
args: ['train'] |
||||||
|
eval: |
||||||
|
- !Node |
||||||
|
type: DecodeImg |
||||||
|
- !Node |
||||||
|
type: Resize |
||||||
|
args: |
||||||
|
target_size: 608 |
||||||
|
interp: CUBIC |
||||||
|
- !Node |
||||||
|
type: Normalize |
||||||
|
args: |
||||||
|
mean: [0.485, 0.456, 0.406] |
||||||
|
std: [0.229, 0.224, 0.225] |
||||||
|
- !Node |
||||||
|
type: ArrangeDetector |
||||||
|
args: ['eval'] |
||||||
|
download_on: False |
||||||
|
|
||||||
|
num_epochs: 10 |
||||||
|
train_batch_size: 4 |
||||||
|
save_interval_epochs: 10 |
||||||
|
log_interval_steps: 4 |
||||||
|
save_dir: ./test_tipc/output/det/ |
||||||
|
learning_rate: 0.0001 |
||||||
|
use_vdl: False |
||||||
|
resume_checkpoint: '' |
||||||
|
train: |
||||||
|
pretrain_weights: COCO |
||||||
|
warmup_steps: 0 |
||||||
|
warmup_start_lr: 0.0 |
@ -0,0 +1,10 @@ |
|||||||
|
# Configurations of Faster R-CNN with RSOD dataset |
||||||
|
|
||||||
|
_base_: ../_base_/rsod.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/det/faster_rcnn/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: FasterRCNN |
||||||
|
args: |
||||||
|
num_classes: 4 |
@ -0,0 +1,10 @@ |
|||||||
|
# Configurations of Faster R-CNN with SARShip dataset |
||||||
|
|
||||||
|
_base_: ../_base_/sarship.yaml |
||||||
|
|
||||||
|
save_dir: ./test_tipc/output/det/faster_rcnn/ |
||||||
|
|
||||||
|
model: !Node |
||||||
|
type: FasterRCNN |
||||||
|
args: |
||||||
|
num_classes: 1 |
@ -0,0 +1,53 @@ |
|||||||
|
===========================train_params=========================== |
||||||
|
model_name:det:faster_rcnn |
||||||
|
python:python |
||||||
|
gpu_list:0|0,1 |
||||||
|
use_gpu:null|null |
||||||
|
--precision:null |
||||||
|
--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=10 |
||||||
|
--save_dir:adaptive |
||||||
|
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4 |
||||||
|
--model_path:null |
||||||
|
--config:lite_train_lite_infer=./test_tipc/configs/det/faster_rcnn/faster_rcnn_sarship.yaml|lite_train_whole_infer=./test_tipc/configs/det/faster_rcnn/faster_rcnn_sarship.yaml|whole_train_whole_infer=./test_tipc/configs/det/faster_rcnn/faster_rcnn_rsod.yaml |
||||||
|
train_model_name:best_model |
||||||
|
null:null |
||||||
|
## |
||||||
|
trainer:norm |
||||||
|
norm_train:test_tipc/run_task.py train det |
||||||
|
pact_train:null |
||||||
|
fpgm_train:null |
||||||
|
distill_train:null |
||||||
|
null:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================eval_params=========================== |
||||||
|
eval:null |
||||||
|
null:null |
||||||
|
## |
||||||
|
===========================export_params=========================== |
||||||
|
--save_dir:adaptive |
||||||
|
--model_dir:adaptive |
||||||
|
--fixed_input_shape:[-1,3,608,608] |
||||||
|
norm_export:deploy/export/export_model.py |
||||||
|
quant_export:null |
||||||
|
fpgm_export:null |
||||||
|
distill_export:null |
||||||
|
export1:null |
||||||
|
export2:null |
||||||
|
===========================infer_params=========================== |
||||||
|
infer_model:null |
||||||
|
infer_export:null |
||||||
|
infer_quant:False |
||||||
|
inference:test_tipc/infer.py |
||||||
|
--device:cpu|gpu |
||||||
|
--enable_mkldnn:True |
||||||
|
--cpu_threads:6 |
||||||
|
--batch_size:1 |
||||||
|
--use_trt:False |
||||||
|
--precision:fp32 |
||||||
|
--model_dir:null |
||||||
|
--config:null |
||||||
|
--save_log_path:null |
||||||
|
--benchmark:True |
||||||
|
--model_name:faster_rcnn |
||||||
|
null:null |
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue