Factseg paddle (#54)

* fact-seg paddle

* fact-paddle

* Coding modify

* Add files via upload

* Passed testing code

* Use offical hook for pre-commit

* Use the official .style.yarp file and passed the yarp test locally

* Restore the file which was changed by error

* Revert "Use the official .style.yarp file and passed the yarp test locally"

This reverts commit 6a294fa21b74166bc6be94b858fcf486ff13d4ea.

* Fix code style

* Code Modify

* Code Modify according to reviewer

* solve conflict

* Solve conflict

* solve yarp

* yarp

* Fix code style

* [Feat] Make FactSeg public

Co-authored-by: Bobholamovic <mhlin425@whu.edu.cn>
own
LHE-IT 2 years ago committed by GitHub
parent 95a0f9e2ac
commit afc25c93c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      docs/intro/model_zoo.md
  2. 4
      paddlers/models/ppgan/models/generators/generator_firstorder.py
  3. 4
      paddlers/models/ppseg/models/losses/lovasz_loss.py
  4. 1
      paddlers/rs_models/seg/__init__.py
  5. 141
      paddlers/rs_models/seg/factseg.py
  6. 21
      paddlers/tasks/segmenter.py
  7. 1
      test_tipc/README.md
  8. 11
      test_tipc/configs/seg/factseg/factseg_rsseg.yaml
  9. 53
      test_tipc/configs/seg/factseg/train_infer_python.txt
  10. 2
      test_tipc/docs/test_train_inference_python.md
  11. 20
      tests/rs_models/test_seg_models.py
  12. 1
      tutorials/train/README.md
  13. 94
      tutorials/train/semantic_segmentation/factseg.py

@ -34,6 +34,7 @@ PaddleRS目前已支持的全部模型如下(标注\*的为遥感专用模型
| 目标检测 | YOLOv3 | 否 |
| 图像分割 | BiSeNet V2 | 是 |
| 图像分割 | DeepLab V3+ | 是 |
| 图像分割 | \*FactSeg | 是 |
| 图像分割 | \*FarSeg | 是 |
| 图像分割 | Fast-SCNN | 是 |
| 图像分割 | HRNet | 是 |

@ -131,8 +131,8 @@ class FirstOrderGenerator(nn.Layer):
transformed_kp['jacobian']))
normed_driving = paddle.inverse(kp_driving['jacobian'])
normed_transformed = jacobian_transformed
value = paddle.matmul(
*broadcast(normed_driving, normed_transformed))
value = paddle.matmul(*broadcast(normed_driving,
normed_transformed))
eye = paddle.tensor.eye(2, dtype='float32').reshape(
(1, 1, 2, 2))
eye = paddle.tile(eye, [1, value.shape[1], 1, 1])

@ -77,8 +77,8 @@ class LovaszHingeLoss(nn.Layer):
"""
if logits.shape[1] == 2:
logits = binary_channel_to_unary(logits)
loss = lovasz_hinge_flat(
*flatten_binary_scores(logits, labels, self.ignore_index))
loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels,
self.ignore_index))
return loss

@ -13,3 +13,4 @@
# limitations under the License.
from .farseg import FarSeg
from .factseg import FactSeg

@ -0,0 +1,141 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlers.models.ppdet.modeling import \
initializer as init
from paddlers.rs_models.seg.farseg import FPN, \
ResNetEncoder,AsymmetricDecoder
def conv_with_kaiming_uniform(use_gn=False, use_relu=False):
def make_conv(in_channels, out_channels, kernel_size, stride=1, dilation=1):
conv = nn.Conv2D(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=dilation * (kernel_size - 1) // 2,
dilation=dilation,
bias_attr=False if use_gn else True)
init.kaiming_uniform_(conv.weight, a=1)
if not use_gn:
init.constant_(conv.bias, 0)
module = [conv, ]
if use_gn:
raise NotImplementedError
if use_relu:
module.append(nn.ReLU())
if len(module) > 1:
return nn.Sequential(*module)
return conv
return make_conv
default_conv_block = conv_with_kaiming_uniform(use_gn=False, use_relu=False)
class FactSeg(nn.Layer):
"""
The FactSeg implementation based on PaddlePaddle.
The original article refers to
A. Ma, J. Wang, Y. Zhong and Z. Zheng, "FactSeg: Foreground Activation
-Driven Small Object Semantic Segmentation in Large-Scale Remote Sensing
Imagery,"in IEEE Transactions on Geoscience and Remote Sensing, vol. 60,
pp. 1-16, 2022, Art no. 5606216.
Args:
in_channels (int): The number of image channels for the input model.
num_classes (int): The unique number of target classes.
backbone (str, optional): A backbone network, models available in
`paddle.vision.models.resnet`. Default: resnet50.
backbone_pretrained (bool, optional): Whether the backbone network uses
IMAGENET pretrained weights. Default: True.
"""
def __init__(self,
in_channels,
num_classes,
backbone='resnet50',
backbone_pretrained=True):
super(FactSeg, self).__init__()
backbone = backbone.lower()
self.resencoder = ResNetEncoder(
backbone=backbone,
in_channels=in_channels,
pretrained=backbone_pretrained)
self.resencoder.resnet._sub_layers.pop('fc')
self.fgfpn = FPN(in_channels_list=[256, 512, 1024, 2048],
out_channels=256,
conv_block=default_conv_block)
self.bifpn = FPN(in_channels_list=[256, 512, 1024, 2048],
out_channels=256,
conv_block=default_conv_block)
self.fg_decoder = AsymmetricDecoder(
in_channels=256,
out_channels=128,
in_feature_output_strides=(4, 8, 16, 32),
out_feature_output_stride=4,
conv_block=nn.Conv2D)
self.bi_decoder = AsymmetricDecoder(
in_channels=256,
out_channels=128,
in_feature_output_strides=(4, 8, 16, 32),
out_feature_output_stride=4,
conv_block=nn.Conv2D)
self.fg_cls = nn.Conv2D(128, num_classes, kernel_size=1)
self.bi_cls = nn.Conv2D(128, 1, kernel_size=1)
self.config_loss = ['joint_loss']
self.config_foreground = []
self.fbattention_atttention = False
def forward(self, x):
feat_list = self.resencoder(x)
if 'skip_decoder' in []:
fg_out = self.fgskip_deocder(feat_list)
bi_out = self.bgskip_deocder(feat_list)
else:
forefeat_list = list(self.fgfpn(feat_list))
binaryfeat_list = self.bifpn(feat_list)
if self.fbattention_atttention:
for i in range(len(binaryfeat_list)):
forefeat_list[i] = self.fbatt_block_list[i](
binaryfeat_list[i], forefeat_list[i])
fg_out = self.fg_decoder(forefeat_list)
bi_out = self.bi_decoder(binaryfeat_list)
fg_pred = self.fg_cls(fg_out)
bi_pred = self.bi_cls(bi_out)
fg_pred = F.interpolate(
fg_pred, scale_factor=4.0, mode='bilinear', align_corners=True)
bi_pred = F.interpolate(
bi_pred, scale_factor=4.0, mode='bilinear', align_corners=True)
if self.training:
return [fg_pred]
else:
binary_prob = F.sigmoid(bi_pred)
cls_prob = F.softmax(fg_pred, axis=1)
cls_prob[:, 0, :, :] = cls_prob[:, 0, :, :] * (
1 - binary_prob).squeeze(axis=1)
cls_prob[:, 1:, :, :] = cls_prob[:, 1:, :, :] * binary_prob
z = paddle.sum(cls_prob, axis=1)
z = z.unsqueeze(axis=1)
cls_prob = paddle.divide(cls_prob, z)
return [cls_prob]

@ -13,7 +13,6 @@
# limitations under the License.
import math
import os
import os.path as osp
from collections import OrderedDict
@ -36,7 +35,9 @@ from .utils import seg_metrics as metrics
from .utils.infer_nets import InferSegNet
from .utils.slider_predict import slider_predict
__all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
__all__ = [
"UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg", "FactSeg"
]
class BaseSegmenter(BaseModel):
@ -894,3 +895,19 @@ class FarSeg(BaseSegmenter):
losses=losses,
in_channels=in_channels,
**params)
class FactSeg(BaseSegmenter):
def __init__(self,
in_channels=3,
num_classes=2,
use_mixed_loss=False,
losses=None,
**params):
super(FactSeg, self).__init__(
model_name='FactSeg',
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
losses=losses,
in_channels=in_channels,
**params)

@ -46,6 +46,7 @@
| 目标检测 | YOLOv3 | 支持 | - | - | - |
| 图像分割 | BiSeNet V2 | 支持 | - | - | - |
| 图像分割 | DeepLab V3+ | 支持 | - | - | - |
| 图像分割 | FactSeg | 支持 | - | - | - |
| 图像分割 | FarSeg | 支持 | - | - | - |
| 图像分割 | Fast-SCNN | 支持 | - | - | - |
| 图像分割 | HRNet | 支持 | - | - | - |

@ -0,0 +1,11 @@
# Configurations of FactSeg with RSSeg dataset
_base_: ../_base_/rsseg.yaml
save_dir: ./test_tipc/output/seg/factseg/
model: !Node
type: FactSeg
args:
in_channels: 3
num_classes: 5

@ -0,0 +1,53 @@
===========================train_params===========================
model_name:seg:factseg
python:python
gpu_list:0
use_gpu:null|null
--precision:null
--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=20
--save_dir:adaptive
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4
--model_path:null
--config:lite_train_lite_infer=./test_tipc/configs/seg/factseg/factseg_rsseg.yaml|lite_train_whole_infer=./test_tipc/configs/seg/factseg/factseg_rsseg.yaml|whole_train_whole_infer=./test_tipc/configs/seg/factseg/factseg_rsseg.yaml
train_model_name:best_model
null:null
##
trainer:norm
norm_train:test_tipc/run_task.py train seg
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================export_params===========================
--save_dir:adaptive
--model_dir:adaptive
--fixed_input_shape:[-1,3,512,512]
norm_export:deploy/export/export_model.py
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
===========================infer_params===========================
infer_model:null
infer_export:null
infer_quant:False
inference:test_tipc/infer.py
--device:cpu|gpu
--enable_mkldnn:True
--cpu_threads:6
--batch_size:1
--use_trt:False
--precision:fp32
--model_dir:null
--config:null
--save_log_path:null
--benchmark:True
--model_name:factseg
null:null

@ -33,6 +33,7 @@ Linux GPU/CPU 基础训练推理测试的主程序为`test_train_inference_pytho
| 图像复原 | LESRCNN | 正常训练 | 正常训练 | PSNR=23.67 |
| 图像分割 | BiSeNet V2 | 正常训练 | 正常训练 | mIoU=70.52% |
| 图像分割 | DeepLab V3+ | 正常训练 | 正常训练 | mIoU=64.41% |
| 图像分割 | FactSeg | 正常训练 | 正常训练 | |
| 图像分割 | FarSeg | 正常训练 | 正常训练 | mIoU=50.60% |
| 图像分割 | Fast-SCNN | 正常训练 | 正常训练 | mIoU=49.27% |
| 图像分割 | HRNet | 正常训练 | 正常训练 | mIoU=33.03% |
@ -68,6 +69,7 @@ Linux GPU/CPU 基础训练推理测试的主程序为`test_train_inference_pytho
| 目标检测 | YOLOv3 | 支持 | 支持 | 1 |
| 图像分割 | BiSeNet V2 | 支持 | 支持 | 1 |
| 图像分割 | DeepLab V3+ | 支持 | 支持 | 1 |
| 图像分割 | FactSeg | 支持 | 支持 | 1 |
| 图像分割 | FarSeg | 支持 | 支持 | 1 |
| 图像分割 | Fast-SCNN | 支持 | 支持 | 1 |
| 图像分割 | HRNet | 支持 | 支持 | 1 |

@ -15,7 +15,7 @@
import paddlers
from rs_models.test_model import TestModel
__all__ = ['TestFarSegModel']
__all__ = ['TestFarSegModel', 'TestFactSegModel']
class TestSegModel(TestModel):
@ -70,3 +70,21 @@ class TestFarSegModel(TestSegModel):
self.targets = [[self.get_zeros_array(2)], [self.get_zeros_array(10)],
[self.get_zeros_array(2)], [self.get_zeros_array(2)],
[self.get_zeros_array(2)]]
class TestFactSegModel(TestSegModel):
MODEL_CLASS = paddlers.rs_models.seg.FactSeg
def set_specs(self):
base_spec = dict(in_channels=3, num_classes=2)
self.specs = [
base_spec,
dict(in_channels=6, num_classes=10),
dict(**base_spec,
backbone='resnet50',
backbone_pretrained=False)
] # yapf: disable
def set_targets(self):
self.targets = [[self.get_zeros_array(2)], [self.get_zeros_array(10)],
[self.get_zeros_array(2)]]

@ -29,6 +29,7 @@
|object_detection/yolov3.py | 目标检测 | YOLOv3 |
|semantic_segmentation/bisenetv2.py | 图像分割 | BiSeNet V2 |
|semantic_segmentation/deeplabv3p.py | 图像分割 | DeepLab V3+ |
|semantic_segmentation/factseg.py | 图像分割 | FactSeg |
|semantic_segmentation/farseg.py | 图像分割 | FarSeg |
|semantic_segmentation/fast_scnn.py | 图像分割 | Fast-SCNN |
|semantic_segmentation/hrnet.py | 图像分割 | HRNet |

@ -0,0 +1,94 @@
#!/usr/bin/env python
# 图像分割模型FactSeg训练示例脚本
# 执行此脚本前,请确认已正确安装PaddleRS库
import paddlers as pdrs
from paddlers import transforms as T
# 数据集存放目录
DATA_DIR = './data/rsseg/'
# 训练集`file_list`文件路径
TRAIN_FILE_LIST_PATH = './data/rsseg/train.txt'
# 验证集`file_list`文件路径
EVAL_FILE_LIST_PATH = './data/rsseg/val.txt'
# 数据集类别信息文件路径
LABEL_LIST_PATH = './data/rsseg/labels.txt'
# 实验目录,保存输出的模型权重和结果
EXP_DIR = './output/factseg/'
# 下载和解压多光谱地块分类数据集
pdrs.utils.download_and_decompress(
'https://paddlers.bj.bcebos.com/datasets/rsseg.zip', path='./data/')
# 定义训练和验证时使用的数据变换(数据增强、预处理等)
# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
train_transforms = T.Compose([
# 读取影像
T.DecodeImg(),
# 选择前三个波段
T.SelectBand([1, 2, 3]),
# 将影像缩放到512x512大小
T.Resize(target_size=512),
# 以50%的概率实施随机水平翻转
T.RandomHorizontalFlip(prob=0.5),
# 将数据归一化到[-1,1]
T.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
T.ArrangeSegmenter('train')
])
eval_transforms = T.Compose([
T.DecodeImg(),
# 验证阶段与训练阶段应当选择相同的波段
T.SelectBand([1, 2, 3]),
T.Resize(target_size=512),
# 验证阶段与训练阶段的数据归一化方式必须相同
T.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
T.ReloadMask(),
T.ArrangeSegmenter('eval')
])
# 分别构建训练和验证所用的数据集
train_dataset = pdrs.datasets.SegDataset(
data_dir=DATA_DIR,
file_list=TRAIN_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,
transforms=train_transforms,
num_workers=0,
shuffle=True)
eval_dataset = pdrs.datasets.SegDataset(
data_dir=DATA_DIR,
file_list=EVAL_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,
transforms=eval_transforms,
num_workers=0,
shuffle=False)
# 构建FactSeg模型
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/segmenter.py
model = pdrs.tasks.seg.FactSeg(num_classes=len(train_dataset.labels))
# 执行模型训练
model.train(
num_epochs=10,
train_dataset=train_dataset,
train_batch_size=4,
eval_dataset=eval_dataset,
save_interval_epochs=5,
# 每多少次迭代记录一次日志
log_interval_steps=4,
save_dir=EXP_DIR,
pretrain_weights=None,
# 初始学习率大小
learning_rate=0.001,
# 是否使用early stopping策略,当精度不再改善时提前终止训练
early_stop=False,
# 是否启用VisualDL日志功能
use_vdl=True,
# 指定从某个检查点继续训练
resume_checkpoint=None)
Loading…
Cancel
Save