parent
ea9bb47b62
commit
9e3a2611e0
8 changed files with 1263 additions and 2 deletions
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,71 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License.. |
||||
|
||||
from itertools import repeat |
||||
import collections.abc |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
""" |
||||
Droppath, reimplement from https://github.com/yueatsprograms/Stochastic_Depth |
||||
""" |
||||
|
||||
|
||||
class DropPath(nn.Layer): |
||||
"""DropPath class""" |
||||
|
||||
def __init__(self, drop_prob=None): |
||||
super().__init__() |
||||
self.drop_prob = drop_prob |
||||
|
||||
def drop_path(self, inputs): |
||||
"""drop path op |
||||
Args: |
||||
input: tensor with arbitrary shape |
||||
drop_prob: float number of drop path probability, default: 0.0 |
||||
training: bool, if current mode is training, default: False |
||||
Returns: |
||||
output: output tensor after drop path |
||||
""" |
||||
# if prob is 0 or eval mode, return original input |
||||
if self.drop_prob == 0. or not self.training: |
||||
return inputs |
||||
keep_prob = 1 - self.drop_prob |
||||
keep_prob = paddle.to_tensor(keep_prob, dtype='float32') |
||||
shape = (inputs.shape[0], ) + (1, ) * (inputs.ndim - 1 |
||||
) # shape=(N, 1, 1, 1) |
||||
random_tensor = keep_prob + paddle.rand(shape, dtype=inputs.dtype) |
||||
random_tensor = random_tensor.floor() # mask |
||||
output = inputs.divide( |
||||
keep_prob) * random_tensor # divide to keep same output expectation |
||||
return output |
||||
|
||||
def forward(self, inputs): |
||||
return self.drop_path(inputs) |
||||
|
||||
|
||||
def _ntuple(n): |
||||
def parse(x): |
||||
if isinstance(x, collections.abc.Iterable): |
||||
return x |
||||
return tuple(repeat(x, n)) |
||||
|
||||
return parse |
||||
|
||||
|
||||
to_1tuple = _ntuple(1) |
||||
to_2tuple = _ntuple(2) |
||||
to_3tuple = _ntuple(3) |
||||
to_4tuple = _ntuple(4) |
||||
to_ntuple = _ntuple |
@ -0,0 +1,8 @@ |
||||
# Basic configurations of ChangeFormer |
||||
|
||||
_base_: ../_base_/airchange.yaml |
||||
|
||||
save_dir: ./test_tipc/output/cd/changeformer/ |
||||
|
||||
model: !Node |
||||
type: ChangeFormer |
@ -0,0 +1,53 @@ |
||||
===========================train_params=========================== |
||||
model_name:cd:changeformer |
||||
python:python |
||||
gpu_list:0|0,1 |
||||
use_gpu:null|null |
||||
--precision:null |
||||
--num_epochs:lite_train_lite_infer=5|lite_train_whole_infer=5|whole_train_whole_infer=10 |
||||
--save_dir:adaptive |
||||
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4 |
||||
--model_path:null |
||||
train_model_name:best_model |
||||
train_infer_file_list:./test_tipc/data/airchange/:./test_tipc/data/airchange/eval.txt |
||||
null:null |
||||
## |
||||
trainer:norm |
||||
norm_train:test_tipc/run_task.py train cd --config ./test_tipc/configs/cd/changeformer/changeformer.yaml |
||||
pact_train:null |
||||
fpgm_train:null |
||||
distill_train:null |
||||
null:null |
||||
null:null |
||||
## |
||||
===========================eval_params=========================== |
||||
eval:null |
||||
null:null |
||||
## |
||||
===========================export_params=========================== |
||||
--save_dir:adaptive |
||||
--model_dir:adaptive |
||||
--fixed_input_shape:[1,3,256,256] |
||||
norm_export:deploy/export/export_model.py |
||||
quant_export:null |
||||
fpgm_export:null |
||||
distill_export:null |
||||
export1:null |
||||
export2:null |
||||
===========================infer_params=========================== |
||||
infer_model:null |
||||
infer_export:null |
||||
infer_quant:False |
||||
inference:test_tipc/infer.py |
||||
--device:cpu|gpu |
||||
--enable_mkldnn:True |
||||
--cpu_threads:6 |
||||
--batch_size:1 |
||||
--use_trt:False |
||||
--precision:fp32 |
||||
--model_dir:null |
||||
--file_list:null:null |
||||
--save_log_path:null |
||||
--benchmark:True |
||||
--model_name:changeformer |
||||
null:null |
@ -0,0 +1,94 @@ |
||||
#!/usr/bin/env python |
||||
|
||||
# 变化检测模型ChangeFormer训练示例脚本 |
||||
# 执行此脚本前,请确认已正确安装PaddleRS库 |
||||
|
||||
import paddlers as pdrs |
||||
from paddlers import transforms as T |
||||
|
||||
# 数据集存放目录 |
||||
DATA_DIR = './data/airchange/' |
||||
# 训练集`file_list`文件路径 |
||||
TRAIN_FILE_LIST_PATH = './data/airchange/train.txt' |
||||
# 验证集`file_list`文件路径 |
||||
EVAL_FILE_LIST_PATH = './data/airchange/eval.txt' |
||||
# 实验目录,保存输出的模型权重和结果 |
||||
EXP_DIR = './output/changeformer/' |
||||
|
||||
# 下载和解压AirChange数据集 |
||||
pdrs.utils.download_and_decompress( |
||||
'https://paddlers.bj.bcebos.com/datasets/airchange.zip', path='./data/') |
||||
|
||||
# 定义训练和验证时使用的数据变换(数据增强、预处理等) |
||||
# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行 |
||||
# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/transforms.md |
||||
train_transforms = T.Compose([ |
||||
# 读取影像 |
||||
T.DecodeImg(), |
||||
# 随机裁剪 |
||||
T.RandomCrop( |
||||
# 裁剪区域将被缩放到256x256 |
||||
crop_size=256, |
||||
# 裁剪区域的横纵比在0.5-2之间变动 |
||||
aspect_ratio=[0.5, 2.0], |
||||
# 裁剪区域相对原始影像长宽比例在一定范围内变动,最小不低于原始长宽的1/5 |
||||
scaling=[0.2, 1.0]), |
||||
# 以50%的概率实施随机水平翻转 |
||||
T.RandomHorizontalFlip(prob=0.5), |
||||
# 将数据归一化到[-1,1] |
||||
T.Normalize( |
||||
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
||||
T.ArrangeChangeDetector('train') |
||||
]) |
||||
|
||||
eval_transforms = T.Compose([ |
||||
T.DecodeImg(), |
||||
# 验证阶段与训练阶段的数据归一化方式必须相同 |
||||
T.Normalize( |
||||
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
||||
T.ReloadMask(), |
||||
T.ArrangeChangeDetector('eval') |
||||
]) |
||||
|
||||
# 分别构建训练和验证所用的数据集 |
||||
train_dataset = pdrs.datasets.CDDataset( |
||||
data_dir=DATA_DIR, |
||||
file_list=TRAIN_FILE_LIST_PATH, |
||||
label_list=None, |
||||
transforms=train_transforms, |
||||
num_workers=0, |
||||
shuffle=True, |
||||
with_seg_labels=False, |
||||
binarize_labels=True) |
||||
|
||||
eval_dataset = pdrs.datasets.CDDataset( |
||||
data_dir=DATA_DIR, |
||||
file_list=EVAL_FILE_LIST_PATH, |
||||
label_list=None, |
||||
transforms=eval_transforms, |
||||
num_workers=0, |
||||
shuffle=False, |
||||
with_seg_labels=False, |
||||
binarize_labels=True) |
||||
|
||||
# 使用默认参数构建ChangeFormer模型 |
||||
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/model_zoo.md |
||||
# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/change_detector.py |
||||
model = pdrs.tasks.cd.ChangeFormer() |
||||
|
||||
# 执行模型训练 |
||||
model.train( |
||||
num_epochs=5, |
||||
train_dataset=train_dataset, |
||||
train_batch_size=4, |
||||
eval_dataset=eval_dataset, |
||||
save_interval_epochs=3, |
||||
# 每多少次迭代记录一次日志 |
||||
log_interval_steps=50, |
||||
save_dir=EXP_DIR, |
||||
# 是否使用early stopping策略,当精度不再改善时提前终止训练 |
||||
early_stop=False, |
||||
# 是否启用VisualDL日志功能 |
||||
use_vdl=True, |
||||
# 指定从某个检查点继续训练 |
||||
resume_checkpoint=None) |
Loading…
Reference in new issue