From 5a6f19da8b65d9577248ffdfae57756be184b493 Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Thu, 25 Aug 2022 10:47:24 +0800 Subject: [PATCH] Add tools and experimental results --- examples/rs_research/README.md | 303 ++++++++++++++---- .../levircd/ablation/custom_model_c.yaml | 8 + .../levircd/ablation/custom_model_t.yaml | 8 + .../configs/levircd/custom_model.yaml | 6 + examples/rs_research/custom_model.py | 96 ++++-- examples/rs_research/custom_trainer.py | 2 +- examples/rs_research/predict_cd.py | 68 ++++ examples/rs_research/scripts/run_ablation.sh | 5 +- examples/rs_research/tools/analyze_model.py | 134 ++++++++ examples/rs_research/tools/collect_imgs.py | 61 ++++ examples/rs_research/tools/visualize_feats.py | 193 +++++++++++ 11 files changed, 792 insertions(+), 92 deletions(-) create mode 100644 examples/rs_research/configs/levircd/ablation/custom_model_c.yaml create mode 100644 examples/rs_research/configs/levircd/ablation/custom_model_t.yaml create mode 100644 examples/rs_research/configs/levircd/custom_model.yaml create mode 100644 examples/rs_research/predict_cd.py create mode 100644 examples/rs_research/tools/analyze_model.py create mode 100644 examples/rs_research/tools/collect_imgs.py create mode 100644 examples/rs_research/tools/visualize_feats.py diff --git a/examples/rs_research/README.md b/examples/rs_research/README.md index f0c2aa4..715e61b 100644 --- a/examples/rs_research/README.md +++ b/examples/rs_research/README.md @@ -12,25 +12,27 @@ cd examples/rs_research ``` +请注意,本文档仅所提供的所有指令遵循bash语法。 + ## 2 数据准备 本案例在[LEVIR-CD数据集](https://www.mdpi.com/2072-4292/12/10/1662)[1]和[synthetic images and real season-varying remote sensing images(SVCD)数据集](https://www.int-arch-photogramm-remote-sens-spatial-inf-sci.net/XLII-2/565/2018/isprs-archives-XLII-2-565-2018.pdf)[2]上开展实验。请在[LEVIR-CD数据集下载链接](https://justchenhao.github.io/LEVIR/)和[SVCD数据集下载链接](https://drive.google.com/file/d/1GX656JqqOyBi_Ef0w65kDGVto-nHrNs9/edit)分别下载这两个数据集,解压至本地目录,并执行如下指令: -```shell +```bash mkdir data/ python ../../tools/prepare_dataset/prepare_levircd.py \ - --in_dataset_dir {LEVIR-CD数据集存放目录路径} \ - --out_dataset_dir "data/levircd" \ + --in_dataset_dir "{LEVIR-CD数据集存放目录路径}" \ + --out_dataset_dir 'data/levircd' \ --crop_size 256 \ --crop_stride 256 python ../../tools/prepare_dataset/prepare_svcd.py \ - --in_dataset_dir {SVCD数据集存放目录路径} \ - --out_dataset_dir "data/svcd" + --in_dataset_dir "{SVCD数据集存放目录路径}" \ + --out_dataset_dir 'data/svcd' ``` 以上指令利用PaddleRS提供的数据集准备工具完成数据集切分、file list创建等操作。具体而言,对于LEVIR-CD数据集,使用官方的训练/验证/测试集划分,并将原始的`1024x1024`大小的影像切分为无重叠的`256x256`的小块(参考[3]中的做法);对于SVCD数据集,使用官方的训练/验证/测试集划分,不做其它额外处理。 -## 3 模型设计与验证 +## 3 模型设计 ### 3.1 问题分析与思路拟定 @@ -43,18 +45,21 @@ python ../../tools/prepare_dataset/prepare_svcd.py \ 1. 巨大的参数量意味着巨大的存储开销。在许多实际场景中,硬件资源往往是有限的,过多的模型参数将给部署造成困难。 2. 在数据有限的情况下,大模型更易遭受过拟合,其在实验数据集上看起来良好的结果也难以泛化到真实场景。 -本案例认为,上述问题的根源在于参数量与数据量的失衡所导致的特征冗余。既然模型的特征存在冗余,也即存在一部分“无用”的特征,是否存在某种手段,能够在固定模型参数量的前提下对特征进行优化,从而“榨取”小模型的更多潜力,获取更多更加有效的特征?基于这个观点,本案例的基本思路是为现有的变化检测模型添加一个“插件式”的特征优化模块,在仅引入较少额外的参数数量的情况下,实现变化特征增强。本案例计划以变化检测领域经典的FC-Siam-diff[4]为baseline网络,利用时间、空间、通道注意力模块对网络的中间层特征进行优化,从而减小特征冗余,提升检测效果。在具体的模块设计方面,对于时间与通道维度,选用论文[5]中提出的通道注意力模块;对于空间维度,选用论文[5]中提出的空间注意力模块。 +本案例认为,上述问题的根源在于参数量与数据量的失衡所导致的特征冗余。既然模型的特征存在冗余,也即存在一部分“无用”的特征,是否存在某种手段,能够在固定模型参数量的前提下对特征进行优化,从而“榨取”小模型的更多潜力,获取更多更加有效的特征?基于这个观点,本案例的基本思路是为现有的变化检测模型添加一个“插件式”的特征优化模块,在仅引入较少额外的参数数量的情况下,实现变化特征增强。本案例计划以变化检测领域经典的FC-Siam-conc[4]为baseline网络,利用通道和时间注意力模块对网络的中间层特征进行优化,从而减小特征冗余,提升检测效果。在具体的模块设计方面,选用论文[5]中提出的通道注意力模块实现通道和时间维度的特征增强。 ### 3.2 模型定义 +本小节基于PaddlePaddle框架与PaddleRS库实现[3.1节](#3.1-问题分析与思路拟定)中提出的想法。 + #### 3.2.1 自定义模型组网 -在`custom_model.py`中定义模型的宏观(macro)结构以及组成模型的各个微观(micro)模块。例如,本案例中,`custom_model.py`中定义了改进后的FC-EF结构,其核心部分实现如下: +在`custom_model.py`中定义模型的宏观(macro)结构以及组成模型的各个微观(micro)模块。本案例在`custom_model.py`中定义了改进后的FC-Siam-conc结构,其核心部分实现如下: + ```python ... # PaddleRS提供了许多开箱即用的模块,其中有对底层基础模块的封装(如conv-bn-relu结构等),也有注意力模块等较高层级的结构 from paddlers.rs_models.cd.layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity -from paddlers.rs_models.cd.layers import ChannelAttention, SpatialAttention +from paddlers.rs_models.cd.layers import ChannelAttention from attach_tools import Attach @@ -65,52 +70,90 @@ class CustomModel(nn.Layer): def __init__(self, in_channels, num_classes, - att_types='cst', + att_types='ct', use_dropout=False): super().__init__() ... + # 构建一个混合注意力模块att4,用于处理两个编码器最终输出的特征 + self.att4 = MixedAttention(C4, att_types) + + self.init_weight() + + def forward(self, t1, t2): + ... + x4d = self.upconv4(x4p) + pad4 = (0, x43_1.shape[3] - x4d.shape[3], 0, + x43_1.shape[2] - x4d.shape[2]) + x4d = F.pad(x4d, pad=pad4, mode='replicate') + # 将注意力模块接入第一个解码单元 + x43_1, x43_2 = self.att4(x43_1, x43_2) + x4d = paddle.concat([x4d, x43_1, x43_2], 1) + x43d = self.do43d(self.conv43d(x4d)) + x42d = self.do42d(self.conv42d(x43d)) + x41d = self.do41d(self.conv41d(x42d)) + ... + + +class MixedAttention(nn.Layer): + def __init__(self, in_channels, att_types='ct'): + super(MixedAttention, self).__init__() + + self.att_types = att_types # 从`att_types`参数中获取要使用的注意力类型 # 每个注意力模块都是可选的 - if 'c' in att_types: - self.att_c = ChannelAttention(C4) + if self.has_att_c: + self.att_c = ChannelAttention(in_channels, ratio=1) + # 在时间注意力模块之后增加归一化层 + # 利用BN层中的可学习参数增强模型的拟合能力 + self.norm_c1 = nn.BatchNorm(in_channels) + self.norm_c2 = nn.BatchNorm(in_channels) else: self.att_c = Identity() - if 's' in att_types: - self.att_s = SpatialAttention() - else: - self.att_s = Identity() + self.norm_c1 = Identity() + self.norm_c2 = Identity() + # 时间注意力模块部分复用通道注意力的逻辑,在`forward()`中将具体解释 - if 't' in att_types: + if has_att_t: self.att_t = ChannelAttention(2, ratio=1) else: self.att_t = Identity() - self.init_weight() + def forward(x1, x2): + # x1和x2分别是FC-Siam-conc的两路编码器提取的特征 + + if self.has_att_c: + # 首先使用通道注意力模块对特征进行优化 + # 两个时相的编码特征共享通道注意力模块,但使用各自的归一化层 + x1 = self.att_c(x1) * x1 + x1 = self.norm_c1(x1) + x2 = self.att_c(x2) * x2 + x2 = self.norm_c2(x2) + + if self.has_att_t: + b, c = x1.shape[:2] + # 为了复用通道注意力模块执行时间维度的注意力操作,首先将两个时相的特征堆叠 + y = paddle.stack([x1, x2], axis=2) + # 堆叠后的y形状为[b, c, t, h, w],其中b表示batch size,c为特征通道数,t为2(时相数目),h和w分别为特征图高宽 + # 将b和c两个维度合并,输出tensor形状为[b*c, t, h, w] + y = paddle.flatten(y, stop_axis=1) + # 此时,时间维度已经替代了原先的通道维度,将四维tensor输入ChannelAttention模块进行处理 + y = self.att_t(y) * y + # 从处理结果中分离两个时相的信息 + y = y.reshape((b, c, 2, *y.shape[2:])) + y1, y2 = y[:, :, 0], y[:, :, 1] + else: + y1, y2 = x1, x2 - def forward(self, t1, t2): - ... - # 以下是本案例在FC-EF基础上新增的部分 - # x43_1和x43_2分别是FC-EF的两路编码器提取的特征 - # 首先使用通道和空间注意力模块对特征进行优化 - x43_1 = self.att_c(x43_1) * x43_1 - x43_1 = self.att_s(x43_1) * x43_1 - x43_2 = self.att_c(x43_2) * x43_2 - x43_2 = self.att_s(x43_2) * x43_2 - # 为了复用通道注意力模块执行时间维度的注意力操作,首先将两个时相的特征堆叠 - x43 = paddle.stack([x43_1, x43_2], axis=1) - # 堆叠后的x43形状为[b, t, c, h, w],其中b表示batch size,t为2(时相数目),c为通道数,h和w分别为特征图高宽 - # 将t和c维度交换,输出tensor形状为[b, c, t, h, w] - x43 = paddle.transpose(x43, [0, 2, 1, 3, 4]) - # 将b和c两个维度合并,输出tensor形状为[b*c, t, h, w] - x43 = paddle.flatten(x43, stop_axis=1) - # 此时,时间维度已经替代了原先的通道维度,将四维tensor输入ChannelAttention模块进行处理 - x43 = self.att_t(x43) * x43 - # 从处理结果中分离两个时相的信息 - x43 = x43.reshape((x43_1.shape[0], -1, 2, *x43.shape[2:])) - x43_1, x43_2 = x43[:,:,0], x43[:,:,1] - ... - ... + return y1, y2 + + @property + def has_att_c(self): + return 'c' in self.att_types + + @property + def has_att_t(self): + return 't' in self.att_types ``` 在编写组网相关代码时请注意以下两点: @@ -132,7 +175,7 @@ class CustomTrainer(BaseChangeDetector): use_mixed_loss=False, losses=None, in_channels=3, - att_types='cst', + att_types='ct', use_dropout=False, **params): params.update({ @@ -158,46 +201,196 @@ class CustomTrainer(BaseChangeDetector): 关于训练器的更多细节请参考[API文档](https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/train.md)。 -### 3.3 消融实验 +## 4 对比实验 -#### 3.3.1 实验设置 +为了验证模型设计的有效性,通常需要开展对比实验,在一个或多个数据集上比较所提出模型与其它模型的精度和性能。在本案例中,将自定义模型与FC-EF、FC-Siam-diff、FC-Siam-conc三种结构进行比较,这三个模型均来自论文[4]。 -#### 3.3.2 编写配置文件 +### 4.1 实验过程 -#### 3.3.3 实验结果 +使用如下指令在LEVIR-CD与SVCD数据集上执行对所有参与对比的模型的训练: -VisualDL、定量指标 +```bash +bash scripts/run_benchmark.sh +``` -### 3.4 特征可视化实验 +或者,可以按照以下格式执行对某个模型在某一数据集上的训练: -## 4 对比实验 +```bash +python run_task.py train cd \ + --config "configs/{数据集名称}/{配置文件名称}" \ + 2>&1 | tee "{日志路径}" +``` + +训练完成后,使用如下指令对验证集上最优的模型在测试集上计算指标: + +```bash +python run_task.py eval cd \ + --config "configs/{数据集名称}/{配置文件名称}" \ + --datasets.eval.args.file_list "data/{数据集名称}/test.txt" \ + --resume_checkpoint "exp/{数据集名称}/{模型名称}/best_model" +``` + +训练程序默认开启VisualDL日志记录功能。训练过程中或训练完成后,可使用VisualDL观察损失函数和精度指标的变化情况。在PaddleRS中使用VisualDL的方式请参考[使用教程](https://github.com/PaddlePaddle/PaddleRS/blob/develop/tutorials/train/README.md#visualdl%E5%8F%AF%E8%A7%86%E5%8C%96%E8%AE%AD%E7%BB%83%E6%8C%87%E6%A0%87)。 + +在训练和精度指标验证完成后,可以通过如下指令保存模型输出的二值变化图: + +```bash +python predict_cd.py \ + --model_dir "exp/{数据集名称}/{模型名称}/best_model" \ + --data_dir "data/{数据集名称}" \ + --file_list "data/{数据集名称}/test.txt" \ + --save_dir "exp/predict/{数据集名称}/{模型名称}" +``` + +之后,可在`exp/predict/{数据集名称}/{模型名称}`目录查看保存的输出结果。 + +可以通过`tools/collect_imgs.py`脚本将输入图像、真值标签以及多个模型的预测结果放置在一个目录下以便于观察比较。该脚本接受三个命令行选项: +- 使用`--globs`指定一系列通配符(可用于Python的[`glob.glob()`函数](https://docs.python.org/zh-cn/3/library/glob.html#glob.glob),用于匹配需要收集的图像; +- 使用`--tags`为`--globs`中的每一项指定一个别名,在存储目录中,相应的图像名将被替换为存储的别名; +- 使用`--save_dir`指定输出目录路径,若目录不存在将被自动创建。 -### 4.1 确定对比算法 +例如,对于LEVIR-CD数据集,执行如下指令: -### 4.2 准备对比算法配置文件 +```bash +python tools/collect_imgs.py \ + --globs "data/levircd/LEVIR-CD/test/A/*/*.png" "data/levircd/LEVIR-CD/test/B/*/*.png" "data/levircd/LEVIR-CD/test/label/*/*.png" \ + "exp/predict/levircd/fc_ef/*.png" "exp/predict/levircd/fc_siam_conc/*.png" "exp/predict/levircd/fc_siam_diff/*.png" \ + "exp/predict/levircd/custom_model/*.png" \ + --tags 'A' 'B' 'GT' \ + 'fc_ef' 'fc_siam_conc' 'fc_siam_diff' \ + 'custom_model' \ + --save_dir "exp/collect/levircd" +``` + +执行完毕后,可在`exp/collect/levircd`目录中找到两个时相的输入影像、真值标签以及各个模型的预测结果。当新增模型后,可以再次调用`tools/collect_imgs.py`脚本补充结果到`exp/collect/levircd`目录中: + +```bash +python tools/collect_imgs.py --globs "exp/predict/levircd/{新增模型名称}/*.png" --tags '{新增模型名称}' --save_dir "exp/collect/levircd" +``` + +对于SVCD数据集,执行如下指令: + +```bash +python tools/collect_imgs.py \ + --globs "data/svcd/ChangeDetectionDataset/Real/subset/test/A/*.jpg" "data/svcd/ChangeDetectionDataset/Real/subset/test/B/*.jpg" "data/svcd/ChangeDetectionDataset/Real/subset/test/OUT/*.jpg" \ + "exp/predict/svcd/fc_ef/*.png" "exp/predict/svcd/fc_siam_conc/*.png" "exp/predict/svcd/fc_siam_diff/*.png" \ + "exp/predict/svcd/custom_model/*.png" \ + --tags 'A' 'B' 'GT' \ + 'fc_ef' 'fc_siam_conc' 'fc_siam_diff' \ + 'custom_model' \ + --save_dir "exp/collect/svcd" +``` + +此外,为了从精度和性能两个方面综合评估变化检测算法,可以通过如下指令计算变化检测模型的[浮点计算数(floating point operations, FLOPs)](https://blog.csdn.net/IT_flying625/article/details/104898152)和模型参数量: + +```bash +python tools/analyze_model.py --model_dir "exp/{数据集名称}/{模型名称}/best_model" +``` -### 4.3 实验结果 +### 4.2 实验结果 -#### 4.3.1 LEVIR-CD数据集上的对比结果 +本案例使用变化类的[交并比(intersection over union, IoU)](https://paddlepedia.readthedocs.io/en/latest/tutorials/computer_vision/semantic_segmentation/Overview/Overview.html#id6)和[F1分数](https://baike.baidu.com/item/F1%E5%88%86%E6%95%B0/13864979)作为定量评价指标。在每个数据集上,从目视效果和定量指标两个方面对算法效果进行评判。 +#### 4.2.1 LEVIR-CD数据集上的对比结果 **目视效果对比** +|时相1影像|时相2影像|FC-EF|FC-Siam-diff|FC-Siam-conc|CustomModel|真值标签| +|:-:|:-:|:-:|:-:|:-:|:-:|:-:| +|![]()|![]()|![]()|![]()|![]()|![]()|![]()| + **定量指标对比** -#### 4.3.2 SVCD数据集上的对比结果 +|模型名称|FLOPs(G)|参数量(M)|IoU%|F1%| +|:-:|:-:|:-:|:-:|:-:| +|FC-EF|3.57|1.35|79.05|88.30| +|FC-Siam-diff|4.71|1.35|81.33|89.70| +|FC-Siam-conc|5.31|1.55|81.31|89.69| +|CustomModel|5.31|1.58|**82.27**|**90.27**| + +#### 4.2.2 SVCD数据集上的对比结果 **目视效果对比** +|时相1影像|时相2影像|FC-EF|FC-Siam-diff|FC-Siam-conc|CustomModel|真值标签| +|:-:|:-:|:-:|:-:|:-:|:-:|:-:| +|![]()|![]()|![]()|![]()|![]()|![]()|![]()| + **定量指标对比** +|模型名称|FLOPs(G)|参数量(M)|IoU%|F1%| +|:-:|:-:|:-:|:-:|:-:| +|FC-EF|3.57|1.35|84.11|91.37| +|FC-Siam-diff|4.71|1.35|88.75|94.04| +|FC-Siam-conc|5.31|1.55|88.29|93.78| +|CustomModel|5.31|1.58||| +## 5 消融实验 + +在科研过程中,为了验证在baseline上所做修改的有效性,常常需要开展消融实验。例如,在本案例中,自定义模型在FC-Siam-conc模型的基础上添加了通道和时间两种注意力模块,因此需要通过消融实验探讨各个注意力模块对最终精度的贡献。具体而言,包括以下4种实验情形(配置文件均存储在`configs/levircd/ablation`目录): + +1. 基础情况:不使用任何注意力模块,即baseline模型FC-Siam-conc。 +2. 仅添加通道注意力模块,对应的配置文件名称为`custom_model_c.yaml`。 +3. 仅添加时间注意力模块,对应的配置文件名称为`custom_model_t.yaml`。 +4. 标准情况:同时添加通道和时间注意力模块的完整模型。 + +其中第1和第4个模型,即baseline和完整模型,在[第4节](#4-对比实验)中已经得到了训练、验证和测试。因此,本节只需要关注情形2、3。 + +### 5.1 实验过程 + +使用如下指令执行全部消融模型的训练: + +```bash +bash scripts/run_ablation.sh +``` + +或者,可以按照以下格式执行对某一个模型的训练: + +```bash +python run_task.py train cd \ + --config "configs/levircd/ablation/{配置文件名称}" \ + 2>&1 | tee {日志路径} +``` + +训练完成后,使用如下指令对验证集上最优的模型在测试集上计算指标: + +```bash +python run_task.py eval cd \ + --config "configs/levircd/ablation/{配置文件名称}" \ + --datasets.eval.args.file_list data/levircd/test.txt \ + --resume_checkpoint "exp/levircd/ablation/{消融模型名称}/best_model" +``` + +注意,形如`custom_model_c.yaml`的配置文件默认对应的消融模型名称为`att_c`。 + +训练程序默认开启VisualDL日志记录功能。训练过程中或训练完成后,可使用VisualDL观察损失函数和精度指标的变化情况。在PaddleRS中使用VisualDL的方式请参考[使用教程](https://github.com/PaddlePaddle/PaddleRS/blob/develop/tutorials/train/README.md#visualdl%E5%8F%AF%E8%A7%86%E5%8C%96%E8%AE%AD%E7%BB%83%E6%8C%87%E6%A0%87)。 + +### 5.2 实验结果 + +实验得到的定量指标如下表所示: + +|通道注意力模块|时间注意力模块|IoU%|F1%| +|:-:|:-:|:-:|:-:| +|||81.31|89.69| +|✓||81.32|89.70| +||✓|81.61|89.88| +|✓|✓|**82.27**|**90.27**| + +其中,最高的指标用粗体表示。从表中数据可知,有限。 + +## 6 特征可视化实验 + +为了更好地探究。 + ## 5 总结与展望 ### 5.1 总结 +本案例以为经典的FC-Siam-conc模型添加注意力模块为例,演示了使用PaddleRS开展科研实验的典型流程。 +- 精度提升十分有限,算法设计。 + ### 5.2 展望 - 本案例对所有参与比较的算法使用了相同的训练超参数,但由于模型之间存在差异,使用统一的超参训练往往难以保证所有模型都能取得较好的效果。在后续工作中,可以对每个对比算法进行调参,使其获得最优精度。 -- 在评估算法效果时,仅仅对比了精度指标,而未对耗时、模型大小、FLOPs等指标进行考量。后续应当从精度和性能两个方面对算法进行综合评估。 +- 本案例只作为 ## 参考文献 diff --git a/examples/rs_research/configs/levircd/ablation/custom_model_c.yaml b/examples/rs_research/configs/levircd/ablation/custom_model_c.yaml new file mode 100644 index 0000000..d66cf44 --- /dev/null +++ b/examples/rs_research/configs/levircd/ablation/custom_model_c.yaml @@ -0,0 +1,8 @@ +_base_: ../levircd.yaml + +save_dir: ./exp/levircd/ablation/att_c/ + +model: !Node + type: CustomTrainer + args: + att_types: c diff --git a/examples/rs_research/configs/levircd/ablation/custom_model_t.yaml b/examples/rs_research/configs/levircd/ablation/custom_model_t.yaml new file mode 100644 index 0000000..028953b --- /dev/null +++ b/examples/rs_research/configs/levircd/ablation/custom_model_t.yaml @@ -0,0 +1,8 @@ +_base_: ../levircd.yaml + +save_dir: ./exp/levircd/ablation/att_t/ + +model: !Node + type: CustomTrainer + args: + att_types: t diff --git a/examples/rs_research/configs/levircd/custom_model.yaml b/examples/rs_research/configs/levircd/custom_model.yaml new file mode 100644 index 0000000..07699c4 --- /dev/null +++ b/examples/rs_research/configs/levircd/custom_model.yaml @@ -0,0 +1,6 @@ +_base_: ./levircd.yaml + +save_dir: ./exp/levircd/custom_model/ + +model: !Node + type: CustomTrainer diff --git a/examples/rs_research/custom_model.py b/examples/rs_research/custom_model.py index 63e2f60..bd11198 100644 --- a/examples/rs_research/custom_model.py +++ b/examples/rs_research/custom_model.py @@ -3,7 +3,7 @@ import paddle.nn as nn import paddle.nn.functional as F import paddlers from paddlers.rs_models.cd.layers import Conv3x3, MaxPool2x2, ConvTransposed3x3, Identity -from paddlers.rs_models.cd.layers import ChannelAttention, SpatialAttention +from paddlers.rs_models.cd.layers import ChannelAttention from attach_tools import Attach @@ -15,7 +15,7 @@ class CustomModel(nn.Layer): def __init__(self, in_channels, num_classes, - att_types='cst', + att_types='ct', use_dropout=False): super(CustomModel, self).__init__() @@ -53,7 +53,7 @@ class CustomModel(nn.Layer): self.upconv4 = ConvTransposed3x3(C4, C4, output_padding=1) - self.conv43d = Conv3x3(C5, C4, norm=True, act=True) + self.conv43d = Conv3x3(C5 + C4, C4, norm=True, act=True) self.do43d = self._make_dropout() self.conv42d = Conv3x3(C4, C4, norm=True, act=True) self.do42d = self._make_dropout() @@ -62,7 +62,7 @@ class CustomModel(nn.Layer): self.upconv3 = ConvTransposed3x3(C3, C3, output_padding=1) - self.conv33d = Conv3x3(C4, C3, norm=True, act=True) + self.conv33d = Conv3x3(C4 + C3, C3, norm=True, act=True) self.do33d = self._make_dropout() self.conv32d = Conv3x3(C3, C3, norm=True, act=True) self.do32d = self._make_dropout() @@ -71,32 +71,21 @@ class CustomModel(nn.Layer): self.upconv2 = ConvTransposed3x3(C2, C2, output_padding=1) - self.conv22d = Conv3x3(C3, C2, norm=True, act=True) + self.conv22d = Conv3x3(C3 + C2, C2, norm=True, act=True) self.do22d = self._make_dropout() self.conv21d = Conv3x3(C2, C1, norm=True, act=True) self.do21d = self._make_dropout() self.upconv1 = ConvTransposed3x3(C1, C1, output_padding=1) - self.conv12d = Conv3x3(C2, C1, norm=True, act=True) + self.conv12d = Conv3x3(C2 + C1, C1, norm=True, act=True) self.do12d = self._make_dropout() self.conv11d = Conv3x3(C1, num_classes) - if 'c' in att_types: - self.att_c = ChannelAttention(C4) - else: - self.att_c = Identity() - if 's' in att_types: - self.att_s = SpatialAttention() - else: - self.att_s = Identity() - if 't' in att_types: - self.att_t = ChannelAttention(2, ratio=1) - else: - self.att_t = Identity() - self.init_weight() + self.att4 = MixedAttention(C4, att_types) + def forward(self, t1, t2): # Encode t1 # Stage 1 @@ -144,25 +133,14 @@ class CustomModel(nn.Layer): x43_2 = self.do43(self.conv43(x42)) x4p = self.pool4(x43_2) - # Attend - x43_1 = self.att_c(x43_1) * x43_1 - x43_1 = self.att_s(x43_1) * x43_1 - x43_2 = self.att_c(x43_2) * x43_2 - x43_2 = self.att_s(x43_2) * x43_2 - x43 = paddle.stack([x43_1, x43_2], axis=1) - x43 = paddle.transpose(x43, [0, 2, 1, 3, 4]) - x43 = paddle.flatten(x43, stop_axis=1) - x43 = self.att_t(x43) * x43 - x43 = x43.reshape((x43_1.shape[0], -1, 2, *x43.shape[2:])) - x43_1, x43_2 = x43[:, :, 0], x43[:, :, 1] - # Decode # Stage 4d x4d = self.upconv4(x4p) pad4 = (0, x43_1.shape[3] - x4d.shape[3], 0, x43_1.shape[2] - x4d.shape[2]) x4d = F.pad(x4d, pad=pad4, mode='replicate') - x4d = paddle.concat([x4d, paddle.abs(x43_1 - x43_2)], 1) + x43_1, x43_2 = self.att4(x43_1, x43_2) + x4d = paddle.concat([x4d, x43_1, x43_2], 1) x43d = self.do43d(self.conv43d(x4d)) x42d = self.do42d(self.conv42d(x43d)) x41d = self.do41d(self.conv41d(x42d)) @@ -172,7 +150,7 @@ class CustomModel(nn.Layer): pad3 = (0, x33_1.shape[3] - x3d.shape[3], 0, x33_1.shape[2] - x3d.shape[2]) x3d = F.pad(x3d, pad=pad3, mode='replicate') - x3d = paddle.concat([x3d, paddle.abs(x33_1 - x33_2)], 1) + x3d = paddle.concat([x3d, x33_1, x33_2], 1) x33d = self.do33d(self.conv33d(x3d)) x32d = self.do32d(self.conv32d(x33d)) x31d = self.do31d(self.conv31d(x32d)) @@ -182,7 +160,7 @@ class CustomModel(nn.Layer): pad2 = (0, x22_1.shape[3] - x2d.shape[3], 0, x22_1.shape[2] - x2d.shape[2]) x2d = F.pad(x2d, pad=pad2, mode='replicate') - x2d = paddle.concat([x2d, paddle.abs(x22_1 - x22_2)], 1) + x2d = paddle.concat([x2d, x22_1, x22_2], 1) x22d = self.do22d(self.conv22d(x2d)) x21d = self.do21d(self.conv21d(x22d)) @@ -191,7 +169,7 @@ class CustomModel(nn.Layer): pad1 = (0, x12_1.shape[3] - x1d.shape[3], 0, x12_1.shape[2] - x1d.shape[2]) x1d = F.pad(x1d, pad=pad1, mode='replicate') - x1d = paddle.concat([x1d, paddle.abs(x12_1 - x12_2)], 1) + x1d = paddle.concat([x1d, x12_1, x12_2], 1) x12d = self.do12d(self.conv12d(x1d)) x11d = self.conv11d(x12d) @@ -205,3 +183,51 @@ class CustomModel(nn.Layer): return nn.Dropout2D(p=0.2) else: return Identity() + + +class MixedAttention(nn.Layer): + def __init__(self, in_channels, att_types='ct'): + super(MixedAttention, self).__init__() + + self.att_types = att_types + + if self.has_att_c: + self.att_c = ChannelAttention(in_channels, ratio=1) + self.norm_c1 = nn.BatchNorm(in_channels) + self.norm_c2 = nn.BatchNorm(in_channels) + else: + self.att_c = Identity() + self.norm_c1 = Identity() + self.norm_c2 = Identity() + + if self.has_att_t: + self.att_t = ChannelAttention(2, ratio=1) + else: + self.att_t = Identity() + + def forward(self, x1, x2): + if self.has_att_c: + x1 = self.att_c(x1) * x1 + x1 = self.norm_c1(x1) + x2 = self.att_c(x2) * x2 + x2 = self.norm_c2(x2) + + if self.has_att_t: + b, c = x1.shape[:2] + y = paddle.stack([x1, x2], axis=2) + y = paddle.flatten(y, stop_axis=1) + y = self.att_t(y) * y + y = y.reshape((b, c, 2, *y.shape[2:])) + y1, y2 = y[:, :, 0], y[:, :, 1] + else: + y1, y2 = x1, x2 + + return y1, y2 + + @property + def has_att_c(self): + return 'c' in self.att_types + + @property + def has_att_t(self): + return 't' in self.att_types diff --git a/examples/rs_research/custom_trainer.py b/examples/rs_research/custom_trainer.py index b3ddb41..f0bdb3f 100644 --- a/examples/rs_research/custom_trainer.py +++ b/examples/rs_research/custom_trainer.py @@ -13,7 +13,7 @@ class CustomTrainer(BaseChangeDetector): use_mixed_loss=False, losses=None, in_channels=3, - att_types='cst', + att_types='ct', use_dropout=False, **params): params.update({ diff --git a/examples/rs_research/predict_cd.py b/examples/rs_research/predict_cd.py new file mode 100644 index 0000000..df0b8b4 --- /dev/null +++ b/examples/rs_research/predict_cd.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python + +import argparse +import os +import os.path as osp + +import cv2 +import paddle +import paddlers +from tqdm import tqdm + +import custom_model +import custom_trainer + + +def read_file_list(file_list, sep=' '): + with open(file_list, 'r') as f: + for line in f: + line = line.strip() + parts = line.split(sep) + yield parts + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_dir", default=None, type=str, help="Path of saved model.") + parser.add_argument("--data_dir", type=str, help="Path of input dataset.") + parser.add_argument("--file_list", type=str, help="Path of file list.") + parser.add_argument( + "--save_dir", + default='./exp/predict', + type=str, + help="Path of directory to save prediction results.") + parser.add_argument( + "--ext", + default='.png', + type=str, + help="Extension name of the saved image file.") + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + + model = paddlers.tasks.load_model(args.model_dir) + + if not osp.exists(args.save_dir): + os.makedirs(args.save_dir) + + with paddle.no_grad(): + for parts in tqdm(read_file_list(args.file_list)): + im1_path = osp.join(args.data_dir, parts[0]) + im2_path = osp.join(args.data_dir, parts[1]) + + pred = model.predict((im1_path, im2_path)) + cm = pred['label_map'] + # {0,1} -> {0,255} + cm[cm > 0] = 255 + cm = cm.astype('uint8') + + if len(parts) > 2: + name = osp.basename(parts[2]) + else: + name = osp.basename(im1_path) + name = osp.splitext(name)[0] + args.ext + out_path = osp.join(args.save_dir, name) + cv2.imwrite(out_path, cm) diff --git a/examples/rs_research/scripts/run_ablation.sh b/examples/rs_research/scripts/run_ablation.sh index d54066a..b7848ef 100644 --- a/examples/rs_research/scripts/run_ablation.sh +++ b/examples/rs_research/scripts/run_ablation.sh @@ -2,7 +2,7 @@ set -e -CONFIG_DIR='configs/levircd/custom_model' +CONFIG_DIR='configs/levircd/ablation' LOG_DIR='exp/logs/ablation' mkdir -p "${LOG_DIR}" @@ -12,6 +12,9 @@ for config_file in $(ls "${CONFIG_DIR}"/*.yaml); do printf '=%.0s' {1..100} && echo echo -e "\033[33m ${config_file} \033[0m" printf '=%.0s' {1..100} && echo + if [ ${filename} = 'custom_model_cs.yaml' ] || [ ${filename} = 'custom_model_ct.yaml' ]; then + continue + fi python run_task.py train cd --config "${config_file}" 2>&1 | tee "${LOG_DIR}/${filename%.*}.log" echo done diff --git a/examples/rs_research/tools/analyze_model.py b/examples/rs_research/tools/analyze_model.py new file mode 100644 index 0000000..3eec5b4 --- /dev/null +++ b/examples/rs_research/tools/analyze_model.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python + +# Refer to https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.6/tools/analyze_model.py + +import argparse +import os +import os.path as osp +import sys + +import paddle +import numpy as np +import paddlers +from paddle.hapi.dynamic_flops import (count_parameters, register_hooks, + count_io_info) +from paddle.hapi.static_flops import Table + +_dir = osp.dirname(osp.abspath(__file__)) +sys.path.append(osp.abspath(osp.join(_dir, '../'))) +import custom_model +import custom_trainer + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_dir", default=None, type=str, help="Path of saved model.") + parser.add_argument( + "--input_shape", + nargs='+', + type=int, + default=[1, 3, 256, 256], + help="Shape of each input tensor.") + return parser.parse_args() + + +def analyze(model, inputs, custom_ops=None, print_detail=False): + handler_collection = [] + types_collection = set() + if custom_ops is None: + custom_ops = {} + + def add_hooks(m): + if len(list(m.children())) > 0: + return + m.register_buffer('total_ops', paddle.zeros([1], dtype='int64')) + m.register_buffer('total_params', paddle.zeros([1], dtype='int64')) + m_type = type(m) + + flops_fn = None + if m_type in custom_ops: + flops_fn = custom_ops[m_type] + if m_type not in types_collection: + print("Customized function has been applied to {}".format( + m_type)) + elif m_type in register_hooks: + flops_fn = register_hooks[m_type] + if m_type not in types_collection: + print("{}'s FLOPs metric has been counted".format(m_type)) + else: + if m_type not in types_collection: + print( + "Cannot find suitable counting function for {}. Treat it as zero FLOPs." + .format(m_type)) + + if flops_fn is not None: + flops_handler = m.register_forward_post_hook(flops_fn) + handler_collection.append(flops_handler) + params_handler = m.register_forward_post_hook(count_parameters) + io_handler = m.register_forward_post_hook(count_io_info) + handler_collection.append(params_handler) + handler_collection.append(io_handler) + types_collection.add(m_type) + + training = model.training + + model.eval() + model.apply(add_hooks) + + with paddle.framework.no_grad(): + model(*inputs) + + total_ops = 0 + total_params = 0 + for m in model.sublayers(): + if len(list(m.children())) > 0: + continue + if set(['total_ops', 'total_params', 'input_shape', + 'output_shape']).issubset(set(list(m._buffers.keys()))): + total_ops += m.total_ops + total_params += m.total_params + + if training: + model.train() + for handler in handler_collection: + handler.remove() + + table = Table( + ["Layer Name", "Input Shape", "Output Shape", "Params(M)", "FLOPs(G)"]) + + for n, m in model.named_sublayers(): + if len(list(m.children())) > 0: + continue + if set(['total_ops', 'total_params', 'input_shape', + 'output_shape']).issubset(set(list(m._buffers.keys()))): + table.add_row([ + m.full_name(), list(m.input_shape.numpy()), + list(m.output_shape.numpy()), + round(float(m.total_params / 1e6), 3), + round(float(m.total_ops / 1e9), 3) + ]) + m._buffers.pop("total_ops") + m._buffers.pop("total_params") + m._buffers.pop('input_shape') + m._buffers.pop('output_shape') + if print_detail: + table.print_table() + print('Total FLOPs: {}G Total Params: {}M'.format( + round(float(total_ops / 1e9), 3), round(float(total_params / 1e6), 3))) + return int(total_ops) + + +if __name__ == '__main__': + args = parse_args() + + # Enforce the use of CPU + paddle.set_device('cpu') + + model = paddlers.tasks.load_model(args.model_dir) + net = model.net + + # Construct bi-temporal inputs + inputs = [paddle.randn(args.input_shape), paddle.randn(args.input_shape)] + + analyze(model.net, inputs) diff --git a/examples/rs_research/tools/collect_imgs.py b/examples/rs_research/tools/collect_imgs.py new file mode 100644 index 0000000..2e7d0ff --- /dev/null +++ b/examples/rs_research/tools/collect_imgs.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python + +import argparse +import os +import os.path as osp +import shutil +from glob import glob + +from tqdm import tqdm + + +def get_subdir_name(src_path): + basename = osp.basename(src_path) + subdir_name, _ = osp.splitext(basename) + return subdir_name + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", + default='copy', + type=str, + choices=['copy', 'link'], + help="Copy or link images.") + parser.add_argument( + "--globs", + nargs='+', + type=str, + help="Glob patterns used to find the images to be copied.") + parser.add_argument( + "--tags", nargs='+', type=str, help="Tags of each source directory.") + parser.add_argument( + "--save_dir", + default='./', + type=str, + help="Path of directory to save collected results.") + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + + if len(args.globs) != len(args.tags): + raise ValueError( + "The number of globs does not match the number of tags!") + + for pat, tag in zip(args.globs, args.tags): + im_paths = glob(pat) + print(f"Glob: {pat}\tTag: {tag}") + for p in tqdm(im_paths): + subdir_name = get_subdir_name(p) + ext = osp.splitext(p)[1] + subdir_path = osp.join(args.save_dir, subdir_name) + subdir_path = osp.abspath(osp.normpath(subdir_path)) + if not osp.exists(subdir_path): + os.makedirs(subdir_path) + if args.mode == 'copy': + shutil.copyfile(p, osp.join(subdir_path, tag + ext)) + elif args.mode == 'link': + os.symlink(p, osp.join(subdir_path, tag + ext)) diff --git a/examples/rs_research/tools/visualize_feats.py b/examples/rs_research/tools/visualize_feats.py new file mode 100644 index 0000000..62ba93c --- /dev/null +++ b/examples/rs_research/tools/visualize_feats.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python + +import argparse +import sys +import os +import os.path as osp +from collections import OrderedDict + +import numpy as np +import cv2 +import paddle +import paddlers + +_dir = osp.dirname(osp.abspath(__file__)) +sys.path.append(osp.abspath(osp.join(_dir, '../'))) +import custom_model +import custom_trainer + +FILENAME_PATTERN = "{key}_{idx}_vis.png" + + +class FeatureContainer: + def __init__(self): + self._dict = OrderedDict() + + def __setitem__(self, key, val): + if key not in self._dict: + self._dict[key] = list() + self._dict[key].append(val) + + def __getitem__(self, key): + return self._dict[key] + + def __repr__(self): + return self._dict.__repr__() + + def items(self): + return self._dict.items() + + def keys(self): + return self._dict.keys() + + def values(self): + return self._dict.values() + + +class HookHelper: + def __init__(self, model, fetch_dict, out_dict, hook_type='forward_out'): + # XXX: A HookHelper object should only be used as a context manager and should not + # persist in memory since it may keep references to some very large objects. + self.model = model + self.fetch_dict = fetch_dict + self.out_dict = out_dict + self._handles = [] + self.hook_type = hook_type + + def __enter__(self): + def _hook_proto(x, entry): + # `x` should be a tensor or a tuple; + # entry is expected to be a string or a non-nested tuple. + if isinstance(entry, tuple): + for key, f in zip(entry, x): + self.out_dict[key] = f.detach().clone() + else: + self.out_dict[entry] = x.detach().clone() + + if self.hook_type == 'forward_in': + # NOTE: Register forward hooks for LAYERs + for name, layer in self.model.named_sublayers(): + if name in self.fetch_dict: + entry = self.fetch_dict[name] + self._handles.append( + layer.register_forward_pre_hook( + lambda l, x, entry=entry: + # x is a tuple + _hook_proto(x[0] if len(x)==1 else x, entry) + ) + ) + elif self.hook_type == 'forward_out': + # NOTE: Register forward hooks for LAYERs. + for name, module in self.model.named_sublayers(): + if name in self.fetch_dict: + entry = self.fetch_dict[name] + self._handles.append( + module.register_forward_post_hook( + lambda l, x, y, entry=entry: + # y is a tensor or a tuple + _hook_proto(y, entry) + ) + ) + elif self.hook_type == 'backward': + # NOTE: Register backward hooks for TENSORs. + for name, param in self.model.named_parameters(): + if name in self.fetch_dict: + entry = self.fetch_dict[name] + self._handles.append( + param.register_hook( + lambda grad, entry=entry: _hook_proto(grad, entry))) + else: + raise RuntimeError("Hook type is not implemented.") + + def __exit__(self, exc_type, exc_val, ext_tb): + for handle in self._handles: + handle.remove() + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_dir", default=None, type=str, help="Path of saved model.") + parser.add_argument( + "--hook_type", default='forward_out', type=str, help="Type of hook.") + parser.add_argument( + "--layer_names", + nargs='+', + default=[], + type=str, + help="Layers that accepts or produces the features to visualize.") + parser.add_argument( + "--im_paths", nargs='+', type=str, help="Paths of input images.") + parser.add_argument( + "--save_dir", + type=str, + help="Path of directory to save prediction results.") + parser.add_argument( + "--to_pseudo_color", + action='store_true', + help="Whether to save pseudo-color images.") + parser.add_argument( + "--output_size", + nargs='+', + type=int, + default=None, + help="Resize the visualized image to `output_size`.") + return parser.parse_args() + + +def normalize_minmax(x): + EPS = 1e-32 + return (x - x.min()) / (x.max() - x.min() + EPS) + + +def quantize_8bit(x): + # [0.0,1.0] float => [0,255] uint8 + # or [0,1] int => [0,255] uint8 + return (x * 255).astype('uint8') + + +def to_pseudo_color(gray, color_map=cv2.COLORMAP_JET): + return cv2.applyColorMap(gray, color_map) + + +def process_fetched_feat(feat, to_pcolor=True): + # Convert tensor to array + feat = feat.squeeze(0).numpy() + # Average along channel dimension + feat = normalize_minmax(feat.mean(0)) + feat = quantize_8bit(feat) + if to_pcolor: + feat = to_pseudo_color(feat) + return feat + + +if __name__ == '__main__': + args = parse_args() + + # Load model + model = paddlers.tasks.load_model(args.model_dir) + + fetch_dict = dict(zip(args.layer_names, args.layer_names)) + out_dict = FeatureContainer() + + with HookHelper(model.net, fetch_dict, out_dict, hook_type=args.hook_type): + if len(args.im_paths) == 1: + model.predict(args.im_paths[0]) + else: + if len(args.im_paths) != 2: + raise ValueError + model.predict(tuple(args.im_paths)) + + if not osp.exists(args.save_dir): + os.makedirs(args.save_dir) + + for key, feats in out_dict.items(): + for idx, feat in enumerate(feats): + im_vis = process_fetched_feat(feat, to_pcolor=args.to_pseudo_color) + if args.output_size is not None: + im_vis = cv2.resize(im_vis, tuple(args.output_size)) + out_path = osp.join( + args.save_dir, + FILENAME_PATTERN.format( + key=key.replace('.', '_'), idx=idx)) + cv2.imwrite(out_path, im_vis)