Factseg paddle (#54)
* fact-seg paddle * fact-paddle * Coding modify * Add files via upload * Passed testing code * Use offical hook for pre-commit * Use the official .style.yarp file and passed the yarp test locally * Restore the file which was changed by error * Revert "Use the official .style.yarp file and passed the yarp test locally" This reverts commit 6a294fa21b74166bc6be94b858fcf486ff13d4ea. * Fix code style * Code Modify * Code Modify according to reviewer * solve conflict * Solve conflict * solve yarp * yarp * Fix code style * [Feat] Make FactSeg public Co-authored-by: Bobholamovic <mhlin425@whu.edu.cn>own
parent
95a0f9e2ac
commit
afc25c93c9
13 changed files with 347 additions and 7 deletions
@ -0,0 +1,141 @@ |
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
||||
# |
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
# you may not use this file except in compliance with the License. |
||||
# You may obtain a copy of the License at |
||||
# |
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
# |
||||
# Unless required by applicable law or agreed to in writing, software |
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
# See the License for the specific language governing permissions and |
||||
# limitations under the License. |
||||
|
||||
import paddle |
||||
import paddle.nn as nn |
||||
import paddle.nn.functional as F |
||||
|
||||
from paddlers.models.ppdet.modeling import \ |
||||
initializer as init |
||||
from paddlers.rs_models.seg.farseg import FPN, \ |
||||
ResNetEncoder,AsymmetricDecoder |
||||
|
||||
|
||||
def conv_with_kaiming_uniform(use_gn=False, use_relu=False): |
||||
def make_conv(in_channels, out_channels, kernel_size, stride=1, dilation=1): |
||||
conv = nn.Conv2D( |
||||
in_channels, |
||||
out_channels, |
||||
kernel_size=kernel_size, |
||||
stride=stride, |
||||
padding=dilation * (kernel_size - 1) // 2, |
||||
dilation=dilation, |
||||
bias_attr=False if use_gn else True) |
||||
|
||||
init.kaiming_uniform_(conv.weight, a=1) |
||||
if not use_gn: |
||||
init.constant_(conv.bias, 0) |
||||
module = [conv, ] |
||||
if use_gn: |
||||
raise NotImplementedError |
||||
if use_relu: |
||||
module.append(nn.ReLU()) |
||||
if len(module) > 1: |
||||
return nn.Sequential(*module) |
||||
return conv |
||||
|
||||
return make_conv |
||||
|
||||
|
||||
default_conv_block = conv_with_kaiming_uniform(use_gn=False, use_relu=False) |
||||
|
||||
|
||||
class FactSeg(nn.Layer): |
||||
""" |
||||
The FactSeg implementation based on PaddlePaddle. |
||||
|
||||
The original article refers to |
||||
A. Ma, J. Wang, Y. Zhong and Z. Zheng, "FactSeg: Foreground Activation |
||||
-Driven Small Object Semantic Segmentation in Large-Scale Remote Sensing |
||||
Imagery,"in IEEE Transactions on Geoscience and Remote Sensing, vol. 60, |
||||
pp. 1-16, 2022, Art no. 5606216. |
||||
|
||||
|
||||
Args: |
||||
in_channels (int): The number of image channels for the input model. |
||||
num_classes (int): The unique number of target classes. |
||||
backbone (str, optional): A backbone network, models available in |
||||
`paddle.vision.models.resnet`. Default: resnet50. |
||||
backbone_pretrained (bool, optional): Whether the backbone network uses |
||||
IMAGENET pretrained weights. Default: True. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
in_channels, |
||||
num_classes, |
||||
backbone='resnet50', |
||||
backbone_pretrained=True): |
||||
super(FactSeg, self).__init__() |
||||
backbone = backbone.lower() |
||||
self.resencoder = ResNetEncoder( |
||||
backbone=backbone, |
||||
in_channels=in_channels, |
||||
pretrained=backbone_pretrained) |
||||
self.resencoder.resnet._sub_layers.pop('fc') |
||||
self.fgfpn = FPN(in_channels_list=[256, 512, 1024, 2048], |
||||
out_channels=256, |
||||
conv_block=default_conv_block) |
||||
self.bifpn = FPN(in_channels_list=[256, 512, 1024, 2048], |
||||
out_channels=256, |
||||
conv_block=default_conv_block) |
||||
self.fg_decoder = AsymmetricDecoder( |
||||
in_channels=256, |
||||
out_channels=128, |
||||
in_feature_output_strides=(4, 8, 16, 32), |
||||
out_feature_output_stride=4, |
||||
conv_block=nn.Conv2D) |
||||
self.bi_decoder = AsymmetricDecoder( |
||||
in_channels=256, |
||||
out_channels=128, |
||||
in_feature_output_strides=(4, 8, 16, 32), |
||||
out_feature_output_stride=4, |
||||
conv_block=nn.Conv2D) |
||||
self.fg_cls = nn.Conv2D(128, num_classes, kernel_size=1) |
||||
self.bi_cls = nn.Conv2D(128, 1, kernel_size=1) |
||||
self.config_loss = ['joint_loss'] |
||||
self.config_foreground = [] |
||||
self.fbattention_atttention = False |
||||
|
||||
def forward(self, x): |
||||
feat_list = self.resencoder(x) |
||||
if 'skip_decoder' in []: |
||||
fg_out = self.fgskip_deocder(feat_list) |
||||
bi_out = self.bgskip_deocder(feat_list) |
||||
else: |
||||
forefeat_list = list(self.fgfpn(feat_list)) |
||||
binaryfeat_list = self.bifpn(feat_list) |
||||
if self.fbattention_atttention: |
||||
for i in range(len(binaryfeat_list)): |
||||
forefeat_list[i] = self.fbatt_block_list[i]( |
||||
binaryfeat_list[i], forefeat_list[i]) |
||||
fg_out = self.fg_decoder(forefeat_list) |
||||
bi_out = self.bi_decoder(binaryfeat_list) |
||||
fg_pred = self.fg_cls(fg_out) |
||||
bi_pred = self.bi_cls(bi_out) |
||||
fg_pred = F.interpolate( |
||||
fg_pred, scale_factor=4.0, mode='bilinear', align_corners=True) |
||||
bi_pred = F.interpolate( |
||||
bi_pred, scale_factor=4.0, mode='bilinear', align_corners=True) |
||||
if self.training: |
||||
return [fg_pred] |
||||
else: |
||||
binary_prob = F.sigmoid(bi_pred) |
||||
cls_prob = F.softmax(fg_pred, axis=1) |
||||
cls_prob[:, 0, :, :] = cls_prob[:, 0, :, :] * ( |
||||
1 - binary_prob).squeeze(axis=1) |
||||
cls_prob[:, 1:, :, :] = cls_prob[:, 1:, :, :] * binary_prob |
||||
z = paddle.sum(cls_prob, axis=1) |
||||
z = z.unsqueeze(axis=1) |
||||
cls_prob = paddle.divide(cls_prob, z) |
||||
return [cls_prob] |
@ -0,0 +1,11 @@ |
||||
# Configurations of FactSeg with RSSeg dataset |
||||
|
||||
_base_: ../_base_/rsseg.yaml |
||||
|
||||
save_dir: ./test_tipc/output/seg/factseg/ |
||||
|
||||
model: !Node |
||||
type: FactSeg |
||||
args: |
||||
in_channels: 3 |
||||
num_classes: 5 |
@ -0,0 +1,53 @@ |
||||
===========================train_params=========================== |
||||
model_name:seg:factseg |
||||
python:python |
||||
gpu_list:0 |
||||
use_gpu:null|null |
||||
--precision:null |
||||
--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=20 |
||||
--save_dir:adaptive |
||||
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4 |
||||
--model_path:null |
||||
--config:lite_train_lite_infer=./test_tipc/configs/seg/factseg/factseg_rsseg.yaml|lite_train_whole_infer=./test_tipc/configs/seg/factseg/factseg_rsseg.yaml|whole_train_whole_infer=./test_tipc/configs/seg/factseg/factseg_rsseg.yaml |
||||
train_model_name:best_model |
||||
null:null |
||||
## |
||||
trainer:norm |
||||
norm_train:test_tipc/run_task.py train seg |
||||
pact_train:null |
||||
fpgm_train:null |
||||
distill_train:null |
||||
null:null |
||||
null:null |
||||
## |
||||
===========================eval_params=========================== |
||||
eval:null |
||||
null:null |
||||
## |
||||
===========================export_params=========================== |
||||
--save_dir:adaptive |
||||
--model_dir:adaptive |
||||
--fixed_input_shape:[-1,3,512,512] |
||||
norm_export:deploy/export/export_model.py |
||||
quant_export:null |
||||
fpgm_export:null |
||||
distill_export:null |
||||
export1:null |
||||
export2:null |
||||
===========================infer_params=========================== |
||||
infer_model:null |
||||
infer_export:null |
||||
infer_quant:False |
||||
inference:test_tipc/infer.py |
||||
--device:cpu|gpu |
||||
--enable_mkldnn:True |
||||
--cpu_threads:6 |
||||
--batch_size:1 |
||||
--use_trt:False |
||||
--precision:fp32 |
||||
--model_dir:null |
||||
--config:null |
||||
--save_log_path:null |
||||
--benchmark:True |
||||
--model_name:factseg |
||||
null:null |
@ -0,0 +1,94 @@ |
||||
#!/usr/bin/env python |
||||
|
||||
# 图像分割模型FactSeg训练示例脚本 |
||||
# 执行此脚本前,请确认已正确安装PaddleRS库 |
||||
|
||||
import paddlers as pdrs |
||||
from paddlers import transforms as T |
||||
|
||||
# 数据集存放目录 |
||||
DATA_DIR = './data/rsseg/' |
||||
# 训练集`file_list`文件路径 |
||||
TRAIN_FILE_LIST_PATH = './data/rsseg/train.txt' |
||||
# 验证集`file_list`文件路径 |
||||
EVAL_FILE_LIST_PATH = './data/rsseg/val.txt' |
||||
# 数据集类别信息文件路径 |
||||
LABEL_LIST_PATH = './data/rsseg/labels.txt' |
||||
# 实验目录,保存输出的模型权重和结果 |
||||
EXP_DIR = './output/factseg/' |
||||
|
||||
# 下载和解压多光谱地块分类数据集 |
||||
pdrs.utils.download_and_decompress( |
||||
'https://paddlers.bj.bcebos.com/datasets/rsseg.zip', path='./data/') |
||||
|
||||
# 定义训练和验证时使用的数据变换(数据增强、预处理等) |
||||
# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行 |
||||
# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md |
||||
train_transforms = T.Compose([ |
||||
# 读取影像 |
||||
T.DecodeImg(), |
||||
# 选择前三个波段 |
||||
T.SelectBand([1, 2, 3]), |
||||
# 将影像缩放到512x512大小 |
||||
T.Resize(target_size=512), |
||||
# 以50%的概率实施随机水平翻转 |
||||
T.RandomHorizontalFlip(prob=0.5), |
||||
# 将数据归一化到[-1,1] |
||||
T.Normalize( |
||||
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
||||
T.ArrangeSegmenter('train') |
||||
]) |
||||
|
||||
eval_transforms = T.Compose([ |
||||
T.DecodeImg(), |
||||
# 验证阶段与训练阶段应当选择相同的波段 |
||||
T.SelectBand([1, 2, 3]), |
||||
T.Resize(target_size=512), |
||||
# 验证阶段与训练阶段的数据归一化方式必须相同 |
||||
T.Normalize( |
||||
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
||||
T.ReloadMask(), |
||||
T.ArrangeSegmenter('eval') |
||||
]) |
||||
|
||||
# 分别构建训练和验证所用的数据集 |
||||
train_dataset = pdrs.datasets.SegDataset( |
||||
data_dir=DATA_DIR, |
||||
file_list=TRAIN_FILE_LIST_PATH, |
||||
label_list=LABEL_LIST_PATH, |
||||
transforms=train_transforms, |
||||
num_workers=0, |
||||
shuffle=True) |
||||
|
||||
eval_dataset = pdrs.datasets.SegDataset( |
||||
data_dir=DATA_DIR, |
||||
file_list=EVAL_FILE_LIST_PATH, |
||||
label_list=LABEL_LIST_PATH, |
||||
transforms=eval_transforms, |
||||
num_workers=0, |
||||
shuffle=False) |
||||
|
||||
# 构建FactSeg模型 |
||||
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md |
||||
# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/segmenter.py |
||||
model = pdrs.tasks.seg.FactSeg(num_classes=len(train_dataset.labels)) |
||||
|
||||
# 执行模型训练 |
||||
model.train( |
||||
num_epochs=10, |
||||
train_dataset=train_dataset, |
||||
train_batch_size=4, |
||||
eval_dataset=eval_dataset, |
||||
save_interval_epochs=5, |
||||
# 每多少次迭代记录一次日志 |
||||
log_interval_steps=4, |
||||
save_dir=EXP_DIR, |
||||
pretrain_weights=None, |
||||
# 初始学习率大小 |
||||
learning_rate=0.001, |
||||
# 是否使用early stopping策略,当精度不再改善时提前终止训练 |
||||
early_stop=False, |
||||
# 是否启用VisualDL日志功能 |
||||
use_vdl=True, |
||||
# 指定从某个检查点继续训练 |
||||
resume_checkpoint=None) |
Loading…
Reference in new issue