commit
ce21e5dd40
118 changed files with 5650 additions and 1939 deletions
Before Width: | Height: | Size: 241 KiB After Width: | Height: | Size: 229 KiB |
@ -0,0 +1,33 @@ |
||||
# PaddleRS实践案例 |
||||
|
||||
PaddleRS提供从科学研究到产业应用的丰富示例,希望帮助遥感领域科研从业者快速完成算法的研发、验证和调优,以及帮助投身于产业实践的开发者便捷地实现从数据预处理到模型部署的全流程遥感深度学习应用。 |
||||
|
||||
## 1 官方案例 |
||||
|
||||
- [PaddleRS科研实战:设计深度学习变化检测模型](./rs_research/) |
||||
|
||||
## 2 社区贡献案例 |
||||
|
||||
[AI Studio](https://aistudio.baidu.com/aistudio/index)是基于百度深度学习平台飞桨的人工智能学习与实训社区,提供在线编程环境、免费GPU算力、海量开源算法和开放数据,帮助开发者快速创建和部署模型。您可以在AI Studio上探索PaddleRS的更多玩法: |
||||
|
||||
[AI Studio上的PaddleRS相关项目](https://aistudio.baidu.com/aistudio/projectoverview/public?kw=PaddleRS) |
||||
|
||||
本文档收集了部分由开源爱好者贡献的精品项目: |
||||
|
||||
|项目链接|项目作者|项目类型|关键词| |
||||
|-|-|-|-| |
||||
|[手把手教你PaddleRS实现变化检测](https://aistudio.baidu.com/aistudio/projectdetail/3737991)|奔向未来的样子|入门教程|变化检测| |
||||
|[【PPSIG】PaddleRS变化检测模型部署:以BIT为例](https://aistudio.baidu.com/aistudio/projectdetail/4184759)|古代飞|入门教程|变化检测,模型部署| |
||||
|[【PPSIG】PaddleRS实现遥感影像场景分类](https://aistudio.baidu.com/aistudio/projectdetail/4198965)|古代飞|入门教程|场景分类| |
||||
|[PaddleRS:使用超分模型提高真实的低分辨率无人机影像的分割精度](https://aistudio.baidu.com/aistudio/projectdetail/3696814)|KeyK-小胡之父|应用案例|超分辨率重建,无人机影像| |
||||
|[PaddleRS:无人机汽车识别](https://aistudio.baidu.com/aistudio/projectdetail/3713122)|geoyee|应用案例|目标检测,无人机影像| |
||||
|[PaddleRS:高光谱卫星影像场景分类](https://aistudio.baidu.com/aistudio/projectdetail/3711240)|geoyee|应用案例|场景分类,高光谱影像| |
||||
|[PaddleRS:利用卫星影像与数字高程模型进行滑坡识别](https://aistudio.baidu.com/aistudio/projectdetail/4066570)|KeyK-小胡之父|应用案例|图像分割,DEM| |
||||
|[为PaddleRS添加一个袖珍配置系统](https://aistudio.baidu.com/aistudio/projectdetail/4203534)|古代飞|创意开发|| |
||||
|[万丈高楼平地起 基于PaddleGAN与PaddleRS的建筑物生成](https://aistudio.baidu.com/aistudio/projectdetail/3716885)|奔向未来的样子|创意开发|超分辨率重建| |
||||
|[【官方】第十一届 “中国软件杯”百度遥感赛项:变化检测功能](https://aistudio.baidu.com/aistudio/projectdetail/3684588)|古代飞|竞赛打榜|变化检测,比赛基线| |
||||
|[【官方】第十一届 “中国软件杯”百度遥感赛项:目标提取功能](https://aistudio.baidu.com/aistudio/projectdetail/3792610)|古代飞|竞赛打榜|图像分割,比赛基线| |
||||
|[【官方】第十一届 “中国软件杯”百度遥感赛项:地物分类功能](https://aistudio.baidu.com/aistudio/projectdetail/3792606)|古代飞|竞赛打榜|图像分割,比赛基线| |
||||
|[【官方】第十一届 “中国软件杯”百度遥感赛项:目标检测功能](https://aistudio.baidu.com/aistudio/projectdetail/3792609)|古代飞|竞赛打榜|目标检测,比赛基线| |
||||
|[【十一届软件杯】遥感解译赛道:变化检测任务——预赛第四名方案分享](https://aistudio.baidu.com/aistudio/projectdetail/4116895)|lzzzzzm|竞赛打榜|变化检测,高分方案| |
||||
|[【方案分享】第十一届 “中国软件杯”大学生软件设计大赛遥感解译赛道 比赛方案分享](https://aistudio.baidu.com/aistudio/projectdetail/4146154)|trainer|竞赛打榜|变化检测,高分方案| |
@ -0,0 +1,2 @@ |
||||
/data/ |
||||
/exp/ |
@ -0,0 +1,468 @@ |
||||
# PaddleRS科研实战:设计深度学习变化检测模型 |
||||
|
||||
本案例演示如何使用PaddleRS设计变化检测模型,并开展对比实验、消融实验和特征可视化实验。 |
||||
|
||||
## 1 环境配置 |
||||
|
||||
根据[教程](https://github.com/PaddlePaddle/PaddleRS/tree/develop/tutorials/train#环境准备)安装PaddleRS及相关依赖。在本案例中,GDAL库并不是必需的。 |
||||
|
||||
配置好环境后,在PaddleRS仓库根目录中执行如下指令切换到本案例所在目录: |
||||
|
||||
```shell |
||||
cd examples/rs_research |
||||
``` |
||||
|
||||
请注意,本文档仅所提供的所有指令遵循bash语法。 |
||||
|
||||
## 2 数据准备 |
||||
|
||||
本案例在[LEVIR-CD数据集](https://www.mdpi.com/2072-4292/12/10/1662)[1]上开展实验。请在[LEVIR-CD数据集下载链接](https://justchenhao.github.io/LEVIR/)下载数据集,解压至本地目录,并执行如下指令: |
||||
|
||||
```bash |
||||
mkdir data/ |
||||
python ../../tools/prepare_dataset/prepare_levircd.py \ |
||||
--in_dataset_dir "{LEVIR-CD数据集存放目录路径}" \ |
||||
--out_dataset_dir "data/levircd" \ |
||||
--crop_size 256 \ |
||||
--crop_stride 256 |
||||
``` |
||||
|
||||
以上指令利用PaddleRS提供的数据集准备工具完成数据集切分、file list创建等操作。具体而言,使用LEVIR-CD数据集官方的训练/验证/测试集划分,并将原始的`1024x1024`大小的影像切分为无重叠的`256x256`的小块(参考[2]中的做法). |
||||
|
||||
## 3 模型设计 |
||||
|
||||
### 3.1 问题分析与思路拟定 |
||||
|
||||
随着深度学习技术应用的不断深入,近年来,变化检测领域涌现了许多基于全卷积神经网络(fully convolutional network, FCN)的遥感影像变化检测算法。与基于特征和基于影像块的方法相比,基于FCN的方法具有处理效率高、依赖超参数少等优势,但其缺点在于参数量往往较大,因而对训练样本的数量更为依赖。尽管中、大型变化检测数据集的数量与日俱增,训练样本日益丰富,但深度学习变化检测模型的参数量也越来越大。下图显示了从2018年到2021年一些已发表的文献中提出的基于FCN的变化检测模型的参数量与其在SVCD数据集[3]上取得的F1分数(柱状图中bar的高度与模型参数量成正比): |
||||
|
||||
<p align="center"> |
||||
<img src="https://user-images.githubusercontent.com/21275753/186670936-5f79983c-914c-4e81-8f01-11df2beadf09.png" width="850"> |
||||
</p> |
||||
|
||||
诚然,增大参数数量在大多数情况下等同于增加模型容量,而模型容量的增加意味着模型拟合能力的提升,从而有助于模型在实验数据集上取得更高的精度指标。但是,“更大”一定意味着“更好”吗?答案显然是否定的。在实际应用中,“更大”的遥感影像变化检测模型常常遭遇如下问题: |
||||
|
||||
1. 巨大的参数量意味着巨大的存储开销。在许多实际场景中,硬件资源往往是有限的,过多的模型参数将给部署造成困难。 |
||||
2. 在数据有限的情况下,大模型更易遭受过拟合,其在实验数据集上看起来良好的检测效果也难以泛化到真实场景。 |
||||
|
||||
本案例认为,上述问题的根源在于参数量与数据量的失衡所导致的特征冗余。既然模型的特征存在冗余,也即存在一部分“无用”的特征,是否存在某种手段,能够在固定模型参数量的前提下对特征进行优化,从而“榨取”小模型的更多潜力,获取更多更加有效的特征?基于这个观点,本案例的基本思路是为现有的变化检测模型添加一个“插件式”的特征优化模块,在仅引入较少额外的参数数量的情况下,实现变化特征增强。本案例计划以变化检测领域经典的FC-Siam-conc[4]为基线(baseline)网络,利用通道和时间注意力模块对网络的中间层特征进行优化,从而减小特征冗余,提升检测效果。在具体的模块设计方面,选用论文[5]中提出的通道注意力模块实现通道和时间维度的特征增强。 |
||||
|
||||
FC-Siam-conc的网络结构如图所示: |
||||
|
||||
<p align="center"> |
||||
<img src="https://user-images.githubusercontent.com/21275753/186671480-d869a500-6409-4f97-b48b-50ce95ea3a71.jpg" width="500"> |
||||
</p> |
||||
|
||||
本案例计划在解码器中首个Concat模块之前添加通道与时间注意力模块组合而成的混合注意力模块以优化从编码器传来的特征,并将新模型称为CustomModel。 |
||||
|
||||
### 3.2 模型定义 |
||||
|
||||
本小节基于PaddlePaddle框架与PaddleRS库实现[3.1节](#31-问题分析与思路拟定)中提出的想法。 |
||||
|
||||
在`custom_model.py`中定义模型的整体结构以及组成模型的各个模块。本案例在`custom_model.py`中定义了改进后的FC-Siam-conc结构,其核心部分实现如下: |
||||
|
||||
```python |
||||
... |
||||
# PaddleRS提供了许多开箱即用的模块,其中有对底层基础模块的封装(如conv-bn-relu结构等),也有注意力模块等较高层级的结构 |
||||
from paddlers.rs_models.cd.layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity |
||||
from paddlers.rs_models.cd.layers import ChannelAttention |
||||
|
||||
from attach_tools import Attach |
||||
|
||||
attach = Attach.to(paddlers.rs_models.cd) |
||||
|
||||
@attach |
||||
class CustomModel(nn.Layer): |
||||
def __init__(self, |
||||
in_channels, |
||||
num_classes, |
||||
att_types='ct', |
||||
use_dropout=False): |
||||
super().__init__() |
||||
... |
||||
# 构建一个混合注意力模块att4,用于处理两个编码器最终输出的特征 |
||||
self.att4 = MixedAttention(C4, att_types) |
||||
|
||||
self.init_weight() |
||||
|
||||
def forward(self, t1, t2): |
||||
... |
||||
x4d = self.upconv4(x4p) |
||||
pad4 = (0, x43_1.shape[3] - x4d.shape[3], 0, |
||||
x43_1.shape[2] - x4d.shape[2]) |
||||
x4d = F.pad(x4d, pad=pad4, mode='replicate') |
||||
# 将注意力模块接入第一个解码单元 |
||||
x43_1, x43_2 = self.att4(x43_1, x43_2) |
||||
x4d = paddle.concat([x4d, x43_1, x43_2], 1) |
||||
x43d = self.do43d(self.conv43d(x4d)) |
||||
x42d = self.do42d(self.conv42d(x43d)) |
||||
x41d = self.do41d(self.conv41d(x42d)) |
||||
... |
||||
|
||||
|
||||
class MixedAttention(nn.Layer): |
||||
def __init__(self, in_channels, att_types='ct'): |
||||
super(MixedAttention, self).__init__() |
||||
|
||||
self.att_types = att_types |
||||
|
||||
# 每个注意力模块都是可选的 |
||||
if self.has_att_c: |
||||
self.att_c = ChannelAttention(in_channels, ratio=1) |
||||
else: |
||||
self.att_c = Identity() |
||||
|
||||
if has_att_t: |
||||
# 时间注意力模块部分复用通道注意力的逻辑,在`forward()`中将具体解释 |
||||
self.att_t = ChannelAttention(2, ratio=1) |
||||
else: |
||||
self.att_t = Identity() |
||||
|
||||
def forward(x1, x2): |
||||
# x1和x2分别是FC-Siam-conc的两路编码器提取的特征 |
||||
|
||||
if self.has_att_c: |
||||
# 首先使用通道注意力模块对特征进行优化 |
||||
# 两个时相的编码特征共享通道注意力模块 |
||||
# 添加残差连接以加速收敛 |
||||
x1 = (1 + self.att_c(x1)) * x1 |
||||
x2 = (1 + self.att_c(x2)) * x2 |
||||
|
||||
if self.has_att_t: |
||||
b, c = x1.shape[:2] |
||||
# 为了复用通道注意力模块执行时间维度的注意力操作,首先将两个时相的特征堆叠 |
||||
y = paddle.stack([x1, x2], axis=2) |
||||
# 堆叠后的y形状为[b, c, t, h, w],其中b表示batch size,c为特征通道数,t为2(时相数目),h和w分别为特征图高宽 |
||||
# 将b和c两个维度合并,输出tensor形状为[b*c, t, h, w] |
||||
y = paddle.flatten(y, stop_axis=1) |
||||
# 此时,时间维度已经替代了原先的通道维度,将四维tensor输入ChannelAttention模块进行处理 |
||||
# 同样添加残差连接 |
||||
y = (1 + self.att_t(y)) * y |
||||
# 从处理结果中分离两个时相的信息 |
||||
y = y.reshape((b, c, 2, *y.shape[2:])) |
||||
y1, y2 = y[:, :, 0], y[:, :, 1] |
||||
else: |
||||
y1, y2 = x1, x2 |
||||
|
||||
return y1, y2 |
||||
|
||||
@property |
||||
def has_att_c(self): |
||||
return 'c' in self.att_types |
||||
|
||||
@property |
||||
def has_att_t(self): |
||||
return 't' in self.att_types |
||||
``` |
||||
|
||||
在编写组网相关代码时请注意以下两点: |
||||
|
||||
1. 所有模型必须为`paddle.nn.Layer`的子类; |
||||
2. 包含模型整体逻辑结构的最外层模块(如本例中的`CustomModel`类)须用`@attach`装饰; |
||||
3. 对于变化检测任务,最外层模块的`forward()`方法除`self`参数外还接受两个参数`t1`、`t2`,分别表示第一时相和第二时相影像。 |
||||
|
||||
关于模型定义的更多细节请参考[《开发指南》](https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/dev/dev_guide.md)。 |
||||
|
||||
## 4 模型训练 |
||||
|
||||
本案例提供两种模型训练方式:基于脚本编写的方式与基于配置文件的方式。 |
||||
|
||||
- 对于初学者,建议使用脚本编写的方式:该方式更易理解,代码逻辑简单,且无需编写自定义训练器。 |
||||
- 对于较为熟练的科研者,或者是有开展大量对比实验、消融实验需求的科研者,建议使用基于配置文件的方式:该方式能够更方便地管理模型的不同配置,且易于并行执行多组实验。 |
||||
|
||||
需要说明的是,本文档中的实验结果均来自以基于配置文件方式训练的模型。本案例提供了本文档中涉及的全部实验的配置文件,存储在`configs`目录中。 |
||||
|
||||
### 4.1 基于脚本编写的方式 |
||||
|
||||
本案例提供`train_cd.py`脚本对模型进行训练和验证,并汇报验证集上最优模型在测试集上的精度。通过如下指令执行脚本: |
||||
|
||||
```bash |
||||
python train_cd.py |
||||
``` |
||||
|
||||
阅读脚本中的注释有助于使用者理解每个步骤的含义。脚本默认实现LEVIR-CD数据集上对自定义模型CustomModel的训练和验证。在实验过程中,可以根据需要修改脚本中的部分代码,以实现超参数调优或是对不同模型进行训练的功能。 |
||||
|
||||
训练程序默认开启VisualDL日志记录功能。训练过程中或训练完成后,可使用VisualDL观察损失函数和精度指标的变化情况。在PaddleRS中使用VisualDL的方式请参考[使用教程](https://github.com/PaddlePaddle/PaddleRS/blob/develop/tutorials/train/README.md#visualdl%E5%8F%AF%E8%A7%86%E5%8C%96%E8%AE%AD%E7%BB%83%E6%8C%87%E6%A0%87)。 |
||||
|
||||
### 4.2 基于配置文件的方式 |
||||
|
||||
#### 4.2.1 配置文件编写 |
||||
|
||||
本案例提供一个基于[YAML](https://yaml.org/)的轻量级配置系统,使用者可以通过修改yaml文件达到调整超参数、更换模型、更换数据集等目的,或通过编写yaml文件增加新的配置。 |
||||
|
||||
关于本案例中配置文件的编写规则,请参考[此项目](https://aistudio.baidu.com/aistudio/projectdetail/4203534)。 |
||||
|
||||
#### 4.2.2 自定义训练器 |
||||
|
||||
在使用基于配置文件方式进行模型训练时,需要在`custom_trainer.py`中定义训练器。例如,本案例在`custom_trainer.py`中定义了与`CustomModel`模型对应的训练器: |
||||
|
||||
```python |
||||
@attach |
||||
class CustomTrainer(BaseChangeDetector): |
||||
def __init__(self, |
||||
num_classes=2, |
||||
use_mixed_loss=False, |
||||
losses=None, |
||||
in_channels=3, |
||||
att_types='ct', |
||||
use_dropout=False, |
||||
**params): |
||||
params.update({ |
||||
'in_channels': in_channels, |
||||
'att_types': att_types, |
||||
'use_dropout': use_dropout |
||||
}) |
||||
super().__init__( |
||||
model_name='CustomModel', |
||||
num_classes=num_classes, |
||||
use_mixed_loss=use_mixed_loss, |
||||
losses=losses, |
||||
**params) |
||||
``` |
||||
|
||||
在编写训练器定义相关代码时请注意以下两点: |
||||
|
||||
1. 对于变化检测任务,训练器必须为`paddlers.tasks.cd.BaseChangeDetector`的子类; |
||||
2. 与模型一样,训练器也须用`@attach`装饰; |
||||
3. 训练器和模型可以同名。 |
||||
|
||||
在本案例中,仅仅重写了训练器的`__init__()`方法。在实际科研过程中,可以通过重写`train()`、`evaluate()`、`default_loss()`等方法定制更加复杂的训练、评估策略或更换默认损失函数。 |
||||
|
||||
关于训练器的更多细节请参考[《API文档》](https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/train.md)。 |
||||
|
||||
配置文件中的`model`项可以指定训练器名称与构造参数。例如: |
||||
|
||||
```yaml |
||||
model: !Node |
||||
type: CustomTrainer |
||||
args: |
||||
att_types: c |
||||
``` |
||||
|
||||
上述配置指定构造这样的一个训练器对象:`CustomTrainer(att_types=c)`。 |
||||
|
||||
#### 4.2.3 训练指令 |
||||
|
||||
按照以下格式执行对某个模型的训练: |
||||
|
||||
```bash |
||||
python run_task.py train cd \ |
||||
--config "configs/levircd/{配置文件名称}" \ |
||||
2>&1 | tee "{日志路径}" |
||||
``` |
||||
|
||||
训练完成后,使用如下指令对验证集上最优的模型在测试集上计算指标: |
||||
|
||||
```bash |
||||
python run_task.py eval cd \ |
||||
--config "configs/levircd/{配置文件名称}" \ |
||||
--datasets.eval.args.file_list "data/levircd/test.txt" \ |
||||
--resume_checkpoint "exp/levircd/{模型名称}/best_model" |
||||
``` |
||||
|
||||
## 5 对比实验 |
||||
|
||||
为了验证模型设计的有效性,通常需要开展对比实验,在一个或多个数据集上比较所提出模型与其它模型的精度和性能。在本案例中,将自定义模型CustomModel与FC-EF、FC-Siam-diff、FC-Siam-conc三种结构进行比较,这三个模型均来自论文[4]。 |
||||
|
||||
### 5.1 实验过程 |
||||
|
||||
**当使用基于配置文件的方式进行模型训练和验证时**,可以通过如下指令在LEVIR-CD数据集上执行对所有参与对比的模型的训练: |
||||
|
||||
```bash |
||||
bash scripts/run_benchmark.sh |
||||
``` |
||||
|
||||
**当使用`train_cd.py`脚本进行模型训练和验证时**,需要为每个实验手动更改模型的类型和构造参数。此外,可通过修改`EXP_DIR`变量为不同值,将每个模型对应的结果保存到不同的目录中,方便比较。本小节中的指令示例均假设实验过程中将`EXP_DIR`设置为`exp/levircd/{模型名称}`。 |
||||
|
||||
在训练和精度指标验证完成后,可以通过如下指令保存模型输出的二值变化图: |
||||
|
||||
```bash |
||||
python predict_cd.py \ |
||||
--model_dir "exp/levircd/{模型名称}/best_model" \ |
||||
--data_dir "data/levircd" \ |
||||
--file_list "data/levircd/test.txt" \ |
||||
--save_dir "exp/predict/levircd/{模型名称}" |
||||
``` |
||||
|
||||
之后,可在`exp/predict/levircd/{模型名称}`目录查看保存的输出结果。 |
||||
|
||||
可以通过`tools/collect_imgs.py`脚本将输入图像、变化标签以及多个模型的预测结果放置在一个目录下以便于观察比较。该脚本接受三个命令行选项: |
||||
- `--globs`指定一系列通配符(可用于Python的[`glob.glob()`函数](https://docs.python.org/zh-cn/3/library/glob.html#glob.glob)),用于匹配需要收集的图像; |
||||
- `--tags`为`--globs`中的每一项指定一个别名,在存储目录中,相应的图像名将被替换为指定的别名; |
||||
- `--save_dir`指定输出目录路径,若目录不存在将被自动创建。 |
||||
|
||||
例如,对于LEVIR-CD数据集,执行如下指令: |
||||
|
||||
```bash |
||||
python tools/collect_imgs.py \ |
||||
--globs "data/levircd/LEVIR-CD/test/A/*/*.png" "data/levircd/LEVIR-CD/test/B/*/*.png" "data/levircd/LEVIR-CD/test/label/*/*.png" \ |
||||
"exp/predict/levircd/fc_ef/*.png" "exp/predict/levircd/fc_siam_conc/*.png" "exp/predict/levircd/fc_siam_diff/*.png" \ |
||||
"exp/predict/levircd/custom_model/*.png" \ |
||||
--tags 'A' 'B' 'GT' \ |
||||
'fc_ef' 'fc_siam_conc' 'fc_siam_diff' \ |
||||
'custom_model' \ |
||||
--save_dir "exp/collect/levircd" |
||||
``` |
||||
|
||||
执行完毕后,可在`exp/collect/levircd`目录中找到两个时相的输入影像、变化标签以及各个模型的预测结果。当新增模型后,可以再次调用`tools/collect_imgs.py`脚本补充结果到`exp/collect/levircd`目录中: |
||||
|
||||
```bash |
||||
python tools/collect_imgs.py --globs "exp/predict/levircd/{新增模型名称}/*.png" --tags '{新增模型名称}' --save_dir "exp/collect/levircd" |
||||
``` |
||||
|
||||
此外,为了从精度和性能两个方面综合评估变化检测算法,可以通过如下指令计算变化检测模型的[浮点计算数(floating point operations, FLOPs)](https://blog.csdn.net/IT_flying625/article/details/104898152)和模型参数量: |
||||
|
||||
```bash |
||||
python tools/analyze_model.py --model_dir "exp/levircd/{模型名称}/best_model" |
||||
``` |
||||
|
||||
### 5.2 实验结果 |
||||
|
||||
本案例使用变化类的[交并比(intersection over union, IoU)](https://paddlepedia.readthedocs.io/en/latest/tutorials/computer_vision/semantic_segmentation/Overview/Overview.html#id6)和[F1分数](https://baike.baidu.com/item/F1%E5%88%86%E6%95%B0/13864979)作为定量评价指标,这两个指标越高,表示算法的检测效果越好。在每个数据集上,从目视效果和定量指标两个方面对算法效果进行评判。 |
||||
|
||||
#### 5.2.1 目视效果对比 |
||||
|
||||
下图展示了两个时相的输入影像、各算法输出的二值变化图(binary change map)以及变化标签。所选取的样本均来自LEVIR-CD数据集的测试集。 |
||||
|
||||
|时相1影像|时相2影像|FC-EF|FC-Siam-diff|FC-Siam-conc|CustomModel|变化标签| |
||||
|:-:|:-:|:-:|:-:|:-:|:-:|:-:| |
||||
|<img src="https://user-images.githubusercontent.com/21275753/186671764-2dc990a8-b297-43a2-ae81-e31f2d5582e5.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672204-e8e46e9a-7f29-4506-9ed4-31314284a6fb.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672237-ee5f67d8-8966-457d-8a80-0452bdb7af89.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186671987-7da0023a-0c96-413f-9088-0f6730ab54dd.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186671895-c6c40196-b86a-49d1-a4b0-48a7f40cba06.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672068-89a60f8c-c80e-4f73-bb3e-b9ad146e795d.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672106-37e8dcd0-b0f0-46e1-90a1-bd5f566ef97b.png" width="100">| |
||||
|<img src="https://user-images.githubusercontent.com/21275753/186672287-efa1209d-2786-4543-b136-5f50b7b0dd8c.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186671791-beb82760-8c3f-480f-8ada-9c1081860691.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186671861-7b7989e4-15d8-4342-9abe-2d6efa82811a.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672362-94993c68-7c31-4501-b009-755c193a00a8.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672348-3134129c-e2cd-4011-8894-901ef332a43d.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672415-da3984b2-0354-49ad-8dba-9c796a18d282.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672449-fd225e4f-ac58-4506-8b66-3a255567998a.png" width="100">| |
||||
|
||||
从图中可以看出,虽然结果中仍存在一定程度的漏检与误检,但相比其它算法,CustomModel对变化区域的刻画相对更为准确。 |
||||
|
||||
#### 5.2.2 定量指标对比 |
||||
|
||||
|模型名称|FLOPs(G)|参数量(M)|IoU%|F1%| |
||||
|:-:|:-:|:-:|:-:|:-:| |
||||
|FC-EF|3.57|1.35|79.05|88.30| |
||||
|FC-Siam-diff|4.71|1.35|81.33|89.70| |
||||
|FC-Siam-conc|5.31|1.55|81.31|89.69| |
||||
|CustomModel|5.31|1.58|**82.14**|**90.19**| |
||||
|
||||
最高的精度指标用粗体表示。从表中可以看出,CustomModel取得了所有算法中最高的IoU和F1分数指标(与FC-EF对比IoU增加3.09%,F1增加1.89%),而其相比baseline模型FC-Siam-conc仅仅引入0.03 M的额外参数量。 |
||||
|
||||
## 6 消融实验 |
||||
|
||||
在科研过程中,为了验证在baseline上所做修改的有效性,常常需要开展消融实验。在本案例中,CustomModel在FC-Siam-conc模型的基础上添加了通道和时间两种注意力模块,因此需要通过消融实验探讨各个注意力模块对最终精度的贡献。具体而言,包括以下4种实验情形(消融模型相关的配置文件存储在`configs/levircd/ablation`目录): |
||||
|
||||
1. 基础情况:不使用任何注意力模块,即baseline模型FC-Siam-conc; |
||||
2. 仅添加通道注意力模块,对应的配置文件名称为`custom_model_c.yaml`; |
||||
3. 仅添加时间注意力模块,对应的配置文件名称为`custom_model_t.yaml`; |
||||
4. 标准情况:同时添加通道和时间注意力模块的完整模型。 |
||||
|
||||
其中第1和第4个模型,即baseline和完整模型,在[第4节](#4-模型训练)和[第5节](#5-对比实验)中已经得到了训练、验证和测试。因此,本节只需要关注情形2、3。 |
||||
|
||||
### 6.1 实验过程 |
||||
|
||||
**当使用基于配置文件的方式进行模型训练时**,可通过如下指令训练全部消融模型: |
||||
|
||||
```bash |
||||
bash scripts/run_ablation.sh |
||||
``` |
||||
|
||||
或者,可以按照以下格式执行对某一个模型的训练: |
||||
|
||||
```bash |
||||
python run_task.py train cd \ |
||||
--config "configs/levircd/ablation/{配置文件名称}" \ |
||||
2>&1 | tee {日志路径} |
||||
``` |
||||
|
||||
训练完成后,使用如下指令对验证集上最优的模型在测试集上计算指标: |
||||
|
||||
```bash |
||||
python run_task.py eval cd \ |
||||
--config "configs/levircd/ablation/{配置文件名称}" \ |
||||
--datasets.eval.args.file_list "data/levircd/test.txt" \ |
||||
--resume_checkpoint "exp/levircd/ablation/{消融模型名称}/best_model" |
||||
``` |
||||
|
||||
注意,形如`custom_model_c.yaml`的配置文件默认对应的消融模型名称为`att_c`。 |
||||
|
||||
**当使用`train_cd.py`进行模型训练时**,需要修改模型构造时的`att_types`参数,以得到不同消融模型的结果。例如,对于仅添加通道注意力模块的消融模型,应设置`att_types='c'`。此外,可通过修改`EXP_DIR`变量为不同值,将每个实验的结果保存到不同的目录中,方便比较。 |
||||
|
||||
### 6.2 实验结果 |
||||
|
||||
实验得到的定量指标如下表所示: |
||||
|
||||
|通道注意力模块|时间注意力模块|IoU%|F1%| |
||||
|:-:|:-:|:-:|:-:| |
||||
|||81.31|89.69| |
||||
|✓||81.97|90.09| |
||||
||✓|81.59|89.86| |
||||
|✓|✓|**82.14**|**90.19**| |
||||
|
||||
从表中数据可知,无论是通道注意力模块还是时间注意力模块都能对算法的IoU和F1分数指标带来正面贡献,而同时添加两种注意力模块带来的增益是最大的(相比baseline模型IoU增加0.83%,F1分数增加0.50%)。 |
||||
|
||||
## 7 特征可视化实验 |
||||
|
||||
本节主要对模型的中间特征进行可视化,以进一步验证对baseline模型所做的修改是否实现了增强特征的效果。 |
||||
|
||||
### 7.1 实验过程 |
||||
|
||||
通过`tools/visualize_feats.py`脚本实现对模型中间特征的可视化。该脚本接受如下命令行选项: |
||||
- `--model_dir`指定需要加载的模型的存储路径。 |
||||
- `--im_path`指定输入影像的路径,对于变化检测任务,需要依次指定两幅输入影像的路径。 |
||||
- `--save_dir`指定输出目录路径,若目录不存在将被自动创建。 |
||||
- `--hook_type`指定抓取的特征类型,有三种取值:当为`forward_in`时,表示抓取指定模块的前向输入特征;当为`forward_out`时,表示抓取指定模块的前向输出特征;当为`backward`时,表示抓取指定参数的梯度。 |
||||
- `--layer_names`指定一系列接受或产生需要抓取特征的模块的名称(父模块与子模块间使用`.`分隔)或是模型中权重参数的名称(即[state_dict](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/load_cn.html)中的key)。 |
||||
- `--to_pseudo_color`指定是否将特征图存储为伪彩色图。 |
||||
- `--output_size`指定将特征图缩放到的尺寸。 |
||||
|
||||
`tools/visualize_feats.py`生成的文件遵照`{layer_name}_{j}_vis.png`或`{layer_name}_{i}_{j}_vis.png`格式命名。其中,`{layer_name}`对应`--layer_names`选项中指定的值;`{i}`的数值表示一次抓取到多个输入、输出特征时当前特征所对应的编号;`{j}`的数值在`--hook_type`指定为`forward_in`或`forward_out`时分别表示当前特征图是第几次调用该模块时输入或输出的(模型中的一些模块可能被重复调用,如FC-Siam-conc模型中的`conv4`)。例如,如下指令获取并存储CustomModel模型中`att4`模块的输入与输出特征的可视化结果: |
||||
|
||||
```bash |
||||
IM1_PATH="data/levircd/LEVIR-CD/test/A/test_13/test_13_3.png" |
||||
IM2_PATH="data/levircd/LEVIR-CD/test/B/test_13/test_13_3.png" |
||||
|
||||
python tools/visualize_feats.py \ |
||||
--model_dir "exp/levircd/custom_model/best_model" \ |
||||
--im_path "${IM1_PATH}" "${IM2_PATH}" \ |
||||
--save_dir "exp/vis/test_13_3/in" \ |
||||
--hook_type 'forward_in' \ |
||||
--layer_names 'att4' \ |
||||
--to_pseudo_color \ |
||||
--output_size 256 256 |
||||
|
||||
python tools/visualize_feats.py \ |
||||
--model_dir "exp/levircd/custom_model/best_model" \ |
||||
--im_path "${IM1_PATH}" "${IM2_PATH}" \ |
||||
--save_dir "exp/vis/test_13_3/out" \ |
||||
--hook_type 'forward_out' \ |
||||
--layer_names 'att4' \ |
||||
--to_pseudo_color \ |
||||
--output_size 256 256 |
||||
``` |
||||
|
||||
执行上述指令将在`exp/vis/test_13_3/{模型名称}`目录中产生2个子目录,每个子目录中有2个文件,其中`in/att4_0_0_vis.png`和`in/att4_1_0_vis.png`分别表示输入`att4`模块的两个时相特征的可视化结果,`out/att4_0_0_vis.png`和`out/att4_1_0_vis.png`分别表示`att4`模块输出的两个时相特征的可视化结果。 |
||||
|
||||
### 7.2 实验结果 |
||||
|
||||
下图从左往右分别为两个时相的输入影像、变化标签、输入混合注意力模块`att4`的两个时相特征图的可视化结果(分别用x1和x2代指)以及`att4`输出的两个时相特征图的可视化结果(分别用y1和y2代指): |
||||
|
||||
|时相1影像|时相2影像|变化标签|x1|x2|y1|y2| |
||||
|:-:|:-:|:-:|:-:|:-:|:-:|:-:| |
||||
|<img src="https://user-images.githubusercontent.com/21275753/186672741-45c819f0-2591-4b97-ad32-05d787be1a0a.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672761-eb6958be-688d-4bc2-839b-6a60cb6cc3b5.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672791-ceb78cf7-5029-4991-88c2-6c4550fb27d8.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672835-7fda3499-33e0-4af1-b990-8d82f6c5c410.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672870-dba57441-509f-4cd0-bcc9-af343ddf07df.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672893-7bc692a7-c963-4686-b93c-895b5c51fecb.png" width="100">|<img src="https://user-images.githubusercontent.com/21275753/186672914-b99ffee3-9eb4-4f95-96f4-93cb00e0b109.png" width="100">| |
||||
|
||||
对比x2和y2可以看出,经过通道和时间注意力模块处理后,变化特征得到了增强,发生变化的区域在特征图中更加凸显。 |
||||
|
||||
## 8 总结与展望 |
||||
|
||||
### 8.1 总结 |
||||
|
||||
- 本案例以为经典的FC-Siam-conc模型添加注意力模块为例,演示了使用PaddleRS开展科研工作的典型流程。 |
||||
- 本案例中对模型的改进带来了一定的目视效果的改善和检测精度的提升。 |
||||
- 本案例通过消融实验和特征可视化实验证实了所提出改进的有效性。 |
||||
|
||||
### 8.2 展望 |
||||
|
||||
- 本案例对所有参与比较的算法使用了相同的训练超参数,但由于模型之间存在差异,使用统一的超参训练往往难以保证所有模型都能取得较好的效果。在后续工作中,可以对每个对比算法进行调参,使其获得最优精度。 |
||||
- 本案例作为使用PaddleRS开展科研工作的简单例子,并未在算法设计上做出较大改进,因此所提出算法相比baseline的精度提升也较为有限。未来可以考虑更复杂的算法设计,以及使用更加先进的模型结构。 |
||||
|
||||
## 参考文献 |
||||
|
||||
> [1] Chen, Hao, and Zhenwei Shi. "A spatial-temporal attention-based method and a new dataset for remote sensing image change detection." *Remote Sensing* 12.10 (2020): 1662. |
||||
[2] Chen, Hao, Zipeng Qi, and Zhenwei Shi. "Remote sensing image change detection with transformers." *IEEE Transactions on Geoscience and Remote Sensing* 60 (2021): 1-14. |
||||
[3] Lebedev, M. A., et al. "CHANGE DETECTION IN REMOTE SENSING IMAGES USING CONDITIONAL ADVERSARIAL NETWORKS." *International Archives of the Photogrammetry, Remote Sensing & Spatial Information Sciences* 42.2 (2018). |
||||
[4] Daudt, Rodrigo Caye, Bertr Le Saux, and Alexandre Boulch. "Fully convolutional siamese networks for change detection." *2018 25th IEEE International Conference on Image Processing (ICIP)*. IEEE, 2018. |
||||
[5] Woo, Sanghyun, et al. "Cbam: Convolutional block attention module." *Proceedings of the European conference on computer vision (ECCV)*. 2018. |
@ -0,0 +1,35 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
|
||||
class Attach(object): |
||||
def __init__(self, dst): |
||||
self.dst = dst |
||||
|
||||
def __call__(self, obj, name=None): |
||||
if name is None: |
||||
# Automatically get names of functions and classes |
||||
name = obj.__name__ |
||||
if hasattr(self.dst, name): |
||||
raise RuntimeError( |
||||
f"{self.dst} already has the attribute {name}, which is {getattr(self.dst, name)}." |
||||
) |
||||
setattr(self.dst, name, obj) |
||||
if hasattr(self.dst, '__all__'): |
||||
self.dst.__all__.append(name) |
||||
return obj |
||||
|
||||
@staticmethod |
||||
def to(dst): |
||||
return Attach(dst) |
@ -0,0 +1,267 @@ |
||||
#!/usr/bin/env python |
||||
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import argparse |
||||
import os.path as osp |
||||
from collections.abc import Mapping |
||||
|
||||
import yaml |
||||
|
||||
|
||||
def _chain_maps(*maps): |
||||
chained = dict() |
||||
keys = set().union(*maps) |
||||
for key in keys: |
||||
vals = [m[key] for m in maps if key in m] |
||||
if isinstance(vals[0], Mapping): |
||||
chained[key] = _chain_maps(*vals) |
||||
else: |
||||
chained[key] = vals[0] |
||||
return chained |
||||
|
||||
|
||||
def read_config(config_path): |
||||
with open(config_path, 'r', encoding='utf-8') as f: |
||||
cfg = yaml.safe_load(f) |
||||
return cfg or {} |
||||
|
||||
|
||||
def parse_configs(cfg_path, inherit=True): |
||||
if inherit: |
||||
cfgs = [] |
||||
cfgs.append(read_config(cfg_path)) |
||||
while cfgs[-1].get('_base_'): |
||||
base_path = cfgs[-1].pop('_base_') |
||||
curr_dir = osp.dirname(cfg_path) |
||||
cfgs.append( |
||||
read_config(osp.normpath(osp.join(curr_dir, base_path)))) |
||||
return _chain_maps(*cfgs) |
||||
else: |
||||
return read_config(cfg_path) |
||||
|
||||
|
||||
def _cfg2args(cfg, parser, prefix=''): |
||||
node_keys = set() |
||||
for k, v in cfg.items(): |
||||
opt = prefix + k |
||||
if isinstance(v, list): |
||||
if len(v) == 0: |
||||
parser.add_argument( |
||||
'--' + opt, type=object, nargs='*', default=v) |
||||
else: |
||||
# Only apply to homogeneous lists |
||||
if isinstance(v[0], CfgNode): |
||||
node_keys.add(opt) |
||||
parser.add_argument( |
||||
'--' + opt, type=type(v[0]), nargs='*', default=v) |
||||
elif isinstance(v, dict): |
||||
# Recursively parse a dict |
||||
_, new_node_keys = _cfg2args(v, parser, opt + '.') |
||||
node_keys.update(new_node_keys) |
||||
elif isinstance(v, CfgNode): |
||||
node_keys.add(opt) |
||||
_, new_node_keys = _cfg2args(v.to_dict(), parser, opt + '.') |
||||
node_keys.update(new_node_keys) |
||||
elif isinstance(v, bool): |
||||
parser.add_argument('--' + opt, action='store_true', default=v) |
||||
else: |
||||
parser.add_argument('--' + opt, type=type(v), default=v) |
||||
return parser, node_keys |
||||
|
||||
|
||||
def _args2cfg(cfg, args, node_keys): |
||||
args = vars(args) |
||||
for k, v in args.items(): |
||||
pos = k.find('.') |
||||
if pos != -1: |
||||
# Iteratively parse a dict |
||||
dict_ = cfg |
||||
while pos != -1: |
||||
dict_.setdefault(k[:pos], {}) |
||||
dict_ = dict_[k[:pos]] |
||||
k = k[pos + 1:] |
||||
pos = k.find('.') |
||||
dict_[k] = v |
||||
else: |
||||
cfg[k] = v |
||||
|
||||
for k in node_keys: |
||||
pos = k.find('.') |
||||
if pos != -1: |
||||
# Iteratively parse a dict |
||||
dict_ = cfg |
||||
while pos != -1: |
||||
dict_.setdefault(k[:pos], {}) |
||||
dict_ = dict_[k[:pos]] |
||||
k = k[pos + 1:] |
||||
pos = k.find('.') |
||||
v = dict_[k] |
||||
dict_[k] = [CfgNode(v_) for v_ in v] if isinstance( |
||||
v, list) else CfgNode(v) |
||||
else: |
||||
v = cfg[k] |
||||
cfg[k] = [CfgNode(v_) for v_ in v] if isinstance( |
||||
v, list) else CfgNode(v) |
||||
|
||||
return cfg |
||||
|
||||
|
||||
def parse_args(*args, **kwargs): |
||||
cfg_parser = argparse.ArgumentParser(add_help=False) |
||||
cfg_parser.add_argument('--config', type=str, default='') |
||||
cfg_parser.add_argument('--inherit_off', action='store_true') |
||||
cfg_args = cfg_parser.parse_known_args()[0] |
||||
cfg_path = cfg_args.config |
||||
inherit_on = not cfg_args.inherit_off |
||||
|
||||
# Main parser |
||||
parser = argparse.ArgumentParser( |
||||
conflict_handler='resolve', parents=[cfg_parser]) |
||||
# Global settings |
||||
parser.add_argument('cmd', choices=['train', 'eval']) |
||||
parser.add_argument('task', choices=['cd', 'clas', 'det', 'seg']) |
||||
|
||||
# Data |
||||
parser.add_argument('--datasets', type=dict, default={}) |
||||
parser.add_argument('--transforms', type=dict, default={}) |
||||
parser.add_argument('--download_on', action='store_true') |
||||
parser.add_argument('--download_url', type=str, default='') |
||||
parser.add_argument('--download_path', type=str, default='./') |
||||
|
||||
# Optimizer |
||||
parser.add_argument('--optimizer', type=dict, default={}) |
||||
|
||||
# Training related |
||||
parser.add_argument('--num_epochs', type=int, default=100) |
||||
parser.add_argument('--train_batch_size', type=int, default=8) |
||||
parser.add_argument('--save_interval_epochs', type=int, default=1) |
||||
parser.add_argument('--log_interval_steps', type=int, default=1) |
||||
parser.add_argument('--save_dir', default='../exp/') |
||||
parser.add_argument('--learning_rate', type=float, default=0.01) |
||||
parser.add_argument('--early_stop', action='store_true') |
||||
parser.add_argument('--early_stop_patience', type=int, default=5) |
||||
parser.add_argument('--use_vdl', action='store_true') |
||||
parser.add_argument('--resume_checkpoint', type=str) |
||||
parser.add_argument('--train', type=dict, default={}) |
||||
|
||||
# Loss |
||||
parser.add_argument('--losses', type=dict, nargs='+', default={}) |
||||
|
||||
# Model |
||||
parser.add_argument('--model', type=dict, default={}) |
||||
|
||||
if osp.exists(cfg_path): |
||||
cfg = parse_configs(cfg_path, inherit_on) |
||||
parser, node_keys = _cfg2args(cfg, parser, '') |
||||
node_keys = sorted(node_keys, reverse=True) |
||||
args = parser.parse_args(*args, **kwargs) |
||||
return _args2cfg(dict(), args, node_keys) |
||||
elif cfg_path != '': |
||||
raise FileNotFoundError |
||||
else: |
||||
args = parser.parse_args() |
||||
return _args2cfg(dict(), args, set()) |
||||
|
||||
|
||||
class _CfgNodeMeta(yaml.YAMLObjectMetaclass): |
||||
def __call__(cls, obj): |
||||
if isinstance(obj, CfgNode): |
||||
return obj |
||||
return super(_CfgNodeMeta, cls).__call__(obj) |
||||
|
||||
|
||||
class CfgNode(yaml.YAMLObject, metaclass=_CfgNodeMeta): |
||||
yaml_tag = u'!Node' |
||||
yaml_loader = yaml.SafeLoader |
||||
# By default use a lexical scope |
||||
ctx = globals() |
||||
|
||||
def __init__(self, dict_): |
||||
super().__init__() |
||||
self.type = dict_['type'] |
||||
self.args = dict_.get('args', []) |
||||
self.module = dict_.get('module', '') |
||||
|
||||
@classmethod |
||||
def set_context(cls, ctx): |
||||
# TODO: Implement dynamic scope with inspect.stack() |
||||
old_ctx = cls.ctx |
||||
cls.ctx = ctx |
||||
return old_ctx |
||||
|
||||
def build_object(self, mod=None): |
||||
if mod is None: |
||||
mod = self._get_module(self.module) |
||||
cls = getattr(mod, self.type) |
||||
if isinstance(self.args, list): |
||||
args = build_objects(self.args) |
||||
obj = cls(*args) |
||||
elif isinstance(self.args, dict): |
||||
args = build_objects(self.args) |
||||
obj = cls(**args) |
||||
else: |
||||
raise NotImplementedError |
||||
return obj |
||||
|
||||
def _get_module(self, s): |
||||
mod = None |
||||
while s: |
||||
idx = s.find('.') |
||||
if idx == -1: |
||||
next_ = s |
||||
s = '' |
||||
else: |
||||
next_ = s[:idx] |
||||
s = s[idx + 1:] |
||||
if mod is None: |
||||
mod = self.ctx[next_] |
||||
else: |
||||
mod = getattr(mod, next_) |
||||
return mod |
||||
|
||||
@staticmethod |
||||
def build_objects(cfg, mod=None): |
||||
if isinstance(cfg, list): |
||||
return [CfgNode.build_objects(c, mod=mod) for c in cfg] |
||||
elif isinstance(cfg, CfgNode): |
||||
return cfg.build_object(mod=mod) |
||||
elif isinstance(cfg, dict): |
||||
return { |
||||
k: CfgNode.build_objects( |
||||
v, mod=mod) |
||||
for k, v in cfg.items() |
||||
} |
||||
else: |
||||
return cfg |
||||
|
||||
def __repr__(self): |
||||
return f"(type={self.type}, args={self.args}, module={self.module or ' '})" |
||||
|
||||
@classmethod |
||||
def from_yaml(cls, loader, node): |
||||
map_ = loader.construct_mapping(node) |
||||
return cls(map_) |
||||
|
||||
def items(self): |
||||
yield from [('type', self.type), ('args', self.args), ('module', |
||||
self.module)] |
||||
|
||||
def to_dict(self): |
||||
return dict(self.items()) |
||||
|
||||
|
||||
def build_objects(cfg, mod=None): |
||||
return CfgNode.build_objects(cfg, mod=mod) |
@ -0,0 +1,8 @@ |
||||
_base_: ../levircd.yaml |
||||
|
||||
save_dir: ./exp/levircd/ablation/att_c/ |
||||
|
||||
model: !Node |
||||
type: CustomTrainer |
||||
args: |
||||
att_types: c |
@ -0,0 +1,8 @@ |
||||
_base_: ../levircd.yaml |
||||
|
||||
save_dir: ./exp/levircd/ablation/att_t/ |
||||
|
||||
model: !Node |
||||
type: CustomTrainer |
||||
args: |
||||
att_types: t |
@ -0,0 +1,6 @@ |
||||
_base_: ./levircd.yaml |
||||
|
||||
save_dir: ./exp/levircd/custom_model/ |
||||
|
||||
model: !Node |
||||
type: CustomTrainer |
@ -0,0 +1,6 @@ |
||||
_base_: ./levircd.yaml |
||||
|
||||
save_dir: ./exp/levircd/fc_ef/ |
||||
|
||||
model: !Node |
||||
type: FCEarlyFusion |
@ -0,0 +1,6 @@ |
||||
_base_: ./levircd.yaml |
||||
|
||||
save_dir: ./exp/levircd/fc_siam_conc/ |
||||
|
||||
model: !Node |
||||
type: FCSiamConc |
@ -0,0 +1,6 @@ |
||||
_base_: ./levircd.yaml |
||||
|
||||
save_dir: ./exp/levircd/fc_siam_diff/ |
||||
|
||||
model: !Node |
||||
type: FCSiamDiff |
@ -0,0 +1,74 @@ |
||||
# Basic configurations of LEVIR-CD dataset |
||||
|
||||
datasets: |
||||
train: !Node |
||||
type: CDDataset |
||||
args: |
||||
data_dir: ./data/levircd/ |
||||
file_list: ./data/levircd/train.txt |
||||
label_list: null |
||||
num_workers: 2 |
||||
shuffle: True |
||||
with_seg_labels: False |
||||
binarize_labels: True |
||||
eval: !Node |
||||
type: CDDataset |
||||
args: |
||||
data_dir: ./data/levircd/ |
||||
file_list: ./data/levircd/val.txt |
||||
label_list: null |
||||
num_workers: 0 |
||||
shuffle: False |
||||
with_seg_labels: False |
||||
binarize_labels: True |
||||
transforms: |
||||
train: |
||||
- !Node |
||||
type: DecodeImg |
||||
- !Node |
||||
type: RandomFlipOrRotate |
||||
args: |
||||
probs: [0.35, 0.35] |
||||
probsf: [0.5, 0.5, 0, 0, 0] |
||||
probsr: [0.33, 0.34, 0.33] |
||||
- !Node |
||||
type: Normalize |
||||
args: |
||||
mean: [0.5, 0.5, 0.5] |
||||
std: [0.5, 0.5, 0.5] |
||||
- !Node |
||||
type: ArrangeChangeDetector |
||||
args: ['train'] |
||||
eval: |
||||
- !Node |
||||
type: DecodeImg |
||||
- !Node |
||||
type: Normalize |
||||
args: |
||||
mean: [0.5, 0.5, 0.5] |
||||
std: [0.5, 0.5, 0.5] |
||||
- !Node |
||||
type: ArrangeChangeDetector |
||||
args: ['eval'] |
||||
download_on: False |
||||
|
||||
num_epochs: 50 |
||||
train_batch_size: 8 |
||||
optimizer: !Node |
||||
type: Adam |
||||
args: |
||||
learning_rate: !Node |
||||
type: StepDecay |
||||
module: paddle.optimizer.lr |
||||
args: |
||||
learning_rate: 0.002 |
||||
step_size: 35000 |
||||
gamma: 0.2 |
||||
save_interval_epochs: 5 |
||||
log_interval_steps: 50 |
||||
save_dir: ./exp/levircd/ |
||||
learning_rate: 0.002 |
||||
early_stop: False |
||||
early_stop_patience: 5 |
||||
use_vdl: True |
||||
resume_checkpoint: '' |
@ -0,0 +1,241 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
import paddle.nn.functional as F |
||||
import paddlers |
||||
from paddlers.rs_models.cd.layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity |
||||
from paddlers.rs_models.cd.layers import ChannelAttention |
||||
|
||||
from attach_tools import Attach |
||||
|
||||
attach = Attach.to(paddlers.rs_models.cd) |
||||
|
||||
|
||||
@attach |
||||
class CustomModel(nn.Layer): |
||||
def __init__(self, |
||||
in_channels, |
||||
num_classes, |
||||
att_types='ct', |
||||
use_dropout=False): |
||||
super(CustomModel, self).__init__() |
||||
|
||||
C1, C2, C3, C4, C5 = 16, 32, 64, 128, 256 |
||||
|
||||
self.use_dropout = use_dropout |
||||
|
||||
self.conv11 = Conv3x3(in_channels, C1, norm=True, act=True) |
||||
self.do11 = self._make_dropout() |
||||
self.conv12 = Conv3x3(C1, C1, norm=True, act=True) |
||||
self.do12 = self._make_dropout() |
||||
self.pool1 = MaxPool2x2() |
||||
|
||||
self.conv21 = Conv3x3(C1, C2, norm=True, act=True) |
||||
self.do21 = self._make_dropout() |
||||
self.conv22 = Conv3x3(C2, C2, norm=True, act=True) |
||||
self.do22 = self._make_dropout() |
||||
self.pool2 = MaxPool2x2() |
||||
|
||||
self.conv31 = Conv3x3(C2, C3, norm=True, act=True) |
||||
self.do31 = self._make_dropout() |
||||
self.conv32 = Conv3x3(C3, C3, norm=True, act=True) |
||||
self.do32 = self._make_dropout() |
||||
self.conv33 = Conv3x3(C3, C3, norm=True, act=True) |
||||
self.do33 = self._make_dropout() |
||||
self.pool3 = MaxPool2x2() |
||||
|
||||
self.conv41 = Conv3x3(C3, C4, norm=True, act=True) |
||||
self.do41 = self._make_dropout() |
||||
self.conv42 = Conv3x3(C4, C4, norm=True, act=True) |
||||
self.do42 = self._make_dropout() |
||||
self.conv43 = Conv3x3(C4, C4, norm=True, act=True) |
||||
self.do43 = self._make_dropout() |
||||
self.pool4 = MaxPool2x2() |
||||
|
||||
self.upconv4 = ConvTransposed3x3(C4, C4, output_padding=1) |
||||
|
||||
self.conv43d = Conv3x3(C5 + C4, C4, norm=True, act=True) |
||||
self.do43d = self._make_dropout() |
||||
self.conv42d = Conv3x3(C4, C4, norm=True, act=True) |
||||
self.do42d = self._make_dropout() |
||||
self.conv41d = Conv3x3(C4, C3, norm=True, act=True) |
||||
self.do41d = self._make_dropout() |
||||
|
||||
self.upconv3 = ConvTransposed3x3(C3, C3, output_padding=1) |
||||
|
||||
self.conv33d = Conv3x3(C4 + C3, C3, norm=True, act=True) |
||||
self.do33d = self._make_dropout() |
||||
self.conv32d = Conv3x3(C3, C3, norm=True, act=True) |
||||
self.do32d = self._make_dropout() |
||||
self.conv31d = Conv3x3(C3, C2, norm=True, act=True) |
||||
self.do31d = self._make_dropout() |
||||
|
||||
self.upconv2 = ConvTransposed3x3(C2, C2, output_padding=1) |
||||
|
||||
self.conv22d = Conv3x3(C3 + C2, C2, norm=True, act=True) |
||||
self.do22d = self._make_dropout() |
||||
self.conv21d = Conv3x3(C2, C1, norm=True, act=True) |
||||
self.do21d = self._make_dropout() |
||||
|
||||
self.upconv1 = ConvTransposed3x3(C1, C1, output_padding=1) |
||||
|
||||
self.conv12d = Conv3x3(C2 + C1, C1, norm=True, act=True) |
||||
self.do12d = self._make_dropout() |
||||
self.conv11d = Conv3x3(C1, num_classes) |
||||
|
||||
self.init_weight() |
||||
|
||||
self.att4 = MixedAttention(C4, att_types) |
||||
|
||||
def forward(self, t1, t2): |
||||
# Encode t1 |
||||
# Stage 1 |
||||
x11 = self.do11(self.conv11(t1)) |
||||
x12_1 = self.do12(self.conv12(x11)) |
||||
x1p = self.pool1(x12_1) |
||||
|
||||
# Stage 2 |
||||
x21 = self.do21(self.conv21(x1p)) |
||||
x22_1 = self.do22(self.conv22(x21)) |
||||
x2p = self.pool2(x22_1) |
||||
|
||||
# Stage 3 |
||||
x31 = self.do31(self.conv31(x2p)) |
||||
x32 = self.do32(self.conv32(x31)) |
||||
x33_1 = self.do33(self.conv33(x32)) |
||||
x3p = self.pool3(x33_1) |
||||
|
||||
# Stage 4 |
||||
x41 = self.do41(self.conv41(x3p)) |
||||
x42 = self.do42(self.conv42(x41)) |
||||
x43_1 = self.do43(self.conv43(x42)) |
||||
x4p = self.pool4(x43_1) |
||||
|
||||
# Encode t2 |
||||
# Stage 1 |
||||
x11 = self.do11(self.conv11(t2)) |
||||
x12_2 = self.do12(self.conv12(x11)) |
||||
x1p = self.pool1(x12_2) |
||||
|
||||
# Stage 2 |
||||
x21 = self.do21(self.conv21(x1p)) |
||||
x22_2 = self.do22(self.conv22(x21)) |
||||
x2p = self.pool2(x22_2) |
||||
|
||||
# Stage 3 |
||||
x31 = self.do31(self.conv31(x2p)) |
||||
x32 = self.do32(self.conv32(x31)) |
||||
x33_2 = self.do33(self.conv33(x32)) |
||||
x3p = self.pool3(x33_2) |
||||
|
||||
# Stage 4 |
||||
x41 = self.do41(self.conv41(x3p)) |
||||
x42 = self.do42(self.conv42(x41)) |
||||
x43_2 = self.do43(self.conv43(x42)) |
||||
x4p = self.pool4(x43_2) |
||||
|
||||
# Decode |
||||
# Stage 4d |
||||
x4d = self.upconv4(x4p) |
||||
pad4 = (0, x43_1.shape[3] - x4d.shape[3], 0, |
||||
x43_1.shape[2] - x4d.shape[2]) |
||||
x4d = F.pad(x4d, pad=pad4, mode='replicate') |
||||
x43_1, x43_2 = self.att4(x43_1, x43_2) |
||||
x4d = paddle.concat([x4d, x43_1, x43_2], 1) |
||||
x43d = self.do43d(self.conv43d(x4d)) |
||||
x42d = self.do42d(self.conv42d(x43d)) |
||||
x41d = self.do41d(self.conv41d(x42d)) |
||||
|
||||
# Stage 3d |
||||
x3d = self.upconv3(x41d) |
||||
pad3 = (0, x33_1.shape[3] - x3d.shape[3], 0, |
||||
x33_1.shape[2] - x3d.shape[2]) |
||||
x3d = F.pad(x3d, pad=pad3, mode='replicate') |
||||
x3d = paddle.concat([x3d, x33_1, x33_2], 1) |
||||
x33d = self.do33d(self.conv33d(x3d)) |
||||
x32d = self.do32d(self.conv32d(x33d)) |
||||
x31d = self.do31d(self.conv31d(x32d)) |
||||
|
||||
# Stage 2d |
||||
x2d = self.upconv2(x31d) |
||||
pad2 = (0, x22_1.shape[3] - x2d.shape[3], 0, |
||||
x22_1.shape[2] - x2d.shape[2]) |
||||
x2d = F.pad(x2d, pad=pad2, mode='replicate') |
||||
x2d = paddle.concat([x2d, x22_1, x22_2], 1) |
||||
x22d = self.do22d(self.conv22d(x2d)) |
||||
x21d = self.do21d(self.conv21d(x22d)) |
||||
|
||||
# Stage 1d |
||||
x1d = self.upconv1(x21d) |
||||
pad1 = (0, x12_1.shape[3] - x1d.shape[3], 0, |
||||
x12_1.shape[2] - x1d.shape[2]) |
||||
x1d = F.pad(x1d, pad=pad1, mode='replicate') |
||||
x1d = paddle.concat([x1d, x12_1, x12_2], 1) |
||||
x12d = self.do12d(self.conv12d(x1d)) |
||||
x11d = self.conv11d(x12d) |
||||
|
||||
return [x11d] |
||||
|
||||
def init_weight(self): |
||||
pass |
||||
|
||||
def _make_dropout(self): |
||||
if self.use_dropout: |
||||
return nn.Dropout2D(p=0.2) |
||||
else: |
||||
return Identity() |
||||
|
||||
|
||||
class MixedAttention(nn.Layer): |
||||
def __init__(self, in_channels, att_types='ct'): |
||||
super(MixedAttention, self).__init__() |
||||
|
||||
self.att_types = att_types |
||||
|
||||
if self.has_att_c: |
||||
self.att_c = ChannelAttention(in_channels, ratio=1) |
||||
else: |
||||
self.att_c = Identity() |
||||
|
||||
if self.has_att_t: |
||||
self.att_t = ChannelAttention(2, ratio=1) |
||||
else: |
||||
self.att_t = Identity() |
||||
|
||||
def forward(self, x1, x2): |
||||
if self.has_att_c: |
||||
x1 = (1 + self.att_c(x1)) * x1 |
||||
x2 = (1 + self.att_c(x2)) * x2 |
||||
|
||||
if self.has_att_t: |
||||
b, c = x1.shape[:2] |
||||
y = paddle.stack([x1, x2], axis=2) |
||||
y = paddle.flatten(y, stop_axis=1) |
||||
y = (1 + self.att_t(y)) * y |
||||
y = y.reshape((b, c, 2, *y.shape[2:])) |
||||
y1, y2 = y[:, :, 0], y[:, :, 1] |
||||
else: |
||||
y1, y2 = x1, x2 |
||||
|
||||
return y1, y2 |
||||
|
||||
@property |
||||
def has_att_c(self): |
||||
return 'c' in self.att_types |
||||
|
||||
@property |
||||
def has_att_t(self): |
||||
return 't' in self.att_types |
@ -0,0 +1,79 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import inspect |
||||
|
||||
import paddle |
||||
import paddlers |
||||
from paddlers.tasks.change_detector import BaseChangeDetector |
||||
|
||||
from attach_tools import Attach |
||||
|
||||
attach = Attach.to(paddlers.tasks.change_detector) |
||||
|
||||
|
||||
def make_trainer(net_type, *args, **kwargs): |
||||
def _init_func(self, |
||||
num_classes=2, |
||||
use_mixed_loss=False, |
||||
losses=None, |
||||
**params): |
||||
sig = inspect.signature(net_type.__init__) |
||||
net_params = { |
||||
k: p.default |
||||
for k, p in sig.parameters.items() if not p.default is p.empty |
||||
} |
||||
net_params.pop('self', None) |
||||
net_params.pop('num_classes', None) |
||||
net_params.update(params) |
||||
|
||||
super(trainer_type, self).__init__( |
||||
model_name=net_type.__name__, |
||||
num_classes=num_classes, |
||||
use_mixed_loss=use_mixed_loss, |
||||
losses=losses, |
||||
**net_params) |
||||
|
||||
if not issubclass(net_type, paddle.nn.Layer): |
||||
raise TypeError("net must be a subclass of paddle.nn.Layer") |
||||
|
||||
trainer_name = net_type.__name__ |
||||
|
||||
trainer_type = type(trainer_name, (BaseChangeDetector, ), |
||||
{'__init__': _init_func}) |
||||
|
||||
return trainer_type(*args, **kwargs) |
||||
|
||||
|
||||
@attach |
||||
class CustomTrainer(BaseChangeDetector): |
||||
def __init__(self, |
||||
num_classes=2, |
||||
use_mixed_loss=False, |
||||
losses=None, |
||||
in_channels=3, |
||||
att_types='ct', |
||||
use_dropout=False, |
||||
**params): |
||||
params.update({ |
||||
'in_channels': in_channels, |
||||
'att_types': att_types, |
||||
'use_dropout': use_dropout |
||||
}) |
||||
super().__init__( |
||||
model_name='CustomModel', |
||||
num_classes=num_classes, |
||||
use_mixed_loss=use_mixed_loss, |
||||
losses=losses, |
||||
**params) |
@ -0,0 +1,82 @@ |
||||
#!/usr/bin/env python |
||||
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import argparse |
||||
import os |
||||
import os.path as osp |
||||
|
||||
import cv2 |
||||
import paddle |
||||
import paddlers |
||||
from tqdm import tqdm |
||||
|
||||
import custom_model |
||||
import custom_trainer |
||||
|
||||
|
||||
def read_file_list(file_list, sep=' '): |
||||
with open(file_list, 'r') as f: |
||||
for line in f: |
||||
line = line.strip() |
||||
parts = line.split(sep) |
||||
yield parts |
||||
|
||||
|
||||
def parse_args(): |
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument( |
||||
"--model_dir", default=None, type=str, help="Path of saved model.") |
||||
parser.add_argument("--data_dir", type=str, help="Path of input dataset.") |
||||
parser.add_argument("--file_list", type=str, help="Path of file list.") |
||||
parser.add_argument( |
||||
"--save_dir", |
||||
default='./exp/predict', |
||||
type=str, |
||||
help="Path of directory to save prediction results.") |
||||
parser.add_argument( |
||||
"--ext", |
||||
default='.png', |
||||
type=str, |
||||
help="Extension name of the saved image file.") |
||||
return parser.parse_args() |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
args = parse_args() |
||||
|
||||
model = paddlers.tasks.load_model(args.model_dir) |
||||
|
||||
if not osp.exists(args.save_dir): |
||||
os.makedirs(args.save_dir) |
||||
|
||||
with paddle.no_grad(): |
||||
for parts in tqdm(read_file_list(args.file_list)): |
||||
im1_path = osp.join(args.data_dir, parts[0]) |
||||
im2_path = osp.join(args.data_dir, parts[1]) |
||||
|
||||
pred = model.predict((im1_path, im2_path)) |
||||
cm = pred['label_map'] |
||||
# {0,1} -> {0,255} |
||||
cm[cm > 0] = 255 |
||||
cm = cm.astype('uint8') |
||||
|
||||
if len(parts) > 2: |
||||
name = osp.basename(parts[2]) |
||||
else: |
||||
name = osp.basename(im1_path) |
||||
name = osp.splitext(name)[0] + args.ext |
||||
out_path = osp.join(args.save_dir, name) |
||||
cv2.imwrite(out_path, cm) |
@ -0,0 +1,129 @@ |
||||
#!/usr/bin/env python |
||||
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import os |
||||
|
||||
# Import cv2 and sklearn before paddlers to solve the |
||||
# "ImportError: dlopen: cannot load any more object with static TLS" issue. |
||||
import cv2 |
||||
import sklearn |
||||
import paddle |
||||
import paddlers |
||||
from paddlers import transforms as T |
||||
|
||||
import custom_model |
||||
import custom_trainer |
||||
from config_utils import parse_args, build_objects, CfgNode |
||||
|
||||
|
||||
def format_cfg(cfg, indent=0): |
||||
s = '' |
||||
if isinstance(cfg, dict): |
||||
for i, (k, v) in enumerate(sorted(cfg.items())): |
||||
s += ' ' * indent + str(k) + ': ' |
||||
if isinstance(v, (dict, list, CfgNode)): |
||||
s += '\n' + format_cfg(v, indent=indent + 1) |
||||
else: |
||||
s += str(v) |
||||
if i != len(cfg) - 1: |
||||
s += '\n' |
||||
elif isinstance(cfg, list): |
||||
for i, v in enumerate(cfg): |
||||
s += ' ' * indent + '- ' |
||||
if isinstance(v, (dict, list, CfgNode)): |
||||
s += '\n' + format_cfg(v, indent=indent + 1) |
||||
else: |
||||
s += str(v) |
||||
if i != len(cfg) - 1: |
||||
s += '\n' |
||||
elif isinstance(cfg, CfgNode): |
||||
s += ' ' * indent + f"type: {cfg.type}" + '\n' |
||||
s += ' ' * indent + f"module: {cfg.module}" + '\n' |
||||
s += ' ' * indent + 'args: \n' + format_cfg(cfg.args, indent + 1) |
||||
return s |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
CfgNode.set_context(globals()) |
||||
|
||||
cfg = parse_args() |
||||
print(format_cfg(cfg)) |
||||
|
||||
# Automatically download data |
||||
if cfg['download_on']: |
||||
paddlers.utils.download_and_decompress( |
||||
cfg['download_url'], path=cfg['download_path']) |
||||
|
||||
if not isinstance(cfg['datasets']['eval'].args, dict): |
||||
raise ValueError("args of eval dataset must be a dict!") |
||||
if cfg['datasets']['eval'].args.get('transforms', None) is not None: |
||||
raise ValueError( |
||||
"Found key 'transforms' in args of eval dataset and the value is not None." |
||||
) |
||||
eval_transforms = T.Compose(build_objects(cfg['transforms']['eval'], mod=T)) |
||||
# Inplace modification |
||||
cfg['datasets']['eval'].args['transforms'] = eval_transforms |
||||
eval_dataset = build_objects(cfg['datasets']['eval'], mod=paddlers.datasets) |
||||
|
||||
if cfg['cmd'] == 'train': |
||||
if not isinstance(cfg['datasets']['train'].args, dict): |
||||
raise ValueError("args of train dataset must be a dict!") |
||||
if cfg['datasets']['train'].args.get('transforms', None) is not None: |
||||
raise ValueError( |
||||
"Found key 'transforms' in args of train dataset and the value is not None." |
||||
) |
||||
train_transforms = T.Compose( |
||||
build_objects( |
||||
cfg['transforms']['train'], mod=T)) |
||||
# Inplace modification |
||||
cfg['datasets']['train'].args['transforms'] = train_transforms |
||||
train_dataset = build_objects( |
||||
cfg['datasets']['train'], mod=paddlers.datasets) |
||||
model = build_objects( |
||||
cfg['model'], mod=getattr(paddlers.tasks, cfg['task'])) |
||||
if cfg['optimizer']: |
||||
if len(cfg['optimizer'].args) == 0: |
||||
cfg['optimizer'].args = {} |
||||
if not isinstance(cfg['optimizer'].args, dict): |
||||
raise TypeError("args of optimizer must be a dict!") |
||||
if cfg['optimizer'].args.get('parameters', None) is not None: |
||||
raise ValueError( |
||||
"Found key 'parameters' in args of optimizer and the value is not None." |
||||
) |
||||
cfg['optimizer'].args['parameters'] = model.net.parameters() |
||||
optimizer = build_objects(cfg['optimizer'], mod=paddle.optimizer) |
||||
else: |
||||
optimizer = None |
||||
|
||||
model.train( |
||||
num_epochs=cfg['num_epochs'], |
||||
train_dataset=train_dataset, |
||||
train_batch_size=cfg['train_batch_size'], |
||||
eval_dataset=eval_dataset, |
||||
optimizer=optimizer, |
||||
save_interval_epochs=cfg['save_interval_epochs'], |
||||
log_interval_steps=cfg['log_interval_steps'], |
||||
save_dir=cfg['save_dir'], |
||||
learning_rate=cfg['learning_rate'], |
||||
early_stop=cfg['early_stop'], |
||||
early_stop_patience=cfg['early_stop_patience'], |
||||
use_vdl=cfg['use_vdl'], |
||||
resume_checkpoint=cfg['resume_checkpoint'] or None, |
||||
**cfg['train']) |
||||
elif cfg['cmd'] == 'eval': |
||||
model = paddlers.tasks.load_model(cfg['resume_checkpoint']) |
||||
res = model.evaluate(eval_dataset) |
||||
print(res) |
@ -0,0 +1,17 @@ |
||||
#!/bin/bash |
||||
|
||||
set -e |
||||
|
||||
CONFIG_DIR='configs/levircd/ablation' |
||||
LOG_DIR='exp/logs/ablation' |
||||
|
||||
mkdir -p "${LOG_DIR}" |
||||
|
||||
for config_file in $(ls "${CONFIG_DIR}"/*.yaml); do |
||||
filename="$(basename ${config_file})" |
||||
printf '=%.0s' {1..100} && echo |
||||
echo -e "\033[33m ${config_file} \033[0m" |
||||
printf '=%.0s' {1..100} && echo |
||||
python run_task.py train cd --config "${config_file}" 2>&1 | tee "${LOG_DIR}/${filename%.*}.log" |
||||
echo |
||||
done |
@ -0,0 +1,22 @@ |
||||
#!/bin/bash |
||||
|
||||
set -e |
||||
|
||||
DATASET='levircd' |
||||
|
||||
config_dir="configs/${DATASET}" |
||||
log_dir="exp/logs/${DATASET}" |
||||
|
||||
mkdir -p "${log_dir}" |
||||
|
||||
for config_file in $(ls "${config_dir}"/*.yaml); do |
||||
filename="$(basename ${config_file})" |
||||
if [ "${filename}" = "${DATASET}.yaml" ]; then |
||||
continue |
||||
fi |
||||
printf '=%.0s' {1..100} && echo |
||||
echo -e "\033[33m ${config_file} \033[0m" |
||||
printf '=%.0s' {1..100} && echo |
||||
python run_task.py train cd --config "${config_file}" 2>&1 | tee "${log_dir}/${filename%.*}.log" |
||||
echo |
||||
done |
@ -0,0 +1,148 @@ |
||||
#!/usr/bin/env python |
||||
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
# Refer to https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/tools/analyze_model.py |
||||
|
||||
import argparse |
||||
import os |
||||
import os.path as osp |
||||
import sys |
||||
|
||||
import paddle |
||||
import numpy as np |
||||
import paddlers |
||||
from paddle.hapi.dynamic_flops import (count_parameters, register_hooks, |
||||
count_io_info) |
||||
from paddle.hapi.static_flops import Table |
||||
|
||||
_dir = osp.dirname(osp.abspath(__file__)) |
||||
sys.path.append(osp.abspath(osp.join(_dir, '../'))) |
||||
import custom_model |
||||
import custom_trainer |
||||
|
||||
|
||||
def parse_args(): |
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument( |
||||
"--model_dir", default=None, type=str, help="Path of saved model.") |
||||
parser.add_argument( |
||||
"--input_shape", |
||||
nargs='+', |
||||
type=int, |
||||
default=[1, 3, 256, 256], |
||||
help="Shape of each input tensor.") |
||||
return parser.parse_args() |
||||
|
||||
|
||||
def analyze(model, inputs, custom_ops=None, print_detail=False): |
||||
handler_collection = [] |
||||
types_collection = set() |
||||
if custom_ops is None: |
||||
custom_ops = {} |
||||
|
||||
def add_hooks(m): |
||||
if len(list(m.children())) > 0: |
||||
return |
||||
m.register_buffer('total_ops', paddle.zeros([1], dtype='int64')) |
||||
m.register_buffer('total_params', paddle.zeros([1], dtype='int64')) |
||||
m_type = type(m) |
||||
|
||||
flops_fn = None |
||||
if m_type in custom_ops: |
||||
flops_fn = custom_ops[m_type] |
||||
if m_type not in types_collection: |
||||
print("Customized function has been applied to {}".format( |
||||
m_type)) |
||||
elif m_type in register_hooks: |
||||
flops_fn = register_hooks[m_type] |
||||
if m_type not in types_collection: |
||||
print("{}'s FLOPs metric has been counted".format(m_type)) |
||||
else: |
||||
if m_type not in types_collection: |
||||
print( |
||||
"Cannot find suitable counting function for {}. Treat it as zero FLOPs." |
||||
.format(m_type)) |
||||
|
||||
if flops_fn is not None: |
||||
flops_handler = m.register_forward_post_hook(flops_fn) |
||||
handler_collection.append(flops_handler) |
||||
params_handler = m.register_forward_post_hook(count_parameters) |
||||
io_handler = m.register_forward_post_hook(count_io_info) |
||||
handler_collection.append(params_handler) |
||||
handler_collection.append(io_handler) |
||||
types_collection.add(m_type) |
||||
|
||||
training = model.training |
||||
|
||||
model.eval() |
||||
model.apply(add_hooks) |
||||
|
||||
with paddle.framework.no_grad(): |
||||
model(*inputs) |
||||
|
||||
total_ops = 0 |
||||
total_params = 0 |
||||
for m in model.sublayers(): |
||||
if len(list(m.children())) > 0: |
||||
continue |
||||
if set(['total_ops', 'total_params', 'input_shape', |
||||
'output_shape']).issubset(set(list(m._buffers.keys()))): |
||||
total_ops += m.total_ops |
||||
total_params += m.total_params |
||||
|
||||
if training: |
||||
model.train() |
||||
for handler in handler_collection: |
||||
handler.remove() |
||||
|
||||
table = Table( |
||||
["Layer Name", "Input Shape", "Output Shape", "Params(M)", "FLOPs(G)"]) |
||||
|
||||
for n, m in model.named_sublayers(): |
||||
if len(list(m.children())) > 0: |
||||
continue |
||||
if set(['total_ops', 'total_params', 'input_shape', |
||||
'output_shape']).issubset(set(list(m._buffers.keys()))): |
||||
table.add_row([ |
||||
m.full_name(), list(m.input_shape.numpy()), |
||||
list(m.output_shape.numpy()), |
||||
round(float(m.total_params / 1e6), 3), |
||||
round(float(m.total_ops / 1e9), 3) |
||||
]) |
||||
m._buffers.pop("total_ops") |
||||
m._buffers.pop("total_params") |
||||
m._buffers.pop('input_shape') |
||||
m._buffers.pop('output_shape') |
||||
if print_detail: |
||||
table.print_table() |
||||
print('Total FLOPs: {}G Total Params: {}M'.format( |
||||
round(float(total_ops / 1e9), 3), round(float(total_params / 1e6), 3))) |
||||
return int(total_ops) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
args = parse_args() |
||||
|
||||
# Enforce the use of CPU |
||||
paddle.set_device('cpu') |
||||
|
||||
model = paddlers.tasks.load_model(args.model_dir) |
||||
net = model.net |
||||
|
||||
# Construct bi-temporal inputs |
||||
inputs = [paddle.randn(args.input_shape), paddle.randn(args.input_shape)] |
||||
|
||||
analyze(model.net, inputs) |
@ -0,0 +1,75 @@ |
||||
#!/usr/bin/env python |
||||
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import argparse |
||||
import os |
||||
import os.path as osp |
||||
import shutil |
||||
from glob import glob |
||||
|
||||
from tqdm import tqdm |
||||
|
||||
|
||||
def get_subdir_name(src_path): |
||||
basename = osp.basename(src_path) |
||||
subdir_name, _ = osp.splitext(basename) |
||||
return subdir_name |
||||
|
||||
|
||||
def parse_args(): |
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument( |
||||
"--mode", |
||||
default='copy', |
||||
type=str, |
||||
choices=['copy', 'link'], |
||||
help="Copy or link images.") |
||||
parser.add_argument( |
||||
"--globs", |
||||
nargs='+', |
||||
type=str, |
||||
help="Glob patterns used to find the images to be copied.") |
||||
parser.add_argument( |
||||
"--tags", nargs='+', type=str, help="Tags of each source directory.") |
||||
parser.add_argument( |
||||
"--save_dir", |
||||
default='./', |
||||
type=str, |
||||
help="Path of directory to save collected results.") |
||||
return parser.parse_args() |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
args = parse_args() |
||||
|
||||
if len(args.globs) != len(args.tags): |
||||
raise ValueError( |
||||
"The number of globs does not match the number of tags!") |
||||
|
||||
for pat, tag in zip(args.globs, args.tags): |
||||
im_paths = glob(pat) |
||||
print(f"Glob: {pat}\tTag: {tag}") |
||||
for p in tqdm(im_paths): |
||||
subdir_name = get_subdir_name(p) |
||||
ext = osp.splitext(p)[1] |
||||
subdir_path = osp.join(args.save_dir, subdir_name) |
||||
subdir_path = osp.abspath(osp.normpath(subdir_path)) |
||||
if not osp.exists(subdir_path): |
||||
os.makedirs(subdir_path) |
||||
if args.mode == 'copy': |
||||
shutil.copyfile(p, osp.join(subdir_path, tag + ext)) |
||||
elif args.mode == 'link': |
||||
os.symlink(p, osp.join(subdir_path, tag + ext)) |
@ -0,0 +1,228 @@ |
||||
#!/usr/bin/env python |
||||
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import argparse |
||||
import sys |
||||
import os |
||||
import os.path as osp |
||||
from collections import OrderedDict |
||||
|
||||
import numpy as np |
||||
import cv2 |
||||
import paddle |
||||
import paddlers |
||||
from sklearn.decomposition import PCA |
||||
|
||||
_dir = osp.dirname(osp.abspath(__file__)) |
||||
sys.path.append(osp.abspath(osp.join(_dir, '../'))) |
||||
import custom_model |
||||
import custom_trainer |
||||
|
||||
FILENAME_PATTERN = "{key}_{idx}_vis.png" |
||||
|
||||
|
||||
class FeatureContainer: |
||||
def __init__(self): |
||||
self._dict = OrderedDict() |
||||
|
||||
def __setitem__(self, key, val): |
||||
if key not in self._dict: |
||||
self._dict[key] = list() |
||||
self._dict[key].append(val) |
||||
|
||||
def __getitem__(self, key): |
||||
return self._dict[key] |
||||
|
||||
def __repr__(self): |
||||
return self._dict.__repr__() |
||||
|
||||
def items(self): |
||||
return self._dict.items() |
||||
|
||||
def keys(self): |
||||
return self._dict.keys() |
||||
|
||||
def values(self): |
||||
return self._dict.values() |
||||
|
||||
|
||||
class HookHelper: |
||||
def __init__(self, |
||||
model, |
||||
fetch_dict, |
||||
out_dict, |
||||
hook_type='forward_out', |
||||
auto_key=True): |
||||
# XXX: A HookHelper object should only be used as a context manager and should not |
||||
# persist in memory since it may keep references to some very large objects. |
||||
self.model = model |
||||
self.fetch_dict = fetch_dict |
||||
self.out_dict = out_dict |
||||
self._handles = [] |
||||
self.hook_type = hook_type |
||||
self.auto_key = auto_key |
||||
|
||||
def __enter__(self): |
||||
def _hook_proto(x, entry): |
||||
# `x` should be a tensor or a tuple; |
||||
# entry is expected to be a string or a non-nested tuple. |
||||
if isinstance(entry, tuple): |
||||
for key, f in zip(entry, x): |
||||
self.out_dict[key] = f.detach().clone() |
||||
else: |
||||
if isinstance(x, tuple) and self.auto_key: |
||||
for i, f in enumerate(x): |
||||
key = self._gen_key(entry, i) |
||||
self.out_dict[key] = f.detach().clone() |
||||
else: |
||||
self.out_dict[entry] = x.detach().clone() |
||||
|
||||
if self.hook_type == 'forward_in': |
||||
# NOTE: Register forward hooks for LAYERs |
||||
for name, layer in self.model.named_sublayers(): |
||||
if name in self.fetch_dict: |
||||
entry = self.fetch_dict[name] |
||||
self._handles.append( |
||||
layer.register_forward_pre_hook( |
||||
lambda l, x, entry=entry: |
||||
# x is a tuple |
||||
_hook_proto(x[0] if len(x)==1 else x, entry) |
||||
) |
||||
) |
||||
elif self.hook_type == 'forward_out': |
||||
# NOTE: Register forward hooks for LAYERs. |
||||
for name, module in self.model.named_sublayers(): |
||||
if name in self.fetch_dict: |
||||
entry = self.fetch_dict[name] |
||||
self._handles.append( |
||||
module.register_forward_post_hook( |
||||
lambda l, x, y, entry=entry: |
||||
# y is a tensor or a tuple |
||||
_hook_proto(y, entry) |
||||
) |
||||
) |
||||
elif self.hook_type == 'backward': |
||||
# NOTE: Register backward hooks for TENSORs. |
||||
for name, param in self.model.named_parameters(): |
||||
if name in self.fetch_dict: |
||||
entry = self.fetch_dict[name] |
||||
self._handles.append( |
||||
param.register_hook( |
||||
lambda grad, entry=entry: _hook_proto(grad, entry))) |
||||
else: |
||||
raise RuntimeError("Hook type is not implemented.") |
||||
|
||||
def __exit__(self, exc_type, exc_val, ext_tb): |
||||
for handle in self._handles: |
||||
handle.remove() |
||||
|
||||
def _gen_key(self, key, i): |
||||
return key + f'_{i}' |
||||
|
||||
|
||||
def parse_args(): |
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument( |
||||
"--model_dir", default=None, type=str, help="Path of saved model.") |
||||
parser.add_argument( |
||||
"--hook_type", default='forward_out', type=str, help="Type of hook.") |
||||
parser.add_argument( |
||||
"--layer_names", |
||||
nargs='+', |
||||
default=[], |
||||
type=str, |
||||
help="Layers that accepts or produces the features to visualize.") |
||||
parser.add_argument( |
||||
"--im_paths", nargs='+', type=str, help="Paths of input images.") |
||||
parser.add_argument( |
||||
"--save_dir", |
||||
type=str, |
||||
help="Path of directory to save prediction results.") |
||||
parser.add_argument( |
||||
"--to_pseudo_color", |
||||
action='store_true', |
||||
help="Whether to save pseudo-color images.") |
||||
parser.add_argument( |
||||
"--output_size", |
||||
nargs='+', |
||||
type=int, |
||||
default=None, |
||||
help="Resize the visualized image to `output_size`.") |
||||
return parser.parse_args() |
||||
|
||||
|
||||
def normalize_minmax(x): |
||||
EPS = 1e-32 |
||||
return (x - x.min()) / (x.max() - x.min() + EPS) |
||||
|
||||
|
||||
def quantize_8bit(x): |
||||
# [0.0,1.0] float => [0,255] uint8 |
||||
# or [0,1] int => [0,255] uint8 |
||||
return (x * 255).astype('uint8') |
||||
|
||||
|
||||
def to_pseudo_color(gray, color_map=cv2.COLORMAP_JET): |
||||
return cv2.applyColorMap(gray, color_map) |
||||
|
||||
|
||||
def process_fetched_feat(feat, to_pcolor=True): |
||||
# Convert tensor to array |
||||
feat = feat.squeeze(0).numpy() |
||||
# Get principal component |
||||
shape = feat.shape |
||||
x = feat.reshape(shape[0], -1).transpose((1, 0)) |
||||
pca = PCA(n_components=1) |
||||
y = pca.fit_transform(x) |
||||
feat = y.reshape(shape[1:]) |
||||
feat = normalize_minmax(feat) |
||||
feat = quantize_8bit(feat) |
||||
if to_pcolor: |
||||
feat = to_pseudo_color(feat) |
||||
return feat |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
args = parse_args() |
||||
|
||||
# Load model |
||||
model = paddlers.tasks.load_model(args.model_dir) |
||||
|
||||
fetch_dict = dict(zip(args.layer_names, args.layer_names)) |
||||
out_dict = FeatureContainer() |
||||
|
||||
with HookHelper(model.net, fetch_dict, out_dict, hook_type=args.hook_type): |
||||
if len(args.im_paths) == 1: |
||||
model.predict(args.im_paths[0]) |
||||
else: |
||||
if len(args.im_paths) != 2: |
||||
raise ValueError |
||||
model.predict(tuple(args.im_paths)) |
||||
|
||||
if not osp.exists(args.save_dir): |
||||
os.makedirs(args.save_dir) |
||||
|
||||
for key, feats in out_dict.items(): |
||||
for idx, feat in enumerate(feats): |
||||
im_vis = process_fetched_feat(feat, to_pcolor=args.to_pseudo_color) |
||||
if args.output_size is not None: |
||||
im_vis = cv2.resize(im_vis, tuple(args.output_size)) |
||||
out_path = osp.join( |
||||
args.save_dir, |
||||
FILENAME_PATTERN.format( |
||||
key=key.replace('.', '_'), idx=idx)) |
||||
cv2.imwrite(out_path, im_vis) |
||||
print(f"Write feature map to {out_path}") |
@ -0,0 +1,111 @@ |
||||
#!/usr/bin/env bash |
||||
|
||||
import os.path as osp |
||||
|
||||
import paddle |
||||
import paddlers as pdrs |
||||
from paddlers import transforms as T |
||||
|
||||
from custom_model import CustomModel |
||||
from custom_trainer import make_trainer |
||||
|
||||
# 数据集路径 |
||||
DATA_DIR = 'data/levircd/' |
||||
# 保存实验结果的路径 |
||||
EXP_DIR = 'exp/levircd/custom_model/' |
||||
|
||||
# 定义训练和验证时使用的数据变换(数据增强、预处理等) |
||||
# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行 |
||||
# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md |
||||
train_transforms = T.Compose([ |
||||
# 读取影像 |
||||
T.DecodeImg(), |
||||
# 随机翻转和旋转 |
||||
T.RandomFlipOrRotate( |
||||
# 以0.35的概率执行随机翻转,0.35的概率执行随机旋转 |
||||
probs=[0.35, 0.35], |
||||
# 以0.5的概率执行随机水平翻转,0.5的概率执行随机垂直翻转 |
||||
probsf=[0.5, 0.5, 0, 0, 0], |
||||
# 分别以0.33、0.34和0.33的概率执行90°、180°和270°旋转 |
||||
probsr=[0.33, 0.34, 0.33]), |
||||
# 将数据归一化到[-1,1] |
||||
T.Normalize( |
||||
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
||||
T.ArrangeChangeDetector('train') |
||||
]) |
||||
|
||||
eval_transforms = T.Compose([ |
||||
T.DecodeImg(), |
||||
# 验证阶段与训练阶段的数据归一化方式必须相同 |
||||
T.Normalize( |
||||
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
||||
T.ArrangeChangeDetector('eval') |
||||
]) |
||||
|
||||
# 分别构建训练、验证和测试所用的数据集 |
||||
train_dataset = pdrs.datasets.CDDataset( |
||||
data_dir=DATA_DIR, |
||||
file_list=osp.join(DATA_DIR, 'train.txt'), |
||||
label_list=None, |
||||
transforms=train_transforms, |
||||
num_workers=0, |
||||
shuffle=True, |
||||
with_seg_labels=False, |
||||
binarize_labels=True) |
||||
|
||||
val_dataset = pdrs.datasets.CDDataset( |
||||
data_dir=DATA_DIR, |
||||
file_list=osp.join(DATA_DIR, 'val.txt'), |
||||
label_list=None, |
||||
transforms=eval_transforms, |
||||
num_workers=0, |
||||
shuffle=False, |
||||
with_seg_labels=False, |
||||
binarize_labels=True) |
||||
|
||||
test_dataset = pdrs.datasets.CDDataset( |
||||
data_dir=DATA_DIR, |
||||
file_list=osp.join(DATA_DIR, 'test.txt'), |
||||
label_list=None, |
||||
# 与验证阶段使用相同的数据变换算子 |
||||
transforms=eval_transforms, |
||||
num_workers=0, |
||||
shuffle=False, |
||||
with_seg_labels=False, |
||||
binarize_labels=True) |
||||
|
||||
# 构建自定义模型CustomModel并为其自动生成训练器 |
||||
# make_trainer()的首个参数为模型类型,剩余参数为模型构造所需参数 |
||||
model = make_trainer(CustomModel, in_channels=3) |
||||
|
||||
# 构建学习率调度器 |
||||
# 使用定步长学习率衰减策略 |
||||
lr_scheduler = paddle.optimizer.lr.StepDecay( |
||||
learning_rate=0.002, step_size=35000, gamma=0.2) |
||||
|
||||
# 构建优化器 |
||||
optimizer = paddle.optimizer.Adam( |
||||
parameters=model.net.parameters(), learning_rate=lr_scheduler) |
||||
|
||||
# 执行模型训练 |
||||
model.train( |
||||
num_epochs=50, |
||||
train_dataset=train_dataset, |
||||
train_batch_size=8, |
||||
eval_dataset=eval_dataset, |
||||
# 每多少个epoch验证并保存一次模型 |
||||
save_interval_epochs=5, |
||||
# 每多少次迭代记录一次日志 |
||||
log_interval_steps=50, |
||||
save_dir=EXP_DIR, |
||||
# 是否使用early stopping策略,当精度不再改善时提前终止训练 |
||||
early_stop=False, |
||||
# 是否启用VisualDL日志功能 |
||||
use_vdl=True, |
||||
# 指定从某个检查点继续训练 |
||||
resume_checkpoint=None) |
||||
|
||||
# 加载验证集上效果最好的模型 |
||||
model = pdrs.tasks.load_model(osp.join(EXP_DIR, 'best_model')) |
||||
# 在测试集上计算精度指标 |
||||
model.evaluate(test_dataset) |
@ -0,0 +1,83 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import os.path as osp |
||||
import copy |
||||
|
||||
from .base import BaseDataset |
||||
from paddlers.utils import logging, get_encoding, norm_path, is_pic |
||||
|
||||
|
||||
class ResDataset(BaseDataset): |
||||
""" |
||||
Dataset for image restoration tasks. |
||||
|
||||
Args: |
||||
data_dir (str): Root directory of the dataset. |
||||
file_list (str): Path of the file that contains relative paths of source and target image files. |
||||
transforms (paddlers.transforms.Compose): Data preprocessing and data augmentation operators to apply. |
||||
num_workers (int|str, optional): Number of processes used for data loading. If `num_workers` is 'auto', |
||||
the number of workers will be automatically determined according to the number of CPU cores: If |
||||
there are more than 16 cores,8 workers will be used. Otherwise, the number of workers will be half |
||||
the number of CPU cores. Defaults: 'auto'. |
||||
shuffle (bool, optional): Whether to shuffle the samples. Defaults to False. |
||||
sr_factor (int|None, optional): Scaling factor of image super-resolution task. None for other image |
||||
restoration tasks. Defaults to None. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
data_dir, |
||||
file_list, |
||||
transforms, |
||||
num_workers='auto', |
||||
shuffle=False, |
||||
sr_factor=None): |
||||
super(ResDataset, self).__init__(data_dir, None, transforms, |
||||
num_workers, shuffle) |
||||
self.batch_transforms = None |
||||
self.file_list = list() |
||||
|
||||
with open(file_list, encoding=get_encoding(file_list)) as f: |
||||
for line in f: |
||||
items = line.strip().split() |
||||
if len(items) > 2: |
||||
raise ValueError( |
||||
"A space is defined as the delimiter to separate the source and target image path, " \ |
||||
"so the space cannot be in the source image or target image path, but the line[{}] of " \ |
||||
" file_list[{}] has a space in the two paths.".format(line, file_list)) |
||||
items[0] = norm_path(items[0]) |
||||
items[1] = norm_path(items[1]) |
||||
full_path_im = osp.join(data_dir, items[0]) |
||||
full_path_tar = osp.join(data_dir, items[1]) |
||||
if not is_pic(full_path_im) or not is_pic(full_path_tar): |
||||
continue |
||||
if not osp.exists(full_path_im): |
||||
raise IOError("Source image file {} does not exist!".format( |
||||
full_path_im)) |
||||
if not osp.exists(full_path_tar): |
||||
raise IOError("Target image file {} does not exist!".format( |
||||
full_path_tar)) |
||||
sample = { |
||||
'image': full_path_im, |
||||
'target': full_path_tar, |
||||
} |
||||
if sr_factor is not None: |
||||
sample['sr_factor'] = sr_factor |
||||
self.file_list.append(sample) |
||||
self.num_samples = len(self.file_list) |
||||
logging.info("{} samples in file {}".format( |
||||
len(self.file_list), file_list)) |
||||
|
||||
def __len__(self): |
||||
return len(self.file_list) |
@ -1,99 +0,0 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
|
||||
# 超分辨率数据集定义 |
||||
class SRdataset(object): |
||||
def __init__(self, |
||||
mode, |
||||
gt_floder, |
||||
lq_floder, |
||||
transforms, |
||||
scale, |
||||
num_workers=4, |
||||
batch_size=8): |
||||
if mode == 'train': |
||||
preprocess = [] |
||||
preprocess.append({ |
||||
'name': 'LoadImageFromFile', |
||||
'key': 'lq' |
||||
}) # 加载方式 |
||||
preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'}) |
||||
preprocess.append(transforms) # 变换方式 |
||||
self.dataset = { |
||||
'name': 'SRDataset', |
||||
'gt_folder': gt_floder, |
||||
'lq_folder': lq_floder, |
||||
'num_workers': num_workers, |
||||
'batch_size': batch_size, |
||||
'scale': scale, |
||||
'preprocess': preprocess |
||||
} |
||||
|
||||
if mode == "test": |
||||
preprocess = [] |
||||
preprocess.append({'name': 'LoadImageFromFile', 'key': 'lq'}) |
||||
preprocess.append({'name': 'LoadImageFromFile', 'key': 'gt'}) |
||||
preprocess.append(transforms) |
||||
self.dataset = { |
||||
'name': 'SRDataset', |
||||
'gt_folder': gt_floder, |
||||
'lq_folder': lq_floder, |
||||
'scale': scale, |
||||
'preprocess': preprocess |
||||
} |
||||
|
||||
def __call__(self): |
||||
return self.dataset |
||||
|
||||
|
||||
# 对定义的transforms处理方式组合,返回字典 |
||||
class ComposeTrans(object): |
||||
def __init__(self, input_keys, output_keys, pipelines): |
||||
if not isinstance(pipelines, list): |
||||
raise TypeError( |
||||
'Type of transforms is invalid. Must be List, but received is {}' |
||||
.format(type(pipelines))) |
||||
if len(pipelines) < 1: |
||||
raise ValueError( |
||||
'Length of transforms must not be less than 1, but received is {}' |
||||
.format(len(pipelines))) |
||||
self.transforms = pipelines |
||||
self.output_length = len(output_keys) # 当output_keys的长度为3时,是DRN训练 |
||||
self.input_keys = input_keys |
||||
self.output_keys = output_keys |
||||
|
||||
def __call__(self): |
||||
pipeline = [] |
||||
for op in self.transforms: |
||||
if op['name'] == 'SRPairedRandomCrop': |
||||
op['keys'] = ['image'] * 2 |
||||
else: |
||||
op['keys'] = ['image'] * self.output_length |
||||
pipeline.append(op) |
||||
if self.output_length == 2: |
||||
transform_dict = { |
||||
'name': 'Transforms', |
||||
'input_keys': self.input_keys, |
||||
'pipeline': pipeline |
||||
} |
||||
else: |
||||
transform_dict = { |
||||
'name': 'Transforms', |
||||
'input_keys': self.input_keys, |
||||
'output_keys': self.output_keys, |
||||
'pipeline': pipeline |
||||
} |
||||
|
||||
return transform_dict |
@ -0,0 +1,478 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
import paddle.nn.functional as F |
||||
|
||||
from .layers import BasicConv, MaxPool2x2, Conv1x1, Conv3x3 |
||||
|
||||
bn_mom = 1 - 0.0003 |
||||
|
||||
|
||||
class NLBlock(nn.Layer): |
||||
def __init__(self, in_channels): |
||||
super(NLBlock, self).__init__() |
||||
self.conv_v = BasicConv( |
||||
in_ch=in_channels, |
||||
out_ch=in_channels, |
||||
kernel_size=3, |
||||
norm=nn.BatchNorm2D( |
||||
in_channels, momentum=0.9)) |
||||
self.W = BasicConv( |
||||
in_ch=in_channels, |
||||
out_ch=in_channels, |
||||
kernel_size=3, |
||||
norm=nn.BatchNorm2D( |
||||
in_channels, momentum=0.9), |
||||
act=nn.ReLU()) |
||||
|
||||
def forward(self, x): |
||||
batch_size, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] |
||||
value = self.conv_v(x) |
||||
value = value.reshape([batch_size, c, value.shape[2] * value.shape[3]]) |
||||
value = value.transpose([0, 2, 1]) # B * (H*W) * value_channels |
||||
key = x.reshape([batch_size, c, h * w]) # B * key_channels * (H*W) |
||||
query = x.reshape([batch_size, c, h * w]) |
||||
query = query.transpose([0, 2, 1]) |
||||
|
||||
sim_map = paddle.matmul(query, key) # B * (H*W) * (H*W) |
||||
sim_map = (c**-.5) * sim_map # B * (H*W) * (H*W) |
||||
sim_map = nn.functional.softmax(sim_map, axis=-1) # B * (H*W) * (H*W) |
||||
|
||||
context = paddle.matmul(sim_map, value) |
||||
context = context.transpose([0, 2, 1]) |
||||
context = context.reshape([batch_size, c, *x.shape[2:]]) |
||||
context = self.W(context) |
||||
|
||||
return context |
||||
|
||||
|
||||
class NLFPN(nn.Layer): |
||||
""" Non-local feature parymid network""" |
||||
|
||||
def __init__(self, in_dim, reduction=True): |
||||
super(NLFPN, self).__init__() |
||||
if reduction: |
||||
self.reduction = BasicConv( |
||||
in_ch=in_dim, |
||||
out_ch=in_dim // 4, |
||||
kernel_size=1, |
||||
norm=nn.BatchNorm2D( |
||||
in_dim // 4, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
self.re_reduction = BasicConv( |
||||
in_ch=in_dim // 4, |
||||
out_ch=in_dim, |
||||
kernel_size=1, |
||||
norm=nn.BatchNorm2D( |
||||
in_dim, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
in_dim = in_dim // 4 |
||||
else: |
||||
self.reduction = None |
||||
self.re_reduction = None |
||||
self.conv_e1 = BasicConv( |
||||
in_dim, |
||||
in_dim, |
||||
kernel_size=3, |
||||
norm=nn.BatchNorm2D( |
||||
in_dim, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
self.conv_e2 = BasicConv( |
||||
in_dim, |
||||
in_dim * 2, |
||||
kernel_size=3, |
||||
norm=nn.BatchNorm2D( |
||||
in_dim * 2, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
self.conv_e3 = BasicConv( |
||||
in_dim * 2, |
||||
in_dim * 4, |
||||
kernel_size=3, |
||||
norm=nn.BatchNorm2D( |
||||
in_dim * 4, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
self.conv_d1 = BasicConv( |
||||
in_dim, |
||||
in_dim, |
||||
kernel_size=3, |
||||
norm=nn.BatchNorm2D( |
||||
in_dim, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
self.conv_d2 = BasicConv( |
||||
in_dim * 2, |
||||
in_dim, |
||||
kernel_size=3, |
||||
norm=nn.BatchNorm2D( |
||||
in_dim, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
self.conv_d3 = BasicConv( |
||||
in_dim * 4, |
||||
in_dim * 2, |
||||
kernel_size=3, |
||||
norm=nn.BatchNorm2D( |
||||
in_dim * 2, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
self.nl3 = NLBlock(in_dim * 2) |
||||
self.nl2 = NLBlock(in_dim) |
||||
self.nl1 = NLBlock(in_dim) |
||||
|
||||
self.downsample_x2 = nn.MaxPool2D(stride=2, kernel_size=2) |
||||
self.upsample_x2 = nn.UpsamplingBilinear2D(scale_factor=2) |
||||
|
||||
def forward(self, x): |
||||
if self.reduction is not None: |
||||
x = self.reduction(x) |
||||
e1 = self.conv_e1(x) # C,H,W |
||||
e2 = self.conv_e2(self.downsample_x2(e1)) # 2C,H/2,W/2 |
||||
e3 = self.conv_e3(self.downsample_x2(e2)) # 4C,H/4,W/4 |
||||
|
||||
d3 = self.conv_d3(e3) # 2C,H/4,W/4 |
||||
nl = self.nl3(d3) |
||||
d3 = self.upsample_x2(paddle.multiply(d3, nl)) ##2C,H/2,W/2 |
||||
d2 = self.conv_d2(e2 + d3) # C,H/2,W/2 |
||||
nl = self.nl2(d2) |
||||
d2 = self.upsample_x2(paddle.multiply(d2, nl)) # C,H,W |
||||
d1 = self.conv_d1(e1 + d2) |
||||
nl = self.nl1(d1) |
||||
d1 = paddle.multiply(d1, nl) # C,H,W |
||||
if self.re_reduction is not None: |
||||
d1 = self.re_reduction(d1) |
||||
|
||||
return d1 |
||||
|
||||
|
||||
class Cat(nn.Layer): |
||||
def __init__(self, in_chn_high, in_chn_low, out_chn, upsample=False): |
||||
super(Cat, self).__init__() |
||||
self.do_upsample = upsample |
||||
self.upsample = nn.Upsample(scale_factor=2, mode="nearest") |
||||
self.conv2d = BasicConv( |
||||
in_chn_high + in_chn_low, |
||||
out_chn, |
||||
kernel_size=1, |
||||
norm=nn.BatchNorm2D( |
||||
out_chn, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
|
||||
def forward(self, x, y): |
||||
if self.do_upsample: |
||||
x = self.upsample(x) |
||||
|
||||
x = paddle.concat((x, y), 1) |
||||
|
||||
return self.conv2d(x) |
||||
|
||||
|
||||
class DoubleConv(nn.Layer): |
||||
def __init__(self, in_chn, out_chn, stride=1, dilation=1): |
||||
super(DoubleConv, self).__init__() |
||||
self.conv = nn.Sequential( |
||||
nn.Conv2D( |
||||
in_chn, |
||||
out_chn, |
||||
kernel_size=3, |
||||
stride=stride, |
||||
dilation=dilation, |
||||
padding=dilation), |
||||
nn.BatchNorm2D( |
||||
out_chn, momentum=bn_mom), |
||||
nn.ReLU(), |
||||
nn.Conv2D( |
||||
out_chn, out_chn, kernel_size=3, stride=1, padding=1), |
||||
nn.BatchNorm2D( |
||||
out_chn, momentum=bn_mom), |
||||
nn.ReLU()) |
||||
|
||||
def forward(self, x): |
||||
x = self.conv(x) |
||||
return x |
||||
|
||||
|
||||
class SEModule(nn.Layer): |
||||
def __init__(self, channels, reduction_channels): |
||||
super(SEModule, self).__init__() |
||||
self.fc1 = nn.Conv2D( |
||||
channels, |
||||
reduction_channels, |
||||
kernel_size=1, |
||||
padding=0, |
||||
bias_attr=True) |
||||
self.ReLU = nn.ReLU() |
||||
self.fc2 = nn.Conv2D( |
||||
reduction_channels, |
||||
channels, |
||||
kernel_size=1, |
||||
padding=0, |
||||
bias_attr=True) |
||||
|
||||
def forward(self, x): |
||||
x_se = x.reshape( |
||||
[x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).mean(-1).reshape( |
||||
[x.shape[0], x.shape[1], 1, 1]) |
||||
|
||||
x_se = self.fc1(x_se) |
||||
x_se = self.ReLU(x_se) |
||||
x_se = self.fc2(x_se) |
||||
return x * F.sigmoid(x_se) |
||||
|
||||
|
||||
class BasicBlock(nn.Layer): |
||||
expansion = 1 |
||||
|
||||
def __init__(self, |
||||
inplanes, |
||||
planes, |
||||
downsample=None, |
||||
use_se=False, |
||||
stride=1, |
||||
dilation=1): |
||||
super(BasicBlock, self).__init__() |
||||
first_planes = planes |
||||
outplanes = planes * self.expansion |
||||
|
||||
self.conv1 = DoubleConv(inplanes, first_planes) |
||||
self.conv2 = DoubleConv( |
||||
first_planes, outplanes, stride=stride, dilation=dilation) |
||||
self.se = SEModule(outplanes, planes // 4) if use_se else None |
||||
self.downsample = MaxPool2x2() if downsample else None |
||||
self.ReLU = nn.ReLU() |
||||
|
||||
def forward(self, x): |
||||
out = self.conv1(x) |
||||
residual = out |
||||
out = self.conv2(out) |
||||
|
||||
if self.se is not None: |
||||
out = self.se(out) |
||||
|
||||
if self.downsample is not None: |
||||
residual = self.downsample(residual) |
||||
|
||||
out = out + residual |
||||
out = self.ReLU(out) |
||||
return out |
||||
|
||||
|
||||
class DenseCatAdd(nn.Layer): |
||||
def __init__(self, in_chn, out_chn): |
||||
super(DenseCatAdd, self).__init__() |
||||
self.conv1 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) |
||||
self.conv2 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) |
||||
self.conv3 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) |
||||
self.conv_out = BasicConv( |
||||
in_chn, |
||||
out_chn, |
||||
kernel_size=1, |
||||
norm=nn.BatchNorm2D( |
||||
out_chn, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
|
||||
def forward(self, x, y): |
||||
x1 = self.conv1(x) |
||||
x2 = self.conv2(x1) |
||||
x3 = self.conv3(x2 + x1) |
||||
|
||||
y1 = self.conv1(y) |
||||
y2 = self.conv2(y1) |
||||
y3 = self.conv3(y2 + y1) |
||||
|
||||
return self.conv_out(x1 + x2 + x3 + y1 + y2 + y3) |
||||
|
||||
|
||||
class DenseCatDiff(nn.Layer): |
||||
def __init__(self, in_chn, out_chn): |
||||
super(DenseCatDiff, self).__init__() |
||||
self.conv1 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) |
||||
self.conv2 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) |
||||
self.conv3 = BasicConv(in_chn, in_chn, kernel_size=3, act=nn.ReLU()) |
||||
self.conv_out = BasicConv( |
||||
in_ch=in_chn, |
||||
out_ch=out_chn, |
||||
kernel_size=1, |
||||
norm=nn.BatchNorm2D( |
||||
out_chn, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
|
||||
def forward(self, x, y): |
||||
x1 = self.conv1(x) |
||||
x2 = self.conv2(x1) |
||||
x3 = self.conv3(x2 + x1) |
||||
|
||||
y1 = self.conv1(y) |
||||
y2 = self.conv2(y1) |
||||
y3 = self.conv3(y2 + y1) |
||||
out = self.conv_out(paddle.abs(x1 + x2 + x3 - y1 - y2 - y3)) |
||||
return out |
||||
|
||||
|
||||
class DFModule(nn.Layer): |
||||
"""Dense connection-based feature fusion module""" |
||||
|
||||
def __init__(self, dim_in, dim_out, reduction=True): |
||||
super(DFModule, self).__init__() |
||||
if reduction: |
||||
self.reduction = Conv1x1( |
||||
dim_in, |
||||
dim_in // 2, |
||||
norm=nn.BatchNorm2D( |
||||
dim_in // 2, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
dim_in = dim_in // 2 |
||||
else: |
||||
self.reduction = None |
||||
self.cat1 = DenseCatAdd(dim_in, dim_out) |
||||
self.cat2 = DenseCatDiff(dim_in, dim_out) |
||||
self.conv1 = Conv3x3( |
||||
dim_out, |
||||
dim_out, |
||||
norm=nn.BatchNorm2D( |
||||
dim_out, momentum=bn_mom), |
||||
act=nn.ReLU()) |
||||
|
||||
def forward(self, x1, x2): |
||||
if self.reduction is not None: |
||||
x1 = self.reduction(x1) |
||||
x2 = self.reduction(x2) |
||||
x_add = self.cat1(x1, x2) |
||||
x_diff = self.cat2(x1, x2) |
||||
y = self.conv1(x_diff) + x_add |
||||
return y |
||||
|
||||
|
||||
class FCCDN(nn.Layer): |
||||
""" |
||||
The FCCDN implementation based on PaddlePaddle. |
||||
|
||||
The original article refers to |
||||
Pan Chen, et al., "FCCDN: Feature Constraint Network for VHR Image Change Detection" |
||||
(https://arxiv.org/pdf/2105.10860.pdf). |
||||
|
||||
Args: |
||||
in_channels (int): Number of input channels. Default: 3. |
||||
num_classes (int): Number of target classes. Default: 2. |
||||
os (int): Number of output stride. Default: 16. |
||||
use_se (bool): Whether to use SEModule. Default: True. |
||||
""" |
||||
|
||||
def __init__(self, in_channels=3, num_classes=2, os=16, use_se=True): |
||||
super(FCCDN, self).__init__() |
||||
if os >= 16: |
||||
dilation_list = [1, 1, 1, 1] |
||||
stride_list = [2, 2, 2, 2] |
||||
pool_list = [True, True, True, True] |
||||
elif os == 8: |
||||
dilation_list = [2, 1, 1, 1] |
||||
stride_list = [1, 2, 2, 2] |
||||
pool_list = [False, True, True, True] |
||||
else: |
||||
dilation_list = [2, 2, 1, 1] |
||||
stride_list = [1, 1, 2, 2] |
||||
pool_list = [False, False, True, True] |
||||
se_list = [use_se, use_se, use_se, use_se] |
||||
channel_list = [256, 128, 64, 32] |
||||
# Encoder |
||||
self.block1 = BasicBlock(in_channels, channel_list[3], pool_list[3], |
||||
se_list[3], stride_list[3], dilation_list[3]) |
||||
self.block2 = BasicBlock(channel_list[3], channel_list[2], pool_list[2], |
||||
se_list[2], stride_list[2], dilation_list[2]) |
||||
self.block3 = BasicBlock(channel_list[2], channel_list[1], pool_list[1], |
||||
se_list[1], stride_list[1], dilation_list[1]) |
||||
self.block4 = BasicBlock(channel_list[1], channel_list[0], pool_list[0], |
||||
se_list[0], stride_list[0], dilation_list[0]) |
||||
|
||||
# Center |
||||
self.center = NLFPN(channel_list[0], True) |
||||
|
||||
# Decoder |
||||
self.decoder3 = Cat(channel_list[0], |
||||
channel_list[1], |
||||
channel_list[1], |
||||
upsample=pool_list[0]) |
||||
self.decoder2 = Cat(channel_list[1], |
||||
channel_list[2], |
||||
channel_list[2], |
||||
upsample=pool_list[1]) |
||||
self.decoder1 = Cat(channel_list[2], |
||||
channel_list[3], |
||||
channel_list[3], |
||||
upsample=pool_list[2]) |
||||
|
||||
self.df1 = DFModule(channel_list[3], channel_list[3], True) |
||||
self.df2 = DFModule(channel_list[2], channel_list[2], True) |
||||
self.df3 = DFModule(channel_list[1], channel_list[1], True) |
||||
self.df4 = DFModule(channel_list[0], channel_list[0], True) |
||||
|
||||
self.catc3 = Cat(channel_list[0], |
||||
channel_list[1], |
||||
channel_list[1], |
||||
upsample=pool_list[0]) |
||||
self.catc2 = Cat(channel_list[1], |
||||
channel_list[2], |
||||
channel_list[2], |
||||
upsample=pool_list[1]) |
||||
self.catc1 = Cat(channel_list[2], |
||||
channel_list[3], |
||||
channel_list[3], |
||||
upsample=pool_list[2]) |
||||
|
||||
self.upsample_x2 = nn.Sequential( |
||||
nn.Conv2D( |
||||
channel_list[3], 8, kernel_size=3, stride=1, padding=1), |
||||
nn.BatchNorm2D( |
||||
8, momentum=bn_mom), |
||||
nn.ReLU(), |
||||
nn.UpsamplingBilinear2D(scale_factor=2)) |
||||
|
||||
self.conv_out = nn.Conv2D( |
||||
8, num_classes, kernel_size=3, stride=1, padding=1) |
||||
self.conv_out_class = nn.Conv2D( |
||||
channel_list[3], 1, kernel_size=1, stride=1, padding=0) |
||||
|
||||
def forward(self, t1, t2): |
||||
e1_1 = self.block1(t1) |
||||
e2_1 = self.block2(e1_1) |
||||
e3_1 = self.block3(e2_1) |
||||
y1 = self.block4(e3_1) |
||||
|
||||
e1_2 = self.block1(t2) |
||||
e2_2 = self.block2(e1_2) |
||||
e3_2 = self.block3(e2_2) |
||||
y2 = self.block4(e3_2) |
||||
|
||||
y1 = self.center(y1) |
||||
y2 = self.center(y2) |
||||
c = self.df4(y1, y2) |
||||
|
||||
y1 = self.decoder3(y1, e3_1) |
||||
y2 = self.decoder3(y2, e3_2) |
||||
c = self.catc3(c, self.df3(y1, y2)) |
||||
|
||||
y1 = self.decoder2(y1, e2_1) |
||||
y2 = self.decoder2(y2, e2_2) |
||||
c = self.catc2(c, self.df2(y1, y2)) |
||||
|
||||
y1 = self.decoder1(y1, e1_1) |
||||
y2 = self.decoder1(y2, e1_2) |
||||
|
||||
c = self.catc1(c, self.df1(y1, y2)) |
||||
y = self.conv_out(self.upsample_x2(c)) |
||||
|
||||
if self.training: |
||||
y1 = self.conv_out_class(y1) |
||||
y2 = self.conv_out_class(y2) |
||||
return [y, [y1, y2]] |
||||
else: |
||||
return [y] |
@ -0,0 +1,170 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
import paddle.nn.functional as F |
||||
|
||||
|
||||
class DiceLoss(nn.Layer): |
||||
def __init__(self, batch=True): |
||||
super(DiceLoss, self).__init__() |
||||
self.batch = batch |
||||
|
||||
def soft_dice_coeff(self, y_pred, y_true): |
||||
smooth = 0.00001 |
||||
if self.batch: |
||||
i = paddle.sum(y_true) |
||||
j = paddle.sum(y_pred) |
||||
intersection = paddle.sum(y_true * y_pred) |
||||
else: |
||||
i = y_true.sum(1).sum(1).sum(1) |
||||
j = y_pred.sum(1).sum(1).sum(1) |
||||
intersection = (y_true * y_pred).sum(1).sum(1).sum(1) |
||||
score = (2. * intersection + smooth) / (i + j + smooth) |
||||
return score.mean() |
||||
|
||||
def soft_dice_loss(self, y_pred, y_true): |
||||
loss = 1 - self.soft_dice_coeff(y_pred, y_true) |
||||
return loss |
||||
|
||||
def forward(self, y_pred, y_true): |
||||
return self.soft_dice_loss(y_pred.astype(paddle.float32), y_true) |
||||
|
||||
|
||||
class MultiClassDiceLoss(nn.Layer): |
||||
def __init__( |
||||
self, |
||||
weight, |
||||
batch=True, |
||||
ignore_index=-1, |
||||
do_softmax=False, |
||||
**kwargs, ): |
||||
super(MultiClassDiceLoss, self).__init__() |
||||
self.ignore_index = ignore_index |
||||
self.weight = weight |
||||
self.do_softmax = do_softmax |
||||
self.binary_diceloss = DiceLoss(batch) |
||||
|
||||
def forward(self, y_pred, y_true): |
||||
if self.do_softmax: |
||||
y_pred = paddle.nn.functional.softmax(y_pred, axis=1) |
||||
y_true = F.one_hot(y_true.long(), y_pred.shape[1]).permute(0, 3, 1, 2) |
||||
total_loss = 0.0 |
||||
tmp_i = 0.0 |
||||
for i in range(y_pred.shape[1]): |
||||
if i != self.ignore_index: |
||||
diceloss = self.binary_diceloss(y_pred[:, i, :, :], |
||||
y_true[:, i, :, :]) |
||||
total_loss += paddle.multiply(diceloss, self.weight[i]) |
||||
tmp_i += 1.0 |
||||
return total_loss / tmp_i |
||||
|
||||
|
||||
class DiceBCELoss(nn.Layer): |
||||
"""Binary change detection task loss""" |
||||
|
||||
def __init__(self): |
||||
super(DiceBCELoss, self).__init__() |
||||
self.bce_loss = nn.BCELoss() |
||||
self.binnary_dice = DiceLoss() |
||||
|
||||
def forward(self, scores, labels, do_sigmoid=True): |
||||
if len(scores.shape) > 3: |
||||
scores = scores.squeeze(1) |
||||
if len(labels.shape) > 3: |
||||
labels = labels.squeeze(1) |
||||
if do_sigmoid: |
||||
scores = paddle.nn.functional.sigmoid(scores.clone()) |
||||
diceloss = self.binnary_dice(scores, labels) |
||||
bceloss = self.bce_loss(scores, labels) |
||||
return diceloss + bceloss |
||||
|
||||
|
||||
class McDiceBCELoss(nn.Layer): |
||||
"""Multi-class change detection task loss""" |
||||
|
||||
def __init__(self, weight, do_sigmoid=True): |
||||
super(McDiceBCELoss, self).__init__() |
||||
self.ce_loss = nn.CrossEntropyLoss(weight) |
||||
self.dice = MultiClassDiceLoss(weight, do_sigmoid) |
||||
|
||||
def forward(self, scores, labels): |
||||
if len(scores.shape) < 4: |
||||
scores = scores.unsqueeze(1) |
||||
if len(labels.shape) < 4: |
||||
labels = labels.unsqueeze(1) |
||||
diceloss = self.dice(scores, labels) |
||||
bceloss = self.ce_loss(scores, labels) |
||||
return diceloss + bceloss |
||||
|
||||
|
||||
def fccdn_ssl_loss(logits_list, labels): |
||||
""" |
||||
Self-supervised learning loss for change detection. |
||||
|
||||
The original article refers to |
||||
Pan Chen, et al., "FCCDN: Feature Constraint Network for VHR Image Change Detection" |
||||
(https://arxiv.org/pdf/2105.10860.pdf). |
||||
|
||||
Args: |
||||
logits_list (list[paddle.Tensor]): Single-channel segmentation logit maps for each of the two temporal phases. |
||||
labels (paddle.Tensor): Binary change labels. |
||||
""" |
||||
|
||||
# Create loss |
||||
criterion_ssl = DiceBCELoss() |
||||
|
||||
# Get downsampled change map |
||||
h, w = logits_list[0].shape[-2], logits_list[0].shape[-1] |
||||
labels_downsample = F.interpolate(x=labels.unsqueeze(1), size=[h, w]) |
||||
labels_type = str(labels_downsample.dtype) |
||||
assert "int" in labels_type or "bool" in labels_type,\ |
||||
f"Expected dtype of labels to be int or bool, but got {labels_type}" |
||||
|
||||
# Seg map |
||||
out1 = paddle.nn.functional.sigmoid(logits_list[0]).clone() |
||||
out2 = paddle.nn.functional.sigmoid(logits_list[1]).clone() |
||||
out3 = out1.clone() |
||||
out4 = out2.clone() |
||||
|
||||
out1 = paddle.where(labels_downsample == 1, paddle.zeros_like(out1), out1) |
||||
out2 = paddle.where(labels_downsample == 1, paddle.zeros_like(out2), out2) |
||||
out3 = paddle.where(labels_downsample != 1, paddle.zeros_like(out3), out3) |
||||
out4 = paddle.where(labels_downsample != 1, paddle.zeros_like(out4), out4) |
||||
|
||||
pred_seg_pre_tmp1 = paddle.where(out1 <= 0.5, |
||||
paddle.zeros_like(out1), |
||||
paddle.ones_like(out1)) |
||||
pred_seg_post_tmp1 = paddle.where(out2 <= 0.5, |
||||
paddle.zeros_like(out2), |
||||
paddle.ones_like(out2)) |
||||
|
||||
pred_seg_pre_tmp2 = paddle.where(out3 <= 0.5, |
||||
paddle.zeros_like(out3), |
||||
paddle.ones_like(out3)) |
||||
pred_seg_post_tmp2 = paddle.where(out4 <= 0.5, |
||||
paddle.zeros_like(out4), |
||||
paddle.ones_like(out4)) |
||||
|
||||
# Seg loss |
||||
labels_downsample = labels_downsample.astype(paddle.float32) |
||||
loss_aux = 0.2 * criterion_ssl(out1, pred_seg_post_tmp1, False) |
||||
loss_aux += 0.2 * criterion_ssl(out2, pred_seg_pre_tmp1, False) |
||||
loss_aux += 0.2 * criterion_ssl( |
||||
out3, labels_downsample - pred_seg_post_tmp2, False) |
||||
loss_aux += 0.2 * criterion_ssl(out4, labels_downsample - pred_seg_pre_tmp2, |
||||
False) |
||||
|
||||
return loss_aux |
@ -0,0 +1,27 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
|
||||
from paddlers.models.ppgan.modules.init import reset_parameters |
||||
|
||||
|
||||
def init_sr_weight(net): |
||||
def reset_func(m): |
||||
if hasattr(m, 'weight') and ( |
||||
not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))): |
||||
reset_parameters(m) |
||||
|
||||
net.apply(reset_func) |
@ -1,106 +0,0 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
|
||||
from .generators.builder import build_generator |
||||
from ...models.ppgan.models.criterions.builder import build_criterion |
||||
from ...models.ppgan.models.base_model import BaseModel |
||||
from ...models.ppgan.models.builder import MODELS |
||||
from ...models.ppgan.utils.visual import tensor2img |
||||
from ...models.ppgan.modules.init import reset_parameters |
||||
|
||||
|
||||
@MODELS.register() |
||||
class RCANModel(BaseModel): |
||||
""" |
||||
Base SR model for single image super-resolution. |
||||
""" |
||||
|
||||
def __init__(self, generator, pixel_criterion=None, use_init_weight=False): |
||||
""" |
||||
Args: |
||||
generator (dict): config of generator. |
||||
pixel_criterion (dict): config of pixel criterion. |
||||
""" |
||||
super(RCANModel, self).__init__() |
||||
|
||||
self.nets['generator'] = build_generator(generator) |
||||
self.error_last = 1e8 |
||||
self.batch = 0 |
||||
if pixel_criterion: |
||||
self.pixel_criterion = build_criterion(pixel_criterion) |
||||
if use_init_weight: |
||||
init_sr_weight(self.nets['generator']) |
||||
|
||||
def setup_input(self, input): |
||||
self.lq = paddle.to_tensor(input['lq']) |
||||
self.visual_items['lq'] = self.lq |
||||
if 'gt' in input: |
||||
self.gt = paddle.to_tensor(input['gt']) |
||||
self.visual_items['gt'] = self.gt |
||||
self.image_paths = input['lq_path'] |
||||
|
||||
def forward(self): |
||||
pass |
||||
|
||||
def train_iter(self, optims=None): |
||||
optims['optim'].clear_grad() |
||||
|
||||
self.output = self.nets['generator'](self.lq) |
||||
self.visual_items['output'] = self.output |
||||
# pixel loss |
||||
loss_pixel = self.pixel_criterion(self.output, self.gt) |
||||
self.losses['loss_pixel'] = loss_pixel |
||||
|
||||
skip_threshold = 1e6 |
||||
|
||||
if loss_pixel.item() < skip_threshold * self.error_last: |
||||
loss_pixel.backward() |
||||
optims['optim'].step() |
||||
else: |
||||
print('Skip this batch {}! (Loss: {})'.format(self.batch + 1, |
||||
loss_pixel.item())) |
||||
self.batch += 1 |
||||
|
||||
if self.batch % 1000 == 0: |
||||
self.error_last = loss_pixel.item() / 1000 |
||||
print("update error_last:{}".format(self.error_last)) |
||||
|
||||
def test_iter(self, metrics=None): |
||||
self.nets['generator'].eval() |
||||
with paddle.no_grad(): |
||||
self.output = self.nets['generator'](self.lq) |
||||
self.visual_items['output'] = self.output |
||||
self.nets['generator'].train() |
||||
|
||||
out_img = [] |
||||
gt_img = [] |
||||
for out_tensor, gt_tensor in zip(self.output, self.gt): |
||||
out_img.append(tensor2img(out_tensor, (0., 255.))) |
||||
gt_img.append(tensor2img(gt_tensor, (0., 255.))) |
||||
|
||||
if metrics is not None: |
||||
for metric in metrics.values(): |
||||
metric.update(out_img, gt_img) |
||||
|
||||
|
||||
def init_sr_weight(net): |
||||
def reset_func(m): |
||||
if hasattr(m, 'weight') and ( |
||||
not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))): |
||||
reset_parameters(m) |
||||
|
||||
net.apply(reset_func) |
@ -1,786 +0,0 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import os |
||||
import time |
||||
import datetime |
||||
|
||||
import paddle |
||||
from paddle.distributed import ParallelEnv |
||||
|
||||
from ..models.ppgan.datasets.builder import build_dataloader |
||||
from ..models.ppgan.models.builder import build_model |
||||
from ..models.ppgan.utils.visual import tensor2img, save_image |
||||
from ..models.ppgan.utils.filesystem import makedirs, save, load |
||||
from ..models.ppgan.utils.timer import TimeAverager |
||||
from ..models.ppgan.utils.profiler import add_profiler_step |
||||
from ..models.ppgan.utils.logger import setup_logger |
||||
|
||||
|
||||
# 定义AttrDict类实现动态属性 |
||||
class AttrDict(dict): |
||||
def __getattr__(self, key): |
||||
try: |
||||
return self[key] |
||||
except KeyError: |
||||
raise AttributeError(key) |
||||
|
||||
def __setattr__(self, key, value): |
||||
if key in self.__dict__: |
||||
self.__dict__[key] = value |
||||
else: |
||||
self[key] = value |
||||
|
||||
|
||||
# 创建AttrDict类 |
||||
def create_attr_dict(config_dict): |
||||
from ast import literal_eval |
||||
for key, value in config_dict.items(): |
||||
if type(value) is dict: |
||||
config_dict[key] = value = AttrDict(value) |
||||
if isinstance(value, str): |
||||
try: |
||||
value = literal_eval(value) |
||||
except BaseException: |
||||
pass |
||||
if isinstance(value, AttrDict): |
||||
create_attr_dict(config_dict[key]) |
||||
else: |
||||
config_dict[key] = value |
||||
|
||||
|
||||
# 数据加载类 |
||||
class IterLoader: |
||||
def __init__(self, dataloader): |
||||
self._dataloader = dataloader |
||||
self.iter_loader = iter(self._dataloader) |
||||
self._epoch = 1 |
||||
|
||||
@property |
||||
def epoch(self): |
||||
return self._epoch |
||||
|
||||
def __next__(self): |
||||
try: |
||||
data = next(self.iter_loader) |
||||
except StopIteration: |
||||
self._epoch += 1 |
||||
self.iter_loader = iter(self._dataloader) |
||||
data = next(self.iter_loader) |
||||
|
||||
return data |
||||
|
||||
def __len__(self): |
||||
return len(self._dataloader) |
||||
|
||||
|
||||
# 基础训练类 |
||||
class Restorer: |
||||
""" |
||||
# trainer calling logic: |
||||
# |
||||
# build_model || model(BaseModel) |
||||
# | || |
||||
# build_dataloader || dataloader |
||||
# | || |
||||
# model.setup_lr_schedulers || lr_scheduler |
||||
# | || |
||||
# model.setup_optimizers || optimizers |
||||
# | || |
||||
# train loop (model.setup_input + model.train_iter) || train loop |
||||
# | || |
||||
# print log (model.get_current_losses) || |
||||
# | || |
||||
# save checkpoint (model.nets) \/ |
||||
""" |
||||
|
||||
def __init__(self, cfg, logger): |
||||
# base config |
||||
# self.logger = logging.getLogger(__name__) |
||||
self.logger = logger |
||||
self.cfg = cfg |
||||
self.output_dir = cfg.output_dir |
||||
self.max_eval_steps = cfg.model.get('max_eval_steps', None) |
||||
|
||||
self.local_rank = ParallelEnv().local_rank |
||||
self.world_size = ParallelEnv().nranks |
||||
self.log_interval = cfg.log_config.interval |
||||
self.visual_interval = cfg.log_config.visiual_interval |
||||
self.weight_interval = cfg.snapshot_config.interval |
||||
|
||||
self.start_epoch = 1 |
||||
self.current_epoch = 1 |
||||
self.current_iter = 1 |
||||
self.inner_iter = 1 |
||||
self.batch_id = 0 |
||||
self.global_steps = 0 |
||||
|
||||
# build model |
||||
self.model = build_model(cfg.model) |
||||
# multiple gpus prepare |
||||
if ParallelEnv().nranks > 1: |
||||
self.distributed_data_parallel() |
||||
|
||||
# build metrics |
||||
self.metrics = None |
||||
self.is_save_img = True |
||||
validate_cfg = cfg.get('validate', None) |
||||
if validate_cfg and 'metrics' in validate_cfg: |
||||
self.metrics = self.model.setup_metrics(validate_cfg['metrics']) |
||||
if validate_cfg and 'save_img' in validate_cfg: |
||||
self.is_save_img = validate_cfg['save_img'] |
||||
|
||||
self.enable_visualdl = cfg.get('enable_visualdl', False) |
||||
if self.enable_visualdl: |
||||
import visualdl |
||||
self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir) |
||||
|
||||
# evaluate only |
||||
if not cfg.is_train: |
||||
return |
||||
|
||||
# build train dataloader |
||||
self.train_dataloader = build_dataloader(cfg.dataset.train) |
||||
self.iters_per_epoch = len(self.train_dataloader) |
||||
|
||||
# build lr scheduler |
||||
# TODO: has a better way? |
||||
if 'lr_scheduler' in cfg and 'iters_per_epoch' in cfg.lr_scheduler: |
||||
cfg.lr_scheduler.iters_per_epoch = self.iters_per_epoch |
||||
self.lr_schedulers = self.model.setup_lr_schedulers(cfg.lr_scheduler) |
||||
|
||||
# build optimizers |
||||
self.optimizers = self.model.setup_optimizers(self.lr_schedulers, |
||||
cfg.optimizer) |
||||
|
||||
self.epochs = cfg.get('epochs', None) |
||||
if self.epochs: |
||||
self.total_iters = self.epochs * self.iters_per_epoch |
||||
self.by_epoch = True |
||||
else: |
||||
self.by_epoch = False |
||||
self.total_iters = cfg.total_iters |
||||
|
||||
if self.by_epoch: |
||||
self.weight_interval *= self.iters_per_epoch |
||||
|
||||
self.validate_interval = -1 |
||||
if cfg.get('validate', None) is not None: |
||||
self.validate_interval = cfg.validate.get('interval', -1) |
||||
|
||||
self.time_count = {} |
||||
self.best_metric = {} |
||||
self.model.set_total_iter(self.total_iters) |
||||
self.profiler_options = cfg.profiler_options |
||||
|
||||
def distributed_data_parallel(self): |
||||
paddle.distributed.init_parallel_env() |
||||
find_unused_parameters = self.cfg.get('find_unused_parameters', False) |
||||
for net_name, net in self.model.nets.items(): |
||||
self.model.nets[net_name] = paddle.DataParallel( |
||||
net, find_unused_parameters=find_unused_parameters) |
||||
|
||||
def learning_rate_scheduler_step(self): |
||||
if isinstance(self.model.lr_scheduler, dict): |
||||
for lr_scheduler in self.model.lr_scheduler.values(): |
||||
lr_scheduler.step() |
||||
elif isinstance(self.model.lr_scheduler, |
||||
paddle.optimizer.lr.LRScheduler): |
||||
self.model.lr_scheduler.step() |
||||
else: |
||||
raise ValueError( |
||||
'lr schedulter must be a dict or an instance of LRScheduler') |
||||
|
||||
def train(self): |
||||
reader_cost_averager = TimeAverager() |
||||
batch_cost_averager = TimeAverager() |
||||
|
||||
iter_loader = IterLoader(self.train_dataloader) |
||||
|
||||
# set model.is_train = True |
||||
self.model.setup_train_mode(is_train=True) |
||||
while self.current_iter < (self.total_iters + 1): |
||||
self.current_epoch = iter_loader.epoch |
||||
self.inner_iter = self.current_iter % self.iters_per_epoch |
||||
|
||||
add_profiler_step(self.profiler_options) |
||||
|
||||
start_time = step_start_time = time.time() |
||||
data = next(iter_loader) |
||||
reader_cost_averager.record(time.time() - step_start_time) |
||||
# unpack data from dataset and apply preprocessing |
||||
# data input should be dict |
||||
self.model.setup_input(data) |
||||
self.model.train_iter(self.optimizers) |
||||
|
||||
batch_cost_averager.record( |
||||
time.time() - step_start_time, |
||||
num_samples=self.cfg['dataset']['train'].get('batch_size', 1)) |
||||
|
||||
step_start_time = time.time() |
||||
|
||||
if self.current_iter % self.log_interval == 0: |
||||
self.data_time = reader_cost_averager.get_average() |
||||
self.step_time = batch_cost_averager.get_average() |
||||
self.ips = batch_cost_averager.get_ips_average() |
||||
self.print_log() |
||||
|
||||
reader_cost_averager.reset() |
||||
batch_cost_averager.reset() |
||||
|
||||
if self.current_iter % self.visual_interval == 0 and self.local_rank == 0: |
||||
self.visual('visual_train') |
||||
|
||||
self.learning_rate_scheduler_step() |
||||
|
||||
if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0: |
||||
self.test() |
||||
|
||||
if self.current_iter % self.weight_interval == 0: |
||||
self.save(self.current_iter, 'weight', keep=-1) |
||||
self.save(self.current_iter) |
||||
|
||||
self.current_iter += 1 |
||||
|
||||
def test(self): |
||||
if not hasattr(self, 'test_dataloader'): |
||||
self.test_dataloader = build_dataloader( |
||||
self.cfg.dataset.test, is_train=False) |
||||
iter_loader = IterLoader(self.test_dataloader) |
||||
if self.max_eval_steps is None: |
||||
self.max_eval_steps = len(self.test_dataloader) |
||||
|
||||
if self.metrics: |
||||
for metric in self.metrics.values(): |
||||
metric.reset() |
||||
|
||||
# set model.is_train = False |
||||
self.model.setup_train_mode(is_train=False) |
||||
|
||||
for i in range(self.max_eval_steps): |
||||
if self.max_eval_steps < self.log_interval or i % self.log_interval == 0: |
||||
self.logger.info('Test iter: [%d/%d]' % ( |
||||
i * self.world_size, self.max_eval_steps * self.world_size)) |
||||
|
||||
data = next(iter_loader) |
||||
self.model.setup_input(data) |
||||
self.model.test_iter(metrics=self.metrics) |
||||
|
||||
if self.is_save_img: |
||||
visual_results = {} |
||||
current_paths = self.model.get_image_paths() |
||||
current_visuals = self.model.get_current_visuals() |
||||
|
||||
if len(current_visuals) > 0 and list(current_visuals.values())[ |
||||
0].shape == 4: |
||||
num_samples = list(current_visuals.values())[0].shape[0] |
||||
else: |
||||
num_samples = 1 |
||||
|
||||
for j in range(num_samples): |
||||
if j < len(current_paths): |
||||
short_path = os.path.basename(current_paths[j]) |
||||
basename = os.path.splitext(short_path)[0] |
||||
else: |
||||
basename = '{:04d}_{:04d}'.format(i, j) |
||||
for k, img_tensor in current_visuals.items(): |
||||
name = '%s_%s' % (basename, k) |
||||
if len(img_tensor.shape) == 4: |
||||
visual_results.update({name: img_tensor[j]}) |
||||
else: |
||||
visual_results.update({name: img_tensor}) |
||||
|
||||
self.visual( |
||||
'visual_test', |
||||
visual_results=visual_results, |
||||
step=self.batch_id, |
||||
is_save_image=True) |
||||
|
||||
if self.metrics: |
||||
for metric_name, metric in self.metrics.items(): |
||||
self.logger.info("Metric {}: {:.4f}".format( |
||||
metric_name, metric.accumulate())) |
||||
|
||||
def print_log(self): |
||||
losses = self.model.get_current_losses() |
||||
|
||||
message = '' |
||||
if self.by_epoch: |
||||
message += 'Epoch: %d/%d, iter: %d/%d ' % ( |
||||
self.current_epoch, self.epochs, self.inner_iter, |
||||
self.iters_per_epoch) |
||||
else: |
||||
message += 'Iter: %d/%d ' % (self.current_iter, self.total_iters) |
||||
|
||||
message += f'lr: {self.current_learning_rate:.3e} ' |
||||
|
||||
for k, v in losses.items(): |
||||
message += '%s: %.3f ' % (k, v) |
||||
if self.enable_visualdl: |
||||
self.vdl_logger.add_scalar(k, v, step=self.global_steps) |
||||
|
||||
if hasattr(self, 'step_time'): |
||||
message += 'batch_cost: %.5f sec ' % self.step_time |
||||
|
||||
if hasattr(self, 'data_time'): |
||||
message += 'reader_cost: %.5f sec ' % self.data_time |
||||
|
||||
if hasattr(self, 'ips'): |
||||
message += 'ips: %.5f images/s ' % self.ips |
||||
|
||||
if hasattr(self, 'step_time'): |
||||
eta = self.step_time * (self.total_iters - self.current_iter) |
||||
eta = eta if eta > 0 else 0 |
||||
|
||||
eta_str = str(datetime.timedelta(seconds=int(eta))) |
||||
message += f'eta: {eta_str}' |
||||
|
||||
# print the message |
||||
self.logger.info(message) |
||||
|
||||
@property |
||||
def current_learning_rate(self): |
||||
for optimizer in self.model.optimizers.values(): |
||||
return optimizer.get_lr() |
||||
|
||||
def visual(self, |
||||
results_dir, |
||||
visual_results=None, |
||||
step=None, |
||||
is_save_image=False): |
||||
""" |
||||
visual the images, use visualdl or directly write to the directory |
||||
Parameters: |
||||
results_dir (str) -- directory name which contains saved images |
||||
visual_results (dict) -- the results images dict |
||||
step (int) -- global steps, used in visualdl |
||||
is_save_image (bool) -- weather write to the directory or visualdl |
||||
""" |
||||
self.model.compute_visuals() |
||||
|
||||
if visual_results is None: |
||||
visual_results = self.model.get_current_visuals() |
||||
|
||||
min_max = self.cfg.get('min_max', None) |
||||
if min_max is None: |
||||
min_max = (-1., 1.) |
||||
|
||||
image_num = self.cfg.get('image_num', None) |
||||
if (image_num is None) or (not self.enable_visualdl): |
||||
image_num = 1 |
||||
for label, image in visual_results.items(): |
||||
image_numpy = tensor2img(image, min_max, image_num) |
||||
if (not is_save_image) and self.enable_visualdl: |
||||
self.vdl_logger.add_image( |
||||
results_dir + '/' + label, |
||||
image_numpy, |
||||
step=step if step else self.global_steps, |
||||
dataformats="HWC" if image_num == 1 else "NCHW") |
||||
else: |
||||
if self.cfg.is_train: |
||||
if self.by_epoch: |
||||
msg = 'epoch%.3d_' % self.current_epoch |
||||
else: |
||||
msg = 'iter%.3d_' % self.current_iter |
||||
else: |
||||
msg = '' |
||||
makedirs(os.path.join(self.output_dir, results_dir)) |
||||
img_path = os.path.join(self.output_dir, results_dir, |
||||
msg + '%s.png' % (label)) |
||||
save_image(image_numpy, img_path) |
||||
|
||||
def save(self, epoch, name='checkpoint', keep=1): |
||||
if self.local_rank != 0: |
||||
return |
||||
|
||||
assert name in ['checkpoint', 'weight'] |
||||
|
||||
state_dicts = {} |
||||
if self.by_epoch: |
||||
save_filename = 'epoch_%s_%s.pdparams' % ( |
||||
epoch // self.iters_per_epoch, name) |
||||
else: |
||||
save_filename = 'iter_%s_%s.pdparams' % (epoch, name) |
||||
|
||||
os.makedirs(self.output_dir, exist_ok=True) |
||||
save_path = os.path.join(self.output_dir, save_filename) |
||||
for net_name, net in self.model.nets.items(): |
||||
state_dicts[net_name] = net.state_dict() |
||||
|
||||
if name == 'weight': |
||||
save(state_dicts, save_path) |
||||
return |
||||
|
||||
state_dicts['epoch'] = epoch |
||||
|
||||
for opt_name, opt in self.model.optimizers.items(): |
||||
state_dicts[opt_name] = opt.state_dict() |
||||
|
||||
save(state_dicts, save_path) |
||||
|
||||
if keep > 0: |
||||
try: |
||||
if self.by_epoch: |
||||
checkpoint_name_to_be_removed = os.path.join( |
||||
self.output_dir, 'epoch_%s_%s.pdparams' % ( |
||||
(epoch - keep * self.weight_interval) // |
||||
self.iters_per_epoch, name)) |
||||
else: |
||||
checkpoint_name_to_be_removed = os.path.join( |
||||
self.output_dir, 'iter_%s_%s.pdparams' % |
||||
(epoch - keep * self.weight_interval, name)) |
||||
|
||||
if os.path.exists(checkpoint_name_to_be_removed): |
||||
os.remove(checkpoint_name_to_be_removed) |
||||
|
||||
except Exception as e: |
||||
self.logger.info('remove old checkpoints error: {}'.format(e)) |
||||
|
||||
def resume(self, checkpoint_path): |
||||
state_dicts = load(checkpoint_path) |
||||
if state_dicts.get('epoch', None) is not None: |
||||
self.start_epoch = state_dicts['epoch'] + 1 |
||||
self.global_steps = self.iters_per_epoch * state_dicts['epoch'] |
||||
|
||||
self.current_iter = state_dicts['epoch'] + 1 |
||||
|
||||
for net_name, net in self.model.nets.items(): |
||||
net.set_state_dict(state_dicts[net_name]) |
||||
|
||||
for opt_name, opt in self.model.optimizers.items(): |
||||
opt.set_state_dict(state_dicts[opt_name]) |
||||
|
||||
def load(self, weight_path): |
||||
state_dicts = load(weight_path) |
||||
|
||||
for net_name, net in self.model.nets.items(): |
||||
if net_name in state_dicts: |
||||
net.set_state_dict(state_dicts[net_name]) |
||||
self.logger.info('Loaded pretrained weight for net {}'.format( |
||||
net_name)) |
||||
else: |
||||
self.logger.warning( |
||||
'Can not find state dict of net {}. Skip load pretrained weight for net {}' |
||||
.format(net_name, net_name)) |
||||
|
||||
def close(self): |
||||
""" |
||||
when finish the training need close file handler or other. |
||||
""" |
||||
if self.enable_visualdl: |
||||
self.vdl_logger.close() |
||||
|
||||
|
||||
# 基础超分模型训练类 |
||||
class BasicSRNet: |
||||
def __init__(self): |
||||
self.model = {} |
||||
self.optimizer = {} |
||||
self.lr_scheduler = {} |
||||
self.min_max = '' |
||||
|
||||
def train( |
||||
self, |
||||
total_iters, |
||||
train_dataset, |
||||
test_dataset, |
||||
output_dir, |
||||
validate, |
||||
snapshot, |
||||
log, |
||||
lr_rate, |
||||
evaluate_weights='', |
||||
resume='', |
||||
pretrain_weights='', |
||||
periods=[100000], |
||||
restart_weights=[1], ): |
||||
self.lr_scheduler['learning_rate'] = lr_rate |
||||
|
||||
if self.lr_scheduler['name'] == 'CosineAnnealingRestartLR': |
||||
self.lr_scheduler['periods'] = periods |
||||
self.lr_scheduler['restart_weights'] = restart_weights |
||||
|
||||
validate = { |
||||
'interval': validate, |
||||
'save_img': False, |
||||
'metrics': { |
||||
'psnr': { |
||||
'name': 'PSNR', |
||||
'crop_border': 4, |
||||
'test_y_channel': True |
||||
}, |
||||
'ssim': { |
||||
'name': 'SSIM', |
||||
'crop_border': 4, |
||||
'test_y_channel': True |
||||
} |
||||
} |
||||
} |
||||
log_config = {'interval': log, 'visiual_interval': 500} |
||||
snapshot_config = {'interval': snapshot} |
||||
|
||||
cfg = { |
||||
'total_iters': total_iters, |
||||
'output_dir': output_dir, |
||||
'min_max': self.min_max, |
||||
'model': self.model, |
||||
'dataset': { |
||||
'train': train_dataset, |
||||
'test': test_dataset |
||||
}, |
||||
'lr_scheduler': self.lr_scheduler, |
||||
'optimizer': self.optimizer, |
||||
'validate': validate, |
||||
'log_config': log_config, |
||||
'snapshot_config': snapshot_config |
||||
} |
||||
|
||||
cfg = AttrDict(cfg) |
||||
create_attr_dict(cfg) |
||||
|
||||
cfg.is_train = True |
||||
cfg.profiler_options = None |
||||
cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime()) |
||||
|
||||
if cfg.model.name == 'BaseSRModel': |
||||
floderModelName = cfg.model.generator.name |
||||
else: |
||||
floderModelName = cfg.model.name |
||||
cfg.output_dir = os.path.join(cfg.output_dir, |
||||
floderModelName + cfg.timestamp) |
||||
|
||||
logger_cfg = setup_logger(cfg.output_dir) |
||||
logger_cfg.info('Configs: {}'.format(cfg)) |
||||
|
||||
if paddle.is_compiled_with_cuda(): |
||||
paddle.set_device('gpu') |
||||
else: |
||||
paddle.set_device('cpu') |
||||
|
||||
# build trainer |
||||
trainer = Restorer(cfg, logger_cfg) |
||||
|
||||
# continue train or evaluate, checkpoint need contain epoch and optimizer info |
||||
if len(resume) > 0: |
||||
trainer.resume(resume) |
||||
# evaluate or finute, only load generator weights |
||||
elif len(pretrain_weights) > 0: |
||||
trainer.load(pretrain_weights) |
||||
if len(evaluate_weights) > 0: |
||||
trainer.load(evaluate_weights) |
||||
trainer.test() |
||||
return |
||||
# training, when keyboard interrupt save weights |
||||
try: |
||||
trainer.train() |
||||
except KeyboardInterrupt as e: |
||||
trainer.save(trainer.current_epoch) |
||||
|
||||
trainer.close() |
||||
|
||||
|
||||
# DRN模型训练 |
||||
class DRNet(BasicSRNet): |
||||
def __init__(self, |
||||
n_blocks=30, |
||||
n_feats=16, |
||||
n_colors=3, |
||||
rgb_range=255, |
||||
negval=0.2): |
||||
super(DRNet, self).__init__() |
||||
self.min_max = '(0., 255.)' |
||||
self.generator = { |
||||
'name': 'DRNGenerator', |
||||
'scale': (2, 4), |
||||
'n_blocks': n_blocks, |
||||
'n_feats': n_feats, |
||||
'n_colors': n_colors, |
||||
'rgb_range': rgb_range, |
||||
'negval': negval |
||||
} |
||||
self.pixel_criterion = {'name': 'L1Loss'} |
||||
self.model = { |
||||
'name': 'DRN', |
||||
'generator': self.generator, |
||||
'pixel_criterion': self.pixel_criterion |
||||
} |
||||
self.optimizer = { |
||||
'optimG': { |
||||
'name': 'Adam', |
||||
'net_names': ['generator'], |
||||
'weight_decay': 0.0, |
||||
'beta1': 0.9, |
||||
'beta2': 0.999 |
||||
}, |
||||
'optimD': { |
||||
'name': 'Adam', |
||||
'net_names': ['dual_model_0', 'dual_model_1'], |
||||
'weight_decay': 0.0, |
||||
'beta1': 0.9, |
||||
'beta2': 0.999 |
||||
} |
||||
} |
||||
self.lr_scheduler = { |
||||
'name': 'CosineAnnealingRestartLR', |
||||
'eta_min': 1e-07 |
||||
} |
||||
|
||||
|
||||
# 轻量化超分模型LESRCNN训练 |
||||
class LESRCNNet(BasicSRNet): |
||||
def __init__(self, scale=4, multi_scale=False, group=1): |
||||
super(LESRCNNet, self).__init__() |
||||
self.min_max = '(0., 1.)' |
||||
self.generator = { |
||||
'name': 'LESRCNNGenerator', |
||||
'scale': scale, |
||||
'multi_scale': False, |
||||
'group': 1 |
||||
} |
||||
self.pixel_criterion = {'name': 'L1Loss'} |
||||
self.model = { |
||||
'name': 'BaseSRModel', |
||||
'generator': self.generator, |
||||
'pixel_criterion': self.pixel_criterion |
||||
} |
||||
self.optimizer = { |
||||
'name': 'Adam', |
||||
'net_names': ['generator'], |
||||
'beta1': 0.9, |
||||
'beta2': 0.99 |
||||
} |
||||
self.lr_scheduler = { |
||||
'name': 'CosineAnnealingRestartLR', |
||||
'eta_min': 1e-07 |
||||
} |
||||
|
||||
|
||||
# ESRGAN模型训练 |
||||
# 若loss_type='gan' 使用感知损失、对抗损失和像素损失 |
||||
# 若loss_type = 'pixel' 只使用像素损失 |
||||
class ESRGANet(BasicSRNet): |
||||
def __init__(self, loss_type='gan', in_nc=3, out_nc=3, nf=64, nb=23): |
||||
super(ESRGANet, self).__init__() |
||||
self.min_max = '(0., 1.)' |
||||
self.generator = { |
||||
'name': 'RRDBNet', |
||||
'in_nc': in_nc, |
||||
'out_nc': out_nc, |
||||
'nf': nf, |
||||
'nb': nb |
||||
} |
||||
|
||||
if loss_type == 'gan': |
||||
# 定义损失函数 |
||||
self.pixel_criterion = {'name': 'L1Loss', 'loss_weight': 0.01} |
||||
self.discriminator = { |
||||
'name': 'VGGDiscriminator128', |
||||
'in_channels': 3, |
||||
'num_feat': 64 |
||||
} |
||||
self.perceptual_criterion = { |
||||
'name': 'PerceptualLoss', |
||||
'layer_weights': { |
||||
'34': 1.0 |
||||
}, |
||||
'perceptual_weight': 1.0, |
||||
'style_weight': 0.0, |
||||
'norm_img': False |
||||
} |
||||
self.gan_criterion = { |
||||
'name': 'GANLoss', |
||||
'gan_mode': 'vanilla', |
||||
'loss_weight': 0.005 |
||||
} |
||||
# 定义模型 |
||||
self.model = { |
||||
'name': 'ESRGAN', |
||||
'generator': self.generator, |
||||
'discriminator': self.discriminator, |
||||
'pixel_criterion': self.pixel_criterion, |
||||
'perceptual_criterion': self.perceptual_criterion, |
||||
'gan_criterion': self.gan_criterion |
||||
} |
||||
self.optimizer = { |
||||
'optimG': { |
||||
'name': 'Adam', |
||||
'net_names': ['generator'], |
||||
'weight_decay': 0.0, |
||||
'beta1': 0.9, |
||||
'beta2': 0.99 |
||||
}, |
||||
'optimD': { |
||||
'name': 'Adam', |
||||
'net_names': ['discriminator'], |
||||
'weight_decay': 0.0, |
||||
'beta1': 0.9, |
||||
'beta2': 0.99 |
||||
} |
||||
} |
||||
self.lr_scheduler = { |
||||
'name': 'MultiStepDecay', |
||||
'milestones': [50000, 100000, 200000, 300000], |
||||
'gamma': 0.5 |
||||
} |
||||
else: |
||||
self.pixel_criterion = {'name': 'L1Loss'} |
||||
self.model = { |
||||
'name': 'BaseSRModel', |
||||
'generator': self.generator, |
||||
'pixel_criterion': self.pixel_criterion |
||||
} |
||||
self.optimizer = { |
||||
'name': 'Adam', |
||||
'net_names': ['generator'], |
||||
'beta1': 0.9, |
||||
'beta2': 0.99 |
||||
} |
||||
self.lr_scheduler = { |
||||
'name': 'CosineAnnealingRestartLR', |
||||
'eta_min': 1e-07 |
||||
} |
||||
|
||||
|
||||
# RCAN模型训练 |
||||
class RCANet(BasicSRNet): |
||||
def __init__( |
||||
self, |
||||
scale=2, |
||||
n_resgroups=10, |
||||
n_resblocks=20, ): |
||||
super(RCANet, self).__init__() |
||||
self.min_max = '(0., 255.)' |
||||
self.generator = { |
||||
'name': 'RCAN', |
||||
'scale': scale, |
||||
'n_resgroups': n_resgroups, |
||||
'n_resblocks': n_resblocks |
||||
} |
||||
self.pixel_criterion = {'name': 'L1Loss'} |
||||
self.model = { |
||||
'name': 'RCANModel', |
||||
'generator': self.generator, |
||||
'pixel_criterion': self.pixel_criterion |
||||
} |
||||
self.optimizer = { |
||||
'name': 'Adam', |
||||
'net_names': ['generator'], |
||||
'beta1': 0.9, |
||||
'beta2': 0.99 |
||||
} |
||||
self.lr_scheduler = { |
||||
'name': 'MultiStepDecay', |
||||
'milestones': [250000, 500000, 750000, 1000000], |
||||
'gamma': 0.5 |
||||
} |
@ -0,0 +1,934 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import os |
||||
import os.path as osp |
||||
from collections import OrderedDict |
||||
|
||||
import numpy as np |
||||
import cv2 |
||||
import paddle |
||||
import paddle.nn.functional as F |
||||
from paddle.static import InputSpec |
||||
|
||||
import paddlers |
||||
import paddlers.models.ppgan as ppgan |
||||
import paddlers.rs_models.res as cmres |
||||
import paddlers.models.ppgan.metrics as metrics |
||||
import paddlers.utils.logging as logging |
||||
from paddlers.models import res_losses |
||||
from paddlers.transforms import Resize, decode_image |
||||
from paddlers.transforms.functions import calc_hr_shape |
||||
from paddlers.utils import get_single_card_bs |
||||
from .base import BaseModel |
||||
from .utils.res_adapters import GANAdapter, OptimizerAdapter |
||||
from .utils.infer_nets import InferResNet |
||||
|
||||
__all__ = ["DRN", "LESRCNN", "ESRGAN"] |
||||
|
||||
|
||||
class BaseRestorer(BaseModel): |
||||
MIN_MAX = (0., 1.) |
||||
TEST_OUT_KEY = None |
||||
|
||||
def __init__(self, model_name, losses=None, sr_factor=None, **params): |
||||
self.init_params = locals() |
||||
if 'with_net' in self.init_params: |
||||
del self.init_params['with_net'] |
||||
super(BaseRestorer, self).__init__('restorer') |
||||
self.model_name = model_name |
||||
self.losses = losses |
||||
self.sr_factor = sr_factor |
||||
if params.get('with_net', True): |
||||
params.pop('with_net', None) |
||||
self.net = self.build_net(**params) |
||||
self.find_unused_parameters = True |
||||
|
||||
def build_net(self, **params): |
||||
# Currently, only use models from cmres. |
||||
if not hasattr(cmres, self.model_name): |
||||
raise ValueError("ERROR: There is no model named {}.".format( |
||||
model_name)) |
||||
net = dict(**cmres.__dict__)[self.model_name](**params) |
||||
return net |
||||
|
||||
def _build_inference_net(self): |
||||
# For GAN models, only the generator will be used for inference. |
||||
if isinstance(self.net, GANAdapter): |
||||
infer_net = InferResNet( |
||||
self.net.generator, out_key=self.TEST_OUT_KEY) |
||||
else: |
||||
infer_net = InferResNet(self.net, out_key=self.TEST_OUT_KEY) |
||||
infer_net.eval() |
||||
return infer_net |
||||
|
||||
def _fix_transforms_shape(self, image_shape): |
||||
if hasattr(self, 'test_transforms'): |
||||
if self.test_transforms is not None: |
||||
has_resize_op = False |
||||
resize_op_idx = -1 |
||||
normalize_op_idx = len(self.test_transforms.transforms) |
||||
for idx, op in enumerate(self.test_transforms.transforms): |
||||
name = op.__class__.__name__ |
||||
if name == 'Normalize': |
||||
normalize_op_idx = idx |
||||
if 'Resize' in name: |
||||
has_resize_op = True |
||||
resize_op_idx = idx |
||||
|
||||
if not has_resize_op: |
||||
self.test_transforms.transforms.insert( |
||||
normalize_op_idx, Resize(target_size=image_shape)) |
||||
else: |
||||
self.test_transforms.transforms[resize_op_idx] = Resize( |
||||
target_size=image_shape) |
||||
|
||||
def _get_test_inputs(self, image_shape): |
||||
if image_shape is not None: |
||||
if len(image_shape) == 2: |
||||
image_shape = [1, 3] + image_shape |
||||
self._fix_transforms_shape(image_shape[-2:]) |
||||
else: |
||||
image_shape = [None, 3, -1, -1] |
||||
self.fixed_input_shape = image_shape |
||||
input_spec = [ |
||||
InputSpec( |
||||
shape=image_shape, name='image', dtype='float32') |
||||
] |
||||
return input_spec |
||||
|
||||
def run(self, net, inputs, mode): |
||||
outputs = OrderedDict() |
||||
|
||||
if mode == 'test': |
||||
tar_shape = inputs[1] |
||||
if self.status == 'Infer': |
||||
net_out = net(inputs[0]) |
||||
res_map_list = self.postprocess( |
||||
net_out, tar_shape, transforms=inputs[2]) |
||||
else: |
||||
if isinstance(net, GANAdapter): |
||||
net_out = net.generator(inputs[0]) |
||||
else: |
||||
net_out = net(inputs[0]) |
||||
if self.TEST_OUT_KEY is not None: |
||||
net_out = net_out[self.TEST_OUT_KEY] |
||||
pred = self.postprocess( |
||||
net_out, tar_shape, transforms=inputs[2]) |
||||
res_map_list = [] |
||||
for res_map in pred: |
||||
res_map = self._tensor_to_images(res_map) |
||||
res_map_list.append(res_map) |
||||
outputs['res_map'] = res_map_list |
||||
|
||||
if mode == 'eval': |
||||
if isinstance(net, GANAdapter): |
||||
net_out = net.generator(inputs[0]) |
||||
else: |
||||
net_out = net(inputs[0]) |
||||
if self.TEST_OUT_KEY is not None: |
||||
net_out = net_out[self.TEST_OUT_KEY] |
||||
tar = inputs[1] |
||||
tar_shape = [tar.shape[-2:]] |
||||
pred = self.postprocess( |
||||
net_out, tar_shape, transforms=inputs[2])[0] # NCHW |
||||
pred = self._tensor_to_images(pred) |
||||
outputs['pred'] = pred |
||||
tar = self._tensor_to_images(tar) |
||||
outputs['tar'] = tar |
||||
|
||||
if mode == 'train': |
||||
# This is used by non-GAN models. |
||||
# For GAN models, self.run_gan() should be used. |
||||
net_out = net(inputs[0]) |
||||
loss = self.losses(net_out, inputs[1]) |
||||
outputs['loss'] = loss |
||||
return outputs |
||||
|
||||
def run_gan(self, net, inputs, mode, gan_mode): |
||||
raise NotImplementedError |
||||
|
||||
def default_loss(self): |
||||
return res_losses.L1Loss() |
||||
|
||||
def default_optimizer(self, |
||||
parameters, |
||||
learning_rate, |
||||
num_epochs, |
||||
num_steps_each_epoch, |
||||
lr_decay_power=0.9): |
||||
decay_step = num_epochs * num_steps_each_epoch |
||||
lr_scheduler = paddle.optimizer.lr.PolynomialDecay( |
||||
learning_rate, decay_step, end_lr=0, power=lr_decay_power) |
||||
optimizer = paddle.optimizer.Momentum( |
||||
learning_rate=lr_scheduler, |
||||
parameters=parameters, |
||||
momentum=0.9, |
||||
weight_decay=4e-5) |
||||
return optimizer |
||||
|
||||
def train(self, |
||||
num_epochs, |
||||
train_dataset, |
||||
train_batch_size=2, |
||||
eval_dataset=None, |
||||
optimizer=None, |
||||
save_interval_epochs=1, |
||||
log_interval_steps=2, |
||||
save_dir='output', |
||||
pretrain_weights=None, |
||||
learning_rate=0.01, |
||||
lr_decay_power=0.9, |
||||
early_stop=False, |
||||
early_stop_patience=5, |
||||
use_vdl=True, |
||||
resume_checkpoint=None): |
||||
""" |
||||
Train the model. |
||||
|
||||
Args: |
||||
num_epochs (int): Number of epochs. |
||||
train_dataset (paddlers.datasets.ResDataset): Training dataset. |
||||
train_batch_size (int, optional): Total batch size among all cards used in |
||||
training. Defaults to 2. |
||||
eval_dataset (paddlers.datasets.ResDataset|None, optional): Evaluation dataset. |
||||
If None, the model will not be evaluated during training process. |
||||
Defaults to None. |
||||
optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in |
||||
training. If None, a default optimizer will be used. Defaults to None. |
||||
save_interval_epochs (int, optional): Epoch interval for saving the model. |
||||
Defaults to 1. |
||||
log_interval_steps (int, optional): Step interval for printing training |
||||
information. Defaults to 2. |
||||
save_dir (str, optional): Directory to save the model. Defaults to 'output'. |
||||
pretrain_weights (str|None, optional): None or name/path of pretrained |
||||
weights. If None, no pretrained weights will be loaded. |
||||
Defaults to None. |
||||
learning_rate (float, optional): Learning rate for training. Defaults to .01. |
||||
lr_decay_power (float, optional): Learning decay power. Defaults to .9. |
||||
early_stop (bool, optional): Whether to adopt early stop strategy. Defaults |
||||
to False. |
||||
early_stop_patience (int, optional): Early stop patience. Defaults to 5. |
||||
use_vdl (bool, optional): Whether to use VisualDL to monitor the training |
||||
process. Defaults to True. |
||||
resume_checkpoint (str|None, optional): Path of the checkpoint to resume |
||||
training from. If None, no training checkpoint will be resumed. At most |
||||
Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously. |
||||
Defaults to None. |
||||
""" |
||||
|
||||
if self.status == 'Infer': |
||||
logging.error( |
||||
"Exported inference model does not support training.", |
||||
exit=True) |
||||
if pretrain_weights is not None and resume_checkpoint is not None: |
||||
logging.error( |
||||
"pretrain_weights and resume_checkpoint cannot be set simultaneously.", |
||||
exit=True) |
||||
|
||||
if self.losses is None: |
||||
self.losses = self.default_loss() |
||||
|
||||
if optimizer is None: |
||||
num_steps_each_epoch = train_dataset.num_samples // train_batch_size |
||||
if isinstance(self.net, GANAdapter): |
||||
parameters = {'params_g': [], 'params_d': []} |
||||
for net_g in self.net.generators: |
||||
parameters['params_g'].append(net_g.parameters()) |
||||
for net_d in self.net.discriminators: |
||||
parameters['params_d'].append(net_d.parameters()) |
||||
else: |
||||
parameters = self.net.parameters() |
||||
self.optimizer = self.default_optimizer( |
||||
parameters, learning_rate, num_epochs, num_steps_each_epoch, |
||||
lr_decay_power) |
||||
else: |
||||
self.optimizer = optimizer |
||||
|
||||
if pretrain_weights is not None and not osp.exists(pretrain_weights): |
||||
logging.warning("Path of pretrain_weights('{}') does not exist!". |
||||
format(pretrain_weights)) |
||||
elif pretrain_weights is not None and osp.exists(pretrain_weights): |
||||
if osp.splitext(pretrain_weights)[-1] != '.pdparams': |
||||
logging.error( |
||||
"Invalid pretrain weights. Please specify a '.pdparams' file.", |
||||
exit=True) |
||||
pretrained_dir = osp.join(save_dir, 'pretrain') |
||||
is_backbone_weights = pretrain_weights == 'IMAGENET' |
||||
self.net_initialize( |
||||
pretrain_weights=pretrain_weights, |
||||
save_dir=pretrained_dir, |
||||
resume_checkpoint=resume_checkpoint, |
||||
is_backbone_weights=is_backbone_weights) |
||||
|
||||
self.train_loop( |
||||
num_epochs=num_epochs, |
||||
train_dataset=train_dataset, |
||||
train_batch_size=train_batch_size, |
||||
eval_dataset=eval_dataset, |
||||
save_interval_epochs=save_interval_epochs, |
||||
log_interval_steps=log_interval_steps, |
||||
save_dir=save_dir, |
||||
early_stop=early_stop, |
||||
early_stop_patience=early_stop_patience, |
||||
use_vdl=use_vdl) |
||||
|
||||
def quant_aware_train(self, |
||||
num_epochs, |
||||
train_dataset, |
||||
train_batch_size=2, |
||||
eval_dataset=None, |
||||
optimizer=None, |
||||
save_interval_epochs=1, |
||||
log_interval_steps=2, |
||||
save_dir='output', |
||||
learning_rate=0.0001, |
||||
lr_decay_power=0.9, |
||||
early_stop=False, |
||||
early_stop_patience=5, |
||||
use_vdl=True, |
||||
resume_checkpoint=None, |
||||
quant_config=None): |
||||
""" |
||||
Quantization-aware training. |
||||
|
||||
Args: |
||||
num_epochs (int): Number of epochs. |
||||
train_dataset (paddlers.datasets.ResDataset): Training dataset. |
||||
train_batch_size (int, optional): Total batch size among all cards used in |
||||
training. Defaults to 2. |
||||
eval_dataset (paddlers.datasets.ResDataset|None, optional): Evaluation dataset. |
||||
If None, the model will not be evaluated during training process. |
||||
Defaults to None. |
||||
optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used in |
||||
training. If None, a default optimizer will be used. Defaults to None. |
||||
save_interval_epochs (int, optional): Epoch interval for saving the model. |
||||
Defaults to 1. |
||||
log_interval_steps (int, optional): Step interval for printing training |
||||
information. Defaults to 2. |
||||
save_dir (str, optional): Directory to save the model. Defaults to 'output'. |
||||
learning_rate (float, optional): Learning rate for training. |
||||
Defaults to .0001. |
||||
lr_decay_power (float, optional): Learning decay power. Defaults to .9. |
||||
early_stop (bool, optional): Whether to adopt early stop strategy. |
||||
Defaults to False. |
||||
early_stop_patience (int, optional): Early stop patience. Defaults to 5. |
||||
use_vdl (bool, optional): Whether to use VisualDL to monitor the training |
||||
process. Defaults to True. |
||||
quant_config (dict|None, optional): Quantization configuration. If None, |
||||
a default rule of thumb configuration will be used. Defaults to None. |
||||
resume_checkpoint (str|None, optional): Path of the checkpoint to resume |
||||
quantization-aware training from. If None, no training checkpoint will |
||||
be resumed. Defaults to None. |
||||
""" |
||||
|
||||
self._prepare_qat(quant_config) |
||||
self.train( |
||||
num_epochs=num_epochs, |
||||
train_dataset=train_dataset, |
||||
train_batch_size=train_batch_size, |
||||
eval_dataset=eval_dataset, |
||||
optimizer=optimizer, |
||||
save_interval_epochs=save_interval_epochs, |
||||
log_interval_steps=log_interval_steps, |
||||
save_dir=save_dir, |
||||
pretrain_weights=None, |
||||
learning_rate=learning_rate, |
||||
lr_decay_power=lr_decay_power, |
||||
early_stop=early_stop, |
||||
early_stop_patience=early_stop_patience, |
||||
use_vdl=use_vdl, |
||||
resume_checkpoint=resume_checkpoint) |
||||
|
||||
def evaluate(self, eval_dataset, batch_size=1, return_details=False): |
||||
""" |
||||
Evaluate the model. |
||||
|
||||
Args: |
||||
eval_dataset (paddlers.datasets.ResDataset): Evaluation dataset. |
||||
batch_size (int, optional): Total batch size among all cards used for |
||||
evaluation. Defaults to 1. |
||||
return_details (bool, optional): Whether to return evaluation details. |
||||
Defaults to False. |
||||
|
||||
Returns: |
||||
If `return_details` is False, return collections.OrderedDict with |
||||
key-value pairs: |
||||
{"psnr": `peak signal-to-noise ratio`, |
||||
"ssim": `structural similarity`}. |
||||
|
||||
""" |
||||
|
||||
self._check_transforms(eval_dataset.transforms, 'eval') |
||||
|
||||
self.net.eval() |
||||
nranks = paddle.distributed.get_world_size() |
||||
local_rank = paddle.distributed.get_rank() |
||||
if nranks > 1: |
||||
# Initialize parallel environment if not done. |
||||
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized( |
||||
): |
||||
paddle.distributed.init_parallel_env() |
||||
|
||||
# TODO: Distributed evaluation |
||||
if batch_size > 1: |
||||
logging.warning( |
||||
"Restorer only supports single card evaluation with batch_size=1 " |
||||
"during evaluation, so batch_size is forcibly set to 1.") |
||||
batch_size = 1 |
||||
|
||||
if nranks < 2 or local_rank == 0: |
||||
self.eval_data_loader = self.build_data_loader( |
||||
eval_dataset, batch_size=batch_size, mode='eval') |
||||
# XXX: Hard-code crop_border and test_y_channel |
||||
psnr = metrics.PSNR(crop_border=4, test_y_channel=True) |
||||
ssim = metrics.SSIM(crop_border=4, test_y_channel=True) |
||||
logging.info( |
||||
"Start to evaluate(total_samples={}, total_steps={})...".format( |
||||
eval_dataset.num_samples, eval_dataset.num_samples)) |
||||
with paddle.no_grad(): |
||||
for step, data in enumerate(self.eval_data_loader): |
||||
data.append(eval_dataset.transforms.transforms) |
||||
outputs = self.run(self.net, data, 'eval') |
||||
psnr.update(outputs['pred'], outputs['tar']) |
||||
ssim.update(outputs['pred'], outputs['tar']) |
||||
|
||||
# DO NOT use psnr.accumulate() here, otherwise the program hangs in multi-card training. |
||||
assert len(psnr.results) > 0 |
||||
assert len(ssim.results) > 0 |
||||
eval_metrics = OrderedDict( |
||||
zip(['psnr', 'ssim'], |
||||
[np.mean(psnr.results), np.mean(ssim.results)])) |
||||
|
||||
if return_details: |
||||
# TODO: Add details |
||||
return eval_metrics, None |
||||
|
||||
return eval_metrics |
||||
|
||||
def predict(self, img_file, transforms=None): |
||||
""" |
||||
Do inference. |
||||
|
||||
Args: |
||||
img_file (list[np.ndarray|str] | str | np.ndarray): Image path or decoded |
||||
image data, which also could constitute a list, meaning all images to be |
||||
predicted as a mini-batch. |
||||
transforms (paddlers.transforms.Compose|None, optional): Transforms for |
||||
inputs. If None, the transforms for evaluation process will be used. |
||||
Defaults to None. |
||||
|
||||
Returns: |
||||
If `img_file` is a tuple of string or np.array, the result is a dict with |
||||
the following key-value pairs: |
||||
res_map (np.ndarray): Restored image (HWC). |
||||
|
||||
If `img_file` is a list, the result is a list composed of dicts with the |
||||
above keys. |
||||
""" |
||||
|
||||
if transforms is None and not hasattr(self, 'test_transforms'): |
||||
raise ValueError("transforms need to be defined, now is None.") |
||||
if transforms is None: |
||||
transforms = self.test_transforms |
||||
if isinstance(img_file, (str, np.ndarray)): |
||||
images = [img_file] |
||||
else: |
||||
images = img_file |
||||
batch_im, batch_tar_shape = self.preprocess(images, transforms, |
||||
self.model_type) |
||||
self.net.eval() |
||||
data = (batch_im, batch_tar_shape, transforms.transforms) |
||||
outputs = self.run(self.net, data, 'test') |
||||
res_map_list = outputs['res_map'] |
||||
if isinstance(img_file, list): |
||||
prediction = [{'res_map': m} for m in res_map_list] |
||||
else: |
||||
prediction = {'res_map': res_map_list[0]} |
||||
return prediction |
||||
|
||||
def preprocess(self, images, transforms, to_tensor=True): |
||||
self._check_transforms(transforms, 'test') |
||||
batch_im = list() |
||||
batch_tar_shape = list() |
||||
for im in images: |
||||
if isinstance(im, str): |
||||
im = decode_image(im, to_rgb=False) |
||||
ori_shape = im.shape[:2] |
||||
sample = {'image': im} |
||||
im = transforms(sample)[0] |
||||
batch_im.append(im) |
||||
batch_tar_shape.append(self._get_target_shape(ori_shape)) |
||||
if to_tensor: |
||||
batch_im = paddle.to_tensor(batch_im) |
||||
else: |
||||
batch_im = np.asarray(batch_im) |
||||
|
||||
return batch_im, batch_tar_shape |
||||
|
||||
def _get_target_shape(self, ori_shape): |
||||
if self.sr_factor is None: |
||||
return ori_shape |
||||
else: |
||||
return calc_hr_shape(ori_shape, self.sr_factor) |
||||
|
||||
@staticmethod |
||||
def get_transforms_shape_info(batch_tar_shape, transforms): |
||||
batch_restore_list = list() |
||||
for tar_shape in batch_tar_shape: |
||||
restore_list = list() |
||||
h, w = tar_shape[0], tar_shape[1] |
||||
for op in transforms: |
||||
if op.__class__.__name__ == 'Resize': |
||||
restore_list.append(('resize', (h, w))) |
||||
h, w = op.target_size |
||||
elif op.__class__.__name__ == 'ResizeByShort': |
||||
restore_list.append(('resize', (h, w))) |
||||
im_short_size = min(h, w) |
||||
im_long_size = max(h, w) |
||||
scale = float(op.short_size) / float(im_short_size) |
||||
if 0 < op.max_size < np.round(scale * im_long_size): |
||||
scale = float(op.max_size) / float(im_long_size) |
||||
h = int(round(h * scale)) |
||||
w = int(round(w * scale)) |
||||
elif op.__class__.__name__ == 'ResizeByLong': |
||||
restore_list.append(('resize', (h, w))) |
||||
im_long_size = max(h, w) |
||||
scale = float(op.long_size) / float(im_long_size) |
||||
h = int(round(h * scale)) |
||||
w = int(round(w * scale)) |
||||
elif op.__class__.__name__ == 'Pad': |
||||
if op.target_size: |
||||
target_h, target_w = op.target_size |
||||
else: |
||||
target_h = int( |
||||
(np.ceil(h / op.size_divisor) * op.size_divisor)) |
||||
target_w = int( |
||||
(np.ceil(w / op.size_divisor) * op.size_divisor)) |
||||
|
||||
if op.pad_mode == -1: |
||||
offsets = op.offsets |
||||
elif op.pad_mode == 0: |
||||
offsets = [0, 0] |
||||
elif op.pad_mode == 1: |
||||
offsets = [(target_h - h) // 2, (target_w - w) // 2] |
||||
else: |
||||
offsets = [target_h - h, target_w - w] |
||||
restore_list.append(('padding', (h, w), offsets)) |
||||
h, w = target_h, target_w |
||||
|
||||
batch_restore_list.append(restore_list) |
||||
return batch_restore_list |
||||
|
||||
def postprocess(self, batch_pred, batch_tar_shape, transforms): |
||||
batch_restore_list = BaseRestorer.get_transforms_shape_info( |
||||
batch_tar_shape, transforms) |
||||
if self.status == 'Infer': |
||||
return self._infer_postprocess( |
||||
batch_res_map=batch_pred, batch_restore_list=batch_restore_list) |
||||
results = [] |
||||
if batch_pred.dtype == paddle.float32: |
||||
mode = 'bilinear' |
||||
else: |
||||
mode = 'nearest' |
||||
for pred, restore_list in zip(batch_pred, batch_restore_list): |
||||
pred = paddle.unsqueeze(pred, axis=0) |
||||
for item in restore_list[::-1]: |
||||
h, w = item[1][0], item[1][1] |
||||
if item[0] == 'resize': |
||||
pred = F.interpolate( |
||||
pred, (h, w), mode=mode, data_format='NCHW') |
||||
elif item[0] == 'padding': |
||||
x, y = item[2] |
||||
pred = pred[:, :, y:y + h, x:x + w] |
||||
else: |
||||
pass |
||||
results.append(pred) |
||||
return results |
||||
|
||||
def _infer_postprocess(self, batch_res_map, batch_restore_list): |
||||
res_maps = [] |
||||
for res_map, restore_list in zip(batch_res_map, batch_restore_list): |
||||
if not isinstance(res_map, np.ndarray): |
||||
res_map = paddle.unsqueeze(res_map, axis=0) |
||||
for item in restore_list[::-1]: |
||||
h, w = item[1][0], item[1][1] |
||||
if item[0] == 'resize': |
||||
if isinstance(res_map, np.ndarray): |
||||
res_map = cv2.resize( |
||||
res_map, (w, h), interpolation=cv2.INTER_LINEAR) |
||||
else: |
||||
res_map = F.interpolate( |
||||
res_map, (h, w), |
||||
mode='bilinear', |
||||
data_format='NHWC') |
||||
elif item[0] == 'padding': |
||||
x, y = item[2] |
||||
if isinstance(res_map, np.ndarray): |
||||
res_map = res_map[y:y + h, x:x + w] |
||||
else: |
||||
res_map = res_map[:, y:y + h, x:x + w, :] |
||||
else: |
||||
pass |
||||
res_map = res_map.squeeze() |
||||
if not isinstance(res_map, np.ndarray): |
||||
res_map = res_map.numpy() |
||||
res_map = self._normalize(res_map) |
||||
res_maps.append(res_map.squeeze()) |
||||
return res_maps |
||||
|
||||
def _check_transforms(self, transforms, mode): |
||||
super()._check_transforms(transforms, mode) |
||||
if not isinstance(transforms.arrange, |
||||
paddlers.transforms.ArrangeRestorer): |
||||
raise TypeError( |
||||
"`transforms.arrange` must be an ArrangeRestorer object.") |
||||
|
||||
def build_data_loader(self, dataset, batch_size, mode='train'): |
||||
if dataset.num_samples < batch_size: |
||||
raise ValueError( |
||||
'The volume of dataset({}) must be larger than batch size({}).' |
||||
.format(dataset.num_samples, batch_size)) |
||||
|
||||
if mode != 'train': |
||||
return paddle.io.DataLoader( |
||||
dataset, |
||||
batch_size=batch_size, |
||||
shuffle=dataset.shuffle, |
||||
drop_last=False, |
||||
collate_fn=dataset.batch_transforms, |
||||
num_workers=dataset.num_workers, |
||||
return_list=True, |
||||
use_shared_memory=False) |
||||
else: |
||||
return super(BaseRestorer, self).build_data_loader(dataset, |
||||
batch_size, mode) |
||||
|
||||
def set_losses(self, losses): |
||||
self.losses = losses |
||||
|
||||
def _tensor_to_images(self, |
||||
tensor, |
||||
transpose=True, |
||||
squeeze=True, |
||||
quantize=True): |
||||
if transpose: |
||||
tensor = paddle.transpose(tensor, perm=[0, 2, 3, 1]) # NHWC |
||||
if squeeze: |
||||
tensor = tensor.squeeze() |
||||
images = tensor.numpy().astype('float32') |
||||
images = self._normalize( |
||||
images, copy=True, clip=True, quantize=quantize) |
||||
return images |
||||
|
||||
def _normalize(self, im, copy=False, clip=True, quantize=True): |
||||
if copy: |
||||
im = im.copy() |
||||
if clip: |
||||
im = np.clip(im, self.MIN_MAX[0], self.MIN_MAX[1]) |
||||
im -= im.min() |
||||
im /= im.max() + 1e-32 |
||||
if quantize: |
||||
im *= 255 |
||||
im = im.astype('uint8') |
||||
return im |
||||
|
||||
|
||||
class DRN(BaseRestorer): |
||||
TEST_OUT_KEY = -1 |
||||
|
||||
def __init__(self, |
||||
losses=None, |
||||
sr_factor=4, |
||||
scale=(2, 4), |
||||
n_blocks=30, |
||||
n_feats=16, |
||||
n_colors=3, |
||||
rgb_range=1.0, |
||||
negval=0.2, |
||||
lq_loss_weight=0.1, |
||||
dual_loss_weight=0.1, |
||||
**params): |
||||
if sr_factor != max(scale): |
||||
raise ValueError(f"`sr_factor` must be equal to `max(scale)`.") |
||||
params.update({ |
||||
'scale': scale, |
||||
'n_blocks': n_blocks, |
||||
'n_feats': n_feats, |
||||
'n_colors': n_colors, |
||||
'rgb_range': rgb_range, |
||||
'negval': negval |
||||
}) |
||||
self.lq_loss_weight = lq_loss_weight |
||||
self.dual_loss_weight = dual_loss_weight |
||||
super(DRN, self).__init__( |
||||
model_name='DRN', losses=losses, sr_factor=sr_factor, **params) |
||||
|
||||
def build_net(self, **params): |
||||
from ppgan.modules.init import init_weights |
||||
generators = [ppgan.models.generators.DRNGenerator(**params)] |
||||
init_weights(generators[-1]) |
||||
for scale in params['scale']: |
||||
dual_model = ppgan.models.generators.drn.DownBlock( |
||||
params['negval'], params['n_feats'], params['n_colors'], 2) |
||||
generators.append(dual_model) |
||||
init_weights(generators[-1]) |
||||
return GANAdapter(generators, []) |
||||
|
||||
def default_optimizer(self, parameters, *args, **kwargs): |
||||
optims_g = [ |
||||
super(DRN, self).default_optimizer(params_g, *args, **kwargs) |
||||
for params_g in parameters['params_g'] |
||||
] |
||||
return OptimizerAdapter(*optims_g) |
||||
|
||||
def run_gan(self, net, inputs, mode, gan_mode='forward_primary'): |
||||
if mode != 'train': |
||||
raise ValueError("`mode` is not 'train'.") |
||||
outputs = OrderedDict() |
||||
if gan_mode == 'forward_primary': |
||||
sr = net.generator(inputs[0]) |
||||
lr = [inputs[0]] |
||||
lr.extend([ |
||||
F.interpolate( |
||||
inputs[0], scale_factor=s, mode='bicubic') |
||||
for s in net.generator.scale[:-1] |
||||
]) |
||||
loss = self.losses(sr[-1], inputs[1]) |
||||
for i in range(1, len(sr)): |
||||
if self.lq_loss_weight > 0: |
||||
loss += self.losses(sr[i - 1 - len(sr)], |
||||
lr[i - len(sr)]) * self.lq_loss_weight |
||||
outputs['loss_prim'] = loss |
||||
outputs['sr'] = sr |
||||
outputs['lr'] = lr |
||||
elif gan_mode == 'forward_dual': |
||||
sr, lr = inputs[0], inputs[1] |
||||
sr2lr = [] |
||||
n_scales = len(net.generator.scale) |
||||
for i in range(n_scales): |
||||
sr2lr_i = net.generators[1 + i](sr[i - n_scales]) |
||||
sr2lr.append(sr2lr_i) |
||||
loss = self.losses(sr2lr[0], lr[0]) |
||||
for i in range(1, n_scales): |
||||
if self.dual_loss_weight > 0.0: |
||||
loss += self.losses(sr2lr[i], lr[i]) * self.dual_loss_weight |
||||
outputs['loss_dual'] = loss |
||||
else: |
||||
raise ValueError("Invalid `gan_mode`!") |
||||
return outputs |
||||
|
||||
def train_step(self, step, data, net): |
||||
outputs = self.run_gan( |
||||
net, data, mode='train', gan_mode='forward_primary') |
||||
outputs.update( |
||||
self.run_gan( |
||||
net, (outputs['sr'], outputs['lr']), |
||||
mode='train', |
||||
gan_mode='forward_dual')) |
||||
self.optimizer.clear_grad() |
||||
(outputs['loss_prim'] + outputs['loss_dual']).backward() |
||||
self.optimizer.step() |
||||
return { |
||||
'loss_prim': outputs['loss_prim'], |
||||
'loss_dual': outputs['loss_dual'] |
||||
} |
||||
|
||||
|
||||
class LESRCNN(BaseRestorer): |
||||
def __init__(self, |
||||
losses=None, |
||||
sr_factor=4, |
||||
multi_scale=False, |
||||
group=1, |
||||
**params): |
||||
params.update({ |
||||
'scale': sr_factor, |
||||
'multi_scale': multi_scale, |
||||
'group': group |
||||
}) |
||||
super(LESRCNN, self).__init__( |
||||
model_name='LESRCNN', losses=losses, sr_factor=sr_factor, **params) |
||||
|
||||
def build_net(self, **params): |
||||
net = ppgan.models.generators.LESRCNNGenerator(**params) |
||||
return net |
||||
|
||||
|
||||
class ESRGAN(BaseRestorer): |
||||
def __init__(self, |
||||
losses=None, |
||||
sr_factor=4, |
||||
use_gan=True, |
||||
in_channels=3, |
||||
out_channels=3, |
||||
nf=64, |
||||
nb=23, |
||||
**params): |
||||
if sr_factor != 4: |
||||
raise ValueError("`sr_factor` must be 4.") |
||||
params.update({ |
||||
'in_nc': in_channels, |
||||
'out_nc': out_channels, |
||||
'nf': nf, |
||||
'nb': nb |
||||
}) |
||||
self.use_gan = use_gan |
||||
super(ESRGAN, self).__init__( |
||||
model_name='ESRGAN', losses=losses, sr_factor=sr_factor, **params) |
||||
|
||||
def build_net(self, **params): |
||||
from ppgan.modules.init import init_weights |
||||
generator = ppgan.models.generators.RRDBNet(**params) |
||||
init_weights(generator) |
||||
if self.use_gan: |
||||
discriminator = ppgan.models.discriminators.VGGDiscriminator128( |
||||
in_channels=params['out_nc'], num_feat=64) |
||||
net = GANAdapter( |
||||
generators=[generator], discriminators=[discriminator]) |
||||
else: |
||||
net = generator |
||||
return net |
||||
|
||||
def default_loss(self): |
||||
if self.use_gan: |
||||
return { |
||||
'pixel': res_losses.L1Loss(loss_weight=0.01), |
||||
'perceptual': res_losses.PerceptualLoss( |
||||
layer_weights={'34': 1.0}, |
||||
perceptual_weight=1.0, |
||||
style_weight=0.0, |
||||
norm_img=False), |
||||
'gan': res_losses.GANLoss( |
||||
gan_mode='vanilla', loss_weight=0.005) |
||||
} |
||||
else: |
||||
return res_losses.L1Loss() |
||||
|
||||
def default_optimizer(self, parameters, *args, **kwargs): |
||||
if self.use_gan: |
||||
optim_g = super(ESRGAN, self).default_optimizer( |
||||
parameters['params_g'][0], *args, **kwargs) |
||||
optim_d = super(ESRGAN, self).default_optimizer( |
||||
parameters['params_d'][0], *args, **kwargs) |
||||
return OptimizerAdapter(optim_g, optim_d) |
||||
else: |
||||
return super(ESRGAN, self).default_optimizer(parameters, *args, |
||||
**kwargs) |
||||
|
||||
def run_gan(self, net, inputs, mode, gan_mode='forward_g'): |
||||
if mode != 'train': |
||||
raise ValueError("`mode` is not 'train'.") |
||||
outputs = OrderedDict() |
||||
if gan_mode == 'forward_g': |
||||
loss_g = 0 |
||||
g_pred = net.generator(inputs[0]) |
||||
loss_pix = self.losses['pixel'](g_pred, inputs[1]) |
||||
loss_perc, loss_sty = self.losses['perceptual'](g_pred, inputs[1]) |
||||
loss_g += loss_pix |
||||
if loss_perc is not None: |
||||
loss_g += loss_perc |
||||
if loss_sty is not None: |
||||
loss_g += loss_sty |
||||
self._set_requires_grad(net.discriminator, False) |
||||
real_d_pred = net.discriminator(inputs[1]).detach() |
||||
fake_g_pred = net.discriminator(g_pred) |
||||
loss_g_real = self.losses['gan']( |
||||
real_d_pred - paddle.mean(fake_g_pred), False, |
||||
is_disc=False) * 0.5 |
||||
loss_g_fake = self.losses['gan']( |
||||
fake_g_pred - paddle.mean(real_d_pred), True, |
||||
is_disc=False) * 0.5 |
||||
loss_g_gan = loss_g_real + loss_g_fake |
||||
outputs['g_pred'] = g_pred.detach() |
||||
outputs['loss_g_pps'] = loss_g |
||||
outputs['loss_g_gan'] = loss_g_gan |
||||
elif gan_mode == 'forward_d': |
||||
self._set_requires_grad(net.discriminator, True) |
||||
# Real |
||||
fake_d_pred = net.discriminator(inputs[0]).detach() |
||||
real_d_pred = net.discriminator(inputs[1]) |
||||
loss_d_real = self.losses['gan']( |
||||
real_d_pred - paddle.mean(fake_d_pred), True, |
||||
is_disc=True) * 0.5 |
||||
# Fake |
||||
fake_d_pred = net.discriminator(inputs[0].detach()) |
||||
loss_d_fake = self.losses['gan']( |
||||
fake_d_pred - paddle.mean(real_d_pred.detach()), |
||||
False, |
||||
is_disc=True) * 0.5 |
||||
outputs['loss_d'] = loss_d_real + loss_d_fake |
||||
else: |
||||
raise ValueError("Invalid `gan_mode`!") |
||||
return outputs |
||||
|
||||
def train_step(self, step, data, net): |
||||
if self.use_gan: |
||||
optim_g, optim_d = self.optimizer |
||||
|
||||
outputs = self.run_gan( |
||||
net, data, mode='train', gan_mode='forward_g') |
||||
optim_g.clear_grad() |
||||
(outputs['loss_g_pps'] + outputs['loss_g_gan']).backward() |
||||
optim_g.step() |
||||
|
||||
outputs.update( |
||||
self.run_gan( |
||||
net, (outputs['g_pred'], data[1]), |
||||
mode='train', |
||||
gan_mode='forward_d')) |
||||
optim_d.clear_grad() |
||||
outputs['loss_d'].backward() |
||||
optim_d.step() |
||||
|
||||
outputs['loss'] = outputs['loss_g_pps'] + outputs[ |
||||
'loss_g_gan'] + outputs['loss_d'] |
||||
|
||||
return { |
||||
'loss': outputs['loss'], |
||||
'loss_g_pps': outputs['loss_g_pps'], |
||||
'loss_g_gan': outputs['loss_g_gan'], |
||||
'loss_d': outputs['loss_d'] |
||||
} |
||||
else: |
||||
return super(ESRGAN, self).train_step(step, data, net) |
||||
|
||||
def _set_requires_grad(self, net, requires_grad): |
||||
for p in net.parameters(): |
||||
p.trainable = requires_grad |
||||
|
||||
|
||||
class RCAN(BaseRestorer): |
||||
def __init__(self, |
||||
losses=None, |
||||
sr_factor=4, |
||||
n_resgroups=10, |
||||
n_resblocks=20, |
||||
n_feats=64, |
||||
n_colors=3, |
||||
rgb_range=1.0, |
||||
kernel_size=3, |
||||
reduction=16, |
||||
**params): |
||||
params.update({ |
||||
'n_resgroups': n_resgroups, |
||||
'n_resblocks': n_resblocks, |
||||
'n_feats': n_feats, |
||||
'n_colors': n_colors, |
||||
'rgb_range': rgb_range, |
||||
'kernel_size': kernel_size, |
||||
'reduction': reduction |
||||
}) |
||||
super(RCAN, self).__init__( |
||||
model_name='RCAN', losses=losses, sr_factor=sr_factor, **params) |
@ -0,0 +1,132 @@ |
||||
from functools import wraps |
||||
from inspect import isfunction, isgeneratorfunction, getmembers |
||||
from collections.abc import Sequence |
||||
from abc import ABC |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
|
||||
__all__ = ['GANAdapter', 'OptimizerAdapter'] |
||||
|
||||
|
||||
class _AttrDesc: |
||||
def __init__(self, key): |
||||
self.key = key |
||||
|
||||
def __get__(self, instance, owner): |
||||
return tuple(getattr(ele, self.key) for ele in instance) |
||||
|
||||
def __set__(self, instance, value): |
||||
for ele in instance: |
||||
setattr(ele, self.key, value) |
||||
|
||||
|
||||
def _func_deco(cls, func_name): |
||||
@wraps(getattr(cls.__ducktype__, func_name)) |
||||
def _wrapper(self, *args, **kwargs): |
||||
return tuple(getattr(ele, func_name)(*args, **kwargs) for ele in self) |
||||
|
||||
return _wrapper |
||||
|
||||
|
||||
def _generator_deco(cls, func_name): |
||||
@wraps(getattr(cls.__ducktype__, func_name)) |
||||
def _wrapper(self, *args, **kwargs): |
||||
for ele in self: |
||||
yield from getattr(ele, func_name)(*args, **kwargs) |
||||
|
||||
return _wrapper |
||||
|
||||
|
||||
class Adapter(Sequence, ABC): |
||||
__ducktype__ = object |
||||
__ava__ = () |
||||
|
||||
def __init__(self, *args): |
||||
if not all(map(self._check, args)): |
||||
raise TypeError("Please check the input type.") |
||||
self._seq = tuple(args) |
||||
|
||||
def __getitem__(self, key): |
||||
return self._seq[key] |
||||
|
||||
def __len__(self): |
||||
return len(self._seq) |
||||
|
||||
def __repr__(self): |
||||
return repr(self._seq) |
||||
|
||||
@classmethod |
||||
def _check(cls, obj): |
||||
for attr in cls.__ava__: |
||||
try: |
||||
getattr(obj, attr) |
||||
# TODO: Check function signature |
||||
except AttributeError: |
||||
return False |
||||
return True |
||||
|
||||
|
||||
def make_adapter(cls): |
||||
members = dict(getmembers(cls.__ducktype__)) |
||||
for k in cls.__ava__: |
||||
if hasattr(cls, k): |
||||
continue |
||||
if k in members: |
||||
v = members[k] |
||||
if isgeneratorfunction(v): |
||||
setattr(cls, k, _generator_deco(cls, k)) |
||||
elif isfunction(v): |
||||
setattr(cls, k, _func_deco(cls, k)) |
||||
else: |
||||
setattr(cls, k, _AttrDesc(k)) |
||||
return cls |
||||
|
||||
|
||||
class GANAdapter(nn.Layer): |
||||
__ducktype__ = nn.Layer |
||||
__ava__ = ('state_dict', 'set_state_dict', 'train', 'eval') |
||||
|
||||
def __init__(self, generators, discriminators): |
||||
super(GANAdapter, self).__init__() |
||||
self.generators = nn.LayerList(generators) |
||||
self.discriminators = nn.LayerList(discriminators) |
||||
self._m = [*generators, *discriminators] |
||||
|
||||
def __len__(self): |
||||
return len(self._m) |
||||
|
||||
def __getitem__(self, key): |
||||
return self._m[key] |
||||
|
||||
def __contains__(self, m): |
||||
return m in self._m |
||||
|
||||
def __repr__(self): |
||||
return repr(self._m) |
||||
|
||||
@property |
||||
def generator(self): |
||||
return self.generators[0] |
||||
|
||||
@property |
||||
def discriminator(self): |
||||
return self.discriminators[0] |
||||
|
||||
|
||||
Adapter.register(GANAdapter) |
||||
|
||||
|
||||
@make_adapter |
||||
class OptimizerAdapter(Adapter): |
||||
__ducktype__ = paddle.optimizer.Optimizer |
||||
__ava__ = ('state_dict', 'set_state_dict', 'clear_grad', 'step', 'get_lr') |
||||
|
||||
def set_state_dict(self, state_dicts): |
||||
# Special dispatching rule |
||||
for optim, state_dict in zip(self, state_dicts): |
||||
optim.set_state_dict(state_dict) |
||||
|
||||
def get_lr(self): |
||||
# Return the lr of the first optimizer |
||||
return self[0].get_lr() |
@ -0,0 +1,62 @@ |
||||
# Basic configurations of LEVIR-CD dataset |
||||
|
||||
datasets: |
||||
train: !Node |
||||
type: CDDataset |
||||
args: |
||||
data_dir: ./test_tipc/data/levircd/ |
||||
file_list: ./test_tipc/data/levircd/train.txt |
||||
label_list: null |
||||
num_workers: 0 |
||||
shuffle: True |
||||
with_seg_labels: False |
||||
binarize_labels: True |
||||
eval: !Node |
||||
type: CDDataset |
||||
args: |
||||
data_dir: ./test_tipc/data/levircd/ |
||||
file_list: ./test_tipc/data/levircd/val.txt |
||||
label_list: null |
||||
num_workers: 0 |
||||
shuffle: False |
||||
with_seg_labels: False |
||||
binarize_labels: True |
||||
transforms: |
||||
train: |
||||
- !Node |
||||
type: DecodeImg |
||||
- !Node |
||||
type: RandomHorizontalFlip |
||||
args: |
||||
prob: 0.5 |
||||
- !Node |
||||
type: Normalize |
||||
args: |
||||
mean: [0.5, 0.5, 0.5] |
||||
std: [0.5, 0.5, 0.5] |
||||
- !Node |
||||
type: ArrangeChangeDetector |
||||
args: ['train'] |
||||
eval: |
||||
- !Node |
||||
type: DecodeImg |
||||
- !Node |
||||
type: Normalize |
||||
args: |
||||
mean: [0.5, 0.5, 0.5] |
||||
std: [0.5, 0.5, 0.5] |
||||
- !Node |
||||
type: ArrangeChangeDetector |
||||
args: ['eval'] |
||||
download_on: False |
||||
|
||||
num_epochs: 10 |
||||
train_batch_size: 8 |
||||
save_interval_epochs: 5 |
||||
log_interval_steps: 50 |
||||
save_dir: ./test_tipc/output/cd/ |
||||
learning_rate: 0.002 |
||||
early_stop: False |
||||
early_stop_patience: 5 |
||||
use_vdl: False |
||||
resume_checkpoint: '' |
@ -0,0 +1,8 @@ |
||||
# Basic configurations of BIT with AirChange dataset |
||||
|
||||
_base_: ../_base_/airchange.yaml |
||||
|
||||
save_dir: ./test_tipc/output/cd/bit/ |
||||
|
||||
model: !Node |
||||
type: BIT |
@ -0,0 +1,8 @@ |
||||
# Basic configurations of BIT with LEVIR-CD dataset |
||||
|
||||
_base_: ../_base_/levircd.yaml |
||||
|
||||
save_dir: ./test_tipc/output/cd/bit/ |
||||
|
||||
model: !Node |
||||
type: BIT |
@ -0,0 +1,13 @@ |
||||
# Basic configurations of FCCDN |
||||
|
||||
_base_: ../_base_/airchange.yaml |
||||
|
||||
save_dir: ./test_tipc/output/cd/fccdn/ |
||||
|
||||
model: !Node |
||||
type: FCCDN |
||||
|
||||
learning_rate: 0.07 |
||||
lr_decay_power: 0.6 |
||||
log_interval_steps: 100 |
||||
save_interval_epochs: 3 |
@ -0,0 +1,53 @@ |
||||
===========================train_params=========================== |
||||
model_name:cd:fccdn |
||||
python:python |
||||
gpu_list:0 |
||||
use_gpu:null|null |
||||
--precision:null |
||||
--num_epochs:lite_train_lite_infer=15|lite_train_whole_infer=15|whole_train_whole_infer=15 |
||||
--save_dir:adaptive |
||||
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4 |
||||
--model_path:null |
||||
train_model_name:best_model |
||||
train_infer_file_list:./test_tipc/data/airchange/:./test_tipc/data/airchange/eval.txt |
||||
null:null |
||||
## |
||||
trainer:norm |
||||
norm_train:test_tipc/run_task.py train cd --config ./test_tipc/configs/cd/fccdn/fccdn.yaml |
||||
pact_train:null |
||||
fpgm_train:null |
||||
distill_train:null |
||||
null:null |
||||
null:null |
||||
## |
||||
===========================eval_params=========================== |
||||
eval:null |
||||
null:null |
||||
## |
||||
===========================export_params=========================== |
||||
--save_dir:adaptive |
||||
--model_dir:adaptive |
||||
--fixed_input_shape:[1,3,256,256] |
||||
norm_export:deploy/export/export_model.py |
||||
quant_export:null |
||||
fpgm_export:null |
||||
distill_export:null |
||||
export1:null |
||||
export2:null |
||||
===========================infer_params=========================== |
||||
infer_model:null |
||||
infer_export:null |
||||
infer_quant:False |
||||
inference:test_tipc/infer.py |
||||
--device:cpu|gpu |
||||
--enable_mkldnn:True |
||||
--cpu_threads:6 |
||||
--batch_size:1 |
||||
--use_trt:False |
||||
--precision:fp32 |
||||
--model_dir:null |
||||
--file_list:null:null |
||||
--save_log_path:null |
||||
--benchmark:True |
||||
--model_name:fccdn |
||||
null:null |
@ -0,0 +1,46 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddlers |
||||
from rs_models.test_model import TestModel |
||||
|
||||
__all__ = [] |
||||
|
||||
|
||||
class TestResModel(TestModel): |
||||
def check_output(self, output, target): |
||||
output = output.numpy() |
||||
self.check_output_equal(output.shape, target.shape) |
||||
|
||||
def set_inputs(self): |
||||
def _gen_data(specs): |
||||
for spec in specs: |
||||
c = spec.get('in_channels', 3) |
||||
yield self.get_randn_tensor(c) |
||||
|
||||
self.inputs = _gen_data(self.specs) |
||||
|
||||
def set_targets(self): |
||||
def _gen_data(specs): |
||||
for spec in specs: |
||||
# XXX: Hard coding |
||||
if 'out_channels' in spec: |
||||
c = spec['out_channels'] |
||||
elif 'in_channels' in spec: |
||||
c = spec['in_channels'] |
||||
else: |
||||
c = 3 |
||||
yield [self.get_zeros_array(c)] |
||||
|
||||
self.targets = _gen_data(self.specs) |
@ -0,0 +1,41 @@ |
||||
#!/bin bash |
||||
|
||||
rm -rf /usr/local/python2.7.15/bin/python |
||||
rm -rf /usr/local/python2.7.15/bin/pip |
||||
ln -s /usr/local/bin/python3.7 /usr/local/python2.7.15/bin/python |
||||
ln -s /usr/local/bin/pip3.7 /usr/local/python2.7.15/bin/pip |
||||
export PYTHONPATH=`pwd` |
||||
|
||||
python -m pip install --upgrade pip --ignore-installed |
||||
# python -m pip install --upgrade numpy --ignore-installed |
||||
python -m pip uninstall paddlepaddle-gpu -y |
||||
if [[ ${branch} == 'develop' ]];then |
||||
echo "checkout develop !" |
||||
python -m pip install ${paddle_dev} --no-cache-dir |
||||
else |
||||
echo "checkout release !" |
||||
python -m pip install ${paddle_release} --no-cache-dir |
||||
fi |
||||
|
||||
echo -e '*****************paddle_version*****' |
||||
python -c 'import paddle;print(paddle.version.commit)' |
||||
echo -e '*****************paddleseg_version****' |
||||
git rev-parse HEAD |
||||
|
||||
pip install -r requirements.txt --ignore-installed |
||||
pip install -e . |
||||
pip install https://versaweb.dl.sourceforge.net/project/gdal-wheels-for-linux/GDAL-3.4.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl |
||||
|
||||
git clone https://github.com/LDOUBLEV/AutoLog |
||||
cd AutoLog |
||||
pip install -r requirements.txt |
||||
python setup.py bdist_wheel |
||||
pip install ./dist/auto_log*.whl |
||||
cd .. |
||||
|
||||
unset http_proxy https_proxy |
||||
|
||||
set -e |
||||
|
||||
cd tests/ |
||||
bash run_fast_tests.sh |
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue