【论文复现赛】ChangeFormer (#13)

精度验收通过,代码符合规范,论文复现成功。
own
kongdebug 2 years ago committed by GitHub
parent ea9bb47b62
commit 9e3a2611e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      paddlers/rs_models/cd/__init__.py
  2. 1001
      paddlers/rs_models/cd/changeformer.py
  3. 71
      paddlers/rs_models/cd/layers/pd_timm.py
  4. 22
      paddlers/tasks/change_detector.py
  5. 8
      test_tipc/configs/cd/changeformer/changeformer.yaml
  6. 53
      test_tipc/configs/cd/changeformer/train_infer_python.txt
  7. 15
      tests/rs_models/test_cd_models.py
  8. 94
      tutorials/train/change_detection/changeformer.py

@ -22,3 +22,4 @@ from .changestar import ChangeStar
from .fc_ef import FCEarlyFusion
from .fc_siam_conc import FCSiamConc
from .fc_siam_diff import FCSiamDiff
from .changeformer import ChangeFormer

File diff suppressed because it is too large Load Diff

@ -0,0 +1,71 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License..
from itertools import repeat
import collections.abc
import paddle
import paddle.nn as nn
"""
Droppath, reimplement from https://github.com/yueatsprograms/Stochastic_Depth
"""
class DropPath(nn.Layer):
"""DropPath class"""
def __init__(self, drop_prob=None):
super().__init__()
self.drop_prob = drop_prob
def drop_path(self, inputs):
"""drop path op
Args:
input: tensor with arbitrary shape
drop_prob: float number of drop path probability, default: 0.0
training: bool, if current mode is training, default: False
Returns:
output: output tensor after drop path
"""
# if prob is 0 or eval mode, return original input
if self.drop_prob == 0. or not self.training:
return inputs
keep_prob = 1 - self.drop_prob
keep_prob = paddle.to_tensor(keep_prob, dtype='float32')
shape = (inputs.shape[0], ) + (1, ) * (inputs.ndim - 1
) # shape=(N, 1, 1, 1)
random_tensor = keep_prob + paddle.rand(shape, dtype=inputs.dtype)
random_tensor = random_tensor.floor() # mask
output = inputs.divide(
keep_prob) * random_tensor # divide to keep same output expectation
return output
def forward(self, inputs):
return self.drop_path(inputs)
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple

@ -36,7 +36,7 @@ from .utils import seg_metrics as metrics
__all__ = [
"CDNet", "FCEarlyFusion", "FCSiamConc", "FCSiamDiff", "STANet", "BIT",
"SNUNet", "DSIFN", "DSAMNet", "ChangeStar"
"SNUNet", "DSIFN", "DSAMNet", "ChangeStar", "ChangeFormer"
]
@ -1041,3 +1041,23 @@ class ChangeStar(BaseChangeDetector):
raise ValueError(
f"Currently `use_mixed_loss` must be set to False for {self.__class__}"
)
class ChangeFormer(BaseChangeDetector):
def __init__(self,
in_channels=3,
num_classes=2,
decoder_softmax=False,
embed_dim=256,
use_mixed_loss=False,
**params):
params.update({
'in_channels': in_channels,
'embed_dim': embed_dim,
'decoder_softmax': decoder_softmax
})
super(ChangeFormer, self).__init__(
model_name='ChangeFormer',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
**params)

@ -0,0 +1,8 @@
# Basic configurations of ChangeFormer
_base_: ../_base_/airchange.yaml
save_dir: ./test_tipc/output/cd/changeformer/
model: !Node
type: ChangeFormer

@ -0,0 +1,53 @@
===========================train_params===========================
model_name:cd:changeformer
python:python
gpu_list:0|0,1
use_gpu:null|null
--precision:null
--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=10
--save_dir:adaptive
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4
--model_path:null
train_model_name:best_model
train_infer_file_list:./test_tipc/data/airchange/:./test_tipc/data/airchange/eval.txt
null:null
##
trainer:norm
norm_train:test_tipc/run_task.py train cd --config ./test_tipc/configs/cd/changeformer/changeformer.yaml
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================export_params===========================
--save_dir:adaptive
--model_dir:adaptive
--fixed_input_shape:[1,3,256,256]
norm_export:deploy/export/export_model.py
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
===========================infer_params===========================
infer_model:null
infer_export:null
infer_quant:False
inference:test_tipc/infer.py
--device:cpu|gpu
--enable_mkldnn:True
--cpu_threads:6
--batch_size:1
--use_trt:False
--precision:fp32
--model_dir:null
--file_list:null:null
--save_log_path:null
--benchmark:True
--model_name:changeformer
null:null

@ -21,7 +21,8 @@ from rs_models.test_model import TestModel
__all__ = [
'TestBITModel', 'TestCDNetModel', 'TestChangeStarModel', 'TestDSAMNetModel',
'TestDSIFNModel', 'TestFCEarlyFusionModel', 'TestFCSiamConcModel',
'TestFCSiamDiffModel', 'TestSNUNetModel', 'TestSTANetModel'
'TestFCSiamDiffModel', 'TestSNUNetModel', 'TestSTANetModel',
'TestChangeFormerModel'
]
@ -215,6 +216,18 @@ class TestSTANetModel(TestCDModel):
] # yapf: disable
class TestChangeFormerModel(TestCDModel):
MODEL_CLASS = paddlers.rs_models.cd.ChangeFormer
def set_specs(self):
base_spec = dict(in_channels=3, num_classes=2)
self.specs = [
base_spec,
dict(**base_spec, decoder_softmax=True),
dict(**base_spec, embed_dim=56)
] # yapf: disable
# HACK:FIXME: We observe an OOM error when running TestSTANetModel.test_forward() on a Windows machine.
# Currently, we do not perform this test.
if platform.system() == 'Windows':

@ -0,0 +1,94 @@
#!/usr/bin/env python
# 变化检测模型ChangeFormer训练示例脚本
# 执行此脚本前,请确认已正确安装PaddleRS库
import paddlers as pdrs
from paddlers import transforms as T
# 数据集存放目录
DATA_DIR = './data/airchange/'
# 训练集`file_list`文件路径
TRAIN_FILE_LIST_PATH = './data/airchange/train.txt'
# 验证集`file_list`文件路径
EVAL_FILE_LIST_PATH = './data/airchange/eval.txt'
# 实验目录,保存输出的模型权重和结果
EXP_DIR = './output/changeformer/'
# 下载和解压AirChange数据集
pdrs.utils.download_and_decompress(
'https://paddlers.bj.bcebos.com/datasets/airchange.zip', path='./data/')
# 定义训练和验证时使用的数据变换(数据增强、预处理等)
# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md
train_transforms = T.Compose([
# 读取影像
T.DecodeImg(),
# 随机裁剪
T.RandomCrop(
# 裁剪区域将被缩放到256x256
crop_size=256,
# 裁剪区域的横纵比在0.5-2之间变动
aspect_ratio=[0.5, 2.0],
# 裁剪区域相对原始影像长宽比例在一定范围内变动,最小不低于原始长宽的1/5
scaling=[0.2, 1.0]),
# 以50%的概率实施随机水平翻转
T.RandomHorizontalFlip(prob=0.5),
# 将数据归一化到[-1,1]
T.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
T.ArrangeChangeDetector('train')
])
eval_transforms = T.Compose([
T.DecodeImg(),
# 验证阶段与训练阶段的数据归一化方式必须相同
T.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
T.ReloadMask(),
T.ArrangeChangeDetector('eval')
])
# 分别构建训练和验证所用的数据集
train_dataset = pdrs.datasets.CDDataset(
data_dir=DATA_DIR,
file_list=TRAIN_FILE_LIST_PATH,
label_list=None,
transforms=train_transforms,
num_workers=0,
shuffle=True,
with_seg_labels=False,
binarize_labels=True)
eval_dataset = pdrs.datasets.CDDataset(
data_dir=DATA_DIR,
file_list=EVAL_FILE_LIST_PATH,
label_list=None,
transforms=eval_transforms,
num_workers=0,
shuffle=False,
with_seg_labels=False,
binarize_labels=True)
# 使用默认参数构建ChangeFormer模型
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/model_zoo.md
# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/change_detector.py
model = pdrs.tasks.cd.ChangeFormer()
# 执行模型训练
model.train(
num_epochs=5,
train_dataset=train_dataset,
train_batch_size=4,
eval_dataset=eval_dataset,
save_interval_epochs=3,
# 每多少次迭代记录一次日志
log_interval_steps=50,
save_dir=EXP_DIR,
# 是否使用early stopping策略,当精度不再改善时提前终止训练
early_stop=False,
# 是否启用VisualDL日志功能
use_vdl=True,
# 指定从某个检查点继续训练
resume_checkpoint=None)
Loading…
Cancel
Save