Add seg models

own
Bobholamovic 2 years ago
parent 9f5c87e9dd
commit 69c160404a
  1. 23
      docs/intro/model_zoo.md
  2. 11
      paddlers/tasks/segmenter.py
  3. 8
      test_tipc/configs/cd/bit/bit.yaml
  4. 8
      test_tipc/configs/cd/changeformer/changeformer.yaml
  5. 13
      test_tipc/configs/cd/fccdn/fccdn.yaml
  6. 11
      test_tipc/configs/seg/bisenetv2/bisenetv2_rsseg.yaml
  7. 53
      test_tipc/configs/seg/bisenetv2/train_infer_python.txt
  8. 11
      test_tipc/configs/seg/fast_scnn/fast_scnn_rsseg.yaml
  9. 53
      test_tipc/configs/seg/fast_scnn/train_infer_python.txt
  10. 11
      test_tipc/configs/seg/hrnet/hrnet_rsseg.yaml
  11. 53
      test_tipc/configs/seg/hrnet/train_infer_python.txt
  12. 8
      test_tipc/docs/test_train_inference_python.md
  13. 4
      tutorials/train/README.md
  14. 93
      tutorials/train/semantic_segmentation/bisenetv2.py
  15. 93
      tutorials/train/semantic_segmentation/fast_scnn.py
  16. 93
      tutorials/train/semantic_segmentation/hrnet.py

@ -20,18 +20,21 @@ PaddleRS目前已支持的全部模型如下(标注\*的为遥感专用模型
| 变化检测 | \*FCCDN | 是 |
| 变化检测 | \*SNUNet | 是 |
| 变化检测 | \*STANet | 是 |
| 场景分类 | CondenseNetV2 | 是 |
| 场景分类 | HRNet | |
| 场景分类 | MobileNetV3 | |
| 场景分类 | ResNet50-vd | |
| 场景分类 | CondenseNet V2 | 是 |
| 场景分类 | HRNet | |
| 场景分类 | MobileNetV3 | |
| 场景分类 | ResNet50-vd | |
| 图像复原 | DRN | 否 |
| 图像复原 | ESRGAN | 否 |
| 图像复原 | LESRCNN | 否 |
| 目标检测 | Faster R-CNN | 是 |
| 目标检测 | PP-YOLO | 是 |
| 目标检测 | PP-YOLO Tiny | 是 |
| 目标检测 | PP-YOLOv2 | 是 |
| 目标检测 | YOLOv3 | 是 |
| 目标检测 | Faster R-CNN | 否 |
| 目标检测 | PP-YOLO | 否 |
| 目标检测 | PP-YOLO Tiny | 否 |
| 目标检测 | PP-YOLOv2 | 否 |
| 目标检测 | YOLOv3 | 否 |
| 图像分割 | BiSeNet V2 | 是 |
| 图像分割 | DeepLab V3+ | 是 |
| 图像分割 | \*FarSeg | 否 |
| 图像分割 | \*FarSeg | 是 |
| 图像分割 | Fast-SCNN | 是 |
| 图像分割 | HRNet | 是 |
| 图像分割 | UNet | 是 |

@ -806,7 +806,7 @@ class UNet(BaseSegmenter):
})
super(UNet, self).__init__(
model_name='UNet',
input_channel=in_channels,
in_channels=in_channels,
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
losses=losses,
@ -834,7 +834,7 @@ class DeepLabV3P(BaseSegmenter):
if params.get('with_net', True):
with DisablePrint():
backbone = getattr(ppseg.models, backbone)(
input_channel=in_channels, output_stride=output_stride)
in_channels=in_channels, output_stride=output_stride)
else:
backbone = None
params.update({
@ -854,6 +854,7 @@ class DeepLabV3P(BaseSegmenter):
class FastSCNN(BaseSegmenter):
def __init__(self,
in_channels=3,
num_classes=2,
use_mixed_loss=False,
losses=None,
@ -862,6 +863,7 @@ class FastSCNN(BaseSegmenter):
params.update({'align_corners': align_corners})
super(FastSCNN, self).__init__(
model_name='FastSCNN',
in_channels=in_channels,
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
losses=losses,
@ -870,6 +872,7 @@ class FastSCNN(BaseSegmenter):
class HRNet(BaseSegmenter):
def __init__(self,
in_channels=3,
num_classes=2,
width=48,
use_mixed_loss=False,
@ -884,7 +887,7 @@ class HRNet(BaseSegmenter):
if params.get('with_net', True):
with DisablePrint():
backbone = getattr(ppseg.models, self.backbone_name)(
align_corners=align_corners)
in_channels=in_channels, align_corners=align_corners)
else:
backbone = None
@ -900,6 +903,7 @@ class HRNet(BaseSegmenter):
class BiSeNetV2(BaseSegmenter):
def __init__(self,
in_channels=3,
num_classes=2,
use_mixed_loss=False,
losses=None,
@ -908,6 +912,7 @@ class BiSeNetV2(BaseSegmenter):
params.update({'align_corners': align_corners})
super(BiSeNetV2, self).__init__(
model_name='BiSeNetV2',
in_channels=in_channels,
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
losses=losses,

@ -1,8 +0,0 @@
# Basic configurations of BIT
_base_: ../_base_/airchange.yaml
save_dir: ./test_tipc/output/cd/bit/
model: !Node
type: BIT

@ -1,8 +0,0 @@
# Basic configurations of ChangeFormer
_base_: ../_base_/airchange.yaml
save_dir: ./test_tipc/output/cd/changeformer/
model: !Node
type: ChangeFormer

@ -1,13 +0,0 @@
# Basic configurations of FCCDN
_base_: ../_base_/airchange.yaml
save_dir: ./test_tipc/output/cd/fccdn/
model: !Node
type: FCCDN
learning_rate: 0.07
lr_decay_power: 0.6
log_interval_steps: 100
save_interval_epochs: 3

@ -0,0 +1,11 @@
# Configurations of BiSeNet V2 with RSSeg dataset
_base_: ../_base_/rsseg.yaml
save_dir: ./test_tipc/output/seg/bisenetv2/
model: !Node
type: BiSeNet V2
args:
in_channels: 10
num_classes: 5

@ -0,0 +1,53 @@
===========================train_params===========================
model_name:seg:bisenetv2
python:python
gpu_list:0|0,1
use_gpu:null|null
--precision:null
--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=20
--save_dir:adaptive
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4
--model_path:null
--config:lite_train_lite_infer=./test_tipc/configs/seg/bisenetv2/bisenetv2_rsseg.yaml|lite_train_whole_infer=./test_tipc/configs/seg/bisenetv2/bisenetv2_rsseg.yaml|whole_train_whole_infer=./test_tipc/configs/seg/bisenetv2/bisenetv2_rsseg.yaml
train_model_name:best_model
null:null
##
trainer:norm
norm_train:test_tipc/run_task.py train seg
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================export_params===========================
--save_dir:adaptive
--model_dir:adaptive
--fixed_input_shape:[-1,10,512,512]
norm_export:deploy/export/export_model.py
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
===========================infer_params===========================
infer_model:null
infer_export:null
infer_quant:False
inference:test_tipc/infer.py
--device:cpu|gpu
--enable_mkldnn:True
--cpu_threads:6
--batch_size:1
--use_trt:False
--precision:fp32
--model_dir:null
--config:null
--save_log_path:null
--benchmark:True
--model_name:bisenetv2
null:null

@ -0,0 +1,11 @@
# Configurations of Fast-SCNN with RSSeg dataset
_base_: ../_base_/rsseg.yaml
save_dir: ./test_tipc/output/seg/fast_scnn/
model: !Node
type: Fast-SCNN
args:
in_channels: 10
num_classes: 5

@ -0,0 +1,53 @@
===========================train_params===========================
model_name:seg:fast_scnn
python:python
gpu_list:0|0,1
use_gpu:null|null
--precision:null
--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=20
--save_dir:adaptive
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4
--model_path:null
--config:lite_train_lite_infer=./test_tipc/configs/seg/fast_scnn/fast_scnn_rsseg.yaml|lite_train_whole_infer=./test_tipc/configs/seg/fast_scnn/fast_scnn_rsseg.yaml|whole_train_whole_infer=./test_tipc/configs/seg/fast_scnn/fast_scnn_rsseg.yaml
train_model_name:best_model
null:null
##
trainer:norm
norm_train:test_tipc/run_task.py train seg
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================export_params===========================
--save_dir:adaptive
--model_dir:adaptive
--fixed_input_shape:[-1,10,512,512]
norm_export:deploy/export/export_model.py
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
===========================infer_params===========================
infer_model:null
infer_export:null
infer_quant:False
inference:test_tipc/infer.py
--device:cpu|gpu
--enable_mkldnn:True
--cpu_threads:6
--batch_size:1
--use_trt:False
--precision:fp32
--model_dir:null
--config:null
--save_log_path:null
--benchmark:True
--model_name:fast_scnn
null:null

@ -0,0 +1,11 @@
# Configurations of HRNet with RSSeg dataset
_base_: ../_base_/rsseg.yaml
save_dir: ./test_tipc/output/seg/hrnet/
model: !Node
type: HRNet
args:
in_channels: 10
num_classes: 5

@ -0,0 +1,53 @@
===========================train_params===========================
model_name:seg:hrnet
python:python
gpu_list:0|0,1
use_gpu:null|null
--precision:null
--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=20
--save_dir:adaptive
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4
--model_path:null
--config:lite_train_lite_infer=./test_tipc/configs/seg/hrnet/hrnet_rsseg.yaml|lite_train_whole_infer=./test_tipc/configs/seg/hrnet/hrnet_rsseg.yaml|whole_train_whole_infer=./test_tipc/configs/seg/hrnet/hrnet_rsseg.yaml
train_model_name:best_model
null:null
##
trainer:norm
norm_train:test_tipc/run_task.py train seg
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================export_params===========================
--save_dir:adaptive
--model_dir:adaptive
--fixed_input_shape:[-1,10,512,512]
norm_export:deploy/export/export_model.py
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
===========================infer_params===========================
infer_model:null
infer_export:null
infer_quant:False
inference:test_tipc/infer.py
--device:cpu|gpu
--enable_mkldnn:True
--cpu_threads:6
--batch_size:1
--use_trt:False
--precision:fp32
--model_dir:null
--config:null
--save_log_path:null
--benchmark:True
--model_name:hrnet
null:null

@ -19,6 +19,7 @@ Linux GPU/CPU 基础训练推理测试的主程序为`test_train_inference_pytho
| 变化检测 | FC-Siam-conc | 正常训练 | 正常训练 | IoU=65.79% |
| 变化检测 | FC-Siam-diff | 正常训练 | 正常训练 | IoU=61.23% |
| 变化检测 | FCCDN | 正常训练 | 正常训练 | IoU=24.42% |
| 场景分类 | CondenseNet V2 | 正常训练 | 正常训练 | Acc(top1)= |
| 场景分类 | HRNet | 正常训练 | 正常训练 | Acc(top1)=99.37% |
| 场景分类 | MobileNetV3 | 正常训练 | 正常训练 | Acc(top1)=99.58% |
| 场景分类 | ResNet50-vd | 正常训练 | 正常训练 | Acc(top1)=99.26% |
@ -30,8 +31,11 @@ Linux GPU/CPU 基础训练推理测试的主程序为`test_train_inference_pytho
| 目标检测 | PP-YOLO Tiny | 正常训练 | 正常训练 | mAP=44.27% |
| 目标检测 | PP-YOLOv2 | 正常训练 | 正常训练 | mAP=59.37% |
| 目标检测 | YOLOv3 | 正常训练 | 正常训练 | mAP=47.33% |
| 图像分割 | BiSeNet V2 | 正常训练 | 正常训练 | mIoU= |
| 图像分割 | DeepLab V3+ | 正常训练 | 正常训练 | mIoU=56.05% |
| 图像分割 | FarSeg | 正常训练 | 正常训练 | mIoU=49.58% |
| 图像分割 | Fast-SCNN | 正常训练 | 正常训练 | mIoU= |
| 图像分割 | HRNet | 正常训练 | 正常训练 | mIoU= |
| 图像分割 | UNet | 正常训练 | 正常训练 | mIoU=55.50% |
*注:参考预测精度为whole_train_whole_infer模式下单卡训练汇报的精度数据。*
@ -50,6 +54,7 @@ Linux GPU/CPU 基础训练推理测试的主程序为`test_train_inference_pytho
| 变化检测 | FC-EF | 支持 | 支持 | 1 |
| 变化检测 | FC-Siam-conc | 支持 | 支持 | 1 |
| 变化检测 | FC-Siam-diff | 支持 | 支持 | 1 |
| 场景分类 | CondenseNet V2 | 支持 | 支持 | 1 |
| 场景分类 | HRNet | 支持 | 支持 | 1 |
| 场景分类 | MobileNetV3 | 支持 | 支持 | 1 |
| 场景分类 | ResNet50-vd | 支持 | 支持 | 1 |
@ -61,8 +66,11 @@ Linux GPU/CPU 基础训练推理测试的主程序为`test_train_inference_pytho
| 目标检测 | PP-YOLO Tiny | 支持 | 支持 | 1 |
| 目标检测 | PP-YOLOv2 | 支持 | 支持 | 1 |
| 目标检测 | YOLOv3 | 支持 | 支持 | 1 |
| 图像分割 | BiSeNet V2 | 支持 | 支持 | 1 |
| 图像分割 | DeepLab V3+ | 支持 | 支持 | 1 |
| 图像分割 | FarSeg | 支持 | 支持 | 1 |
| 图像分割 | Fast-SCNN | 支持 | 支持 | 1 |
| 图像分割 | HRNet | 支持 | 支持 | 1 |
| 图像分割 | UNet | 支持 | 支持 | 1 |
## 2 测试流程

@ -15,6 +15,7 @@
|change_detection/fccdn.py | 变化检测 | FCCDN |
|change_detection/snunet.py | 变化检测 | SNUNet |
|change_detection/stanet.py | 变化检测 | STANet |
|classification/condensenetv2.py | 场景分类 | CondenseNet V2 |
|classification/hrnet.py | 场景分类 | HRNet |
|classification/mobilenetv3.py | 场景分类 | MobileNetV3 |
|classification/resnet50_vd.py | 场景分类 | ResNet50-vd |
@ -26,8 +27,11 @@
|object_detection/ppyolo_tiny.py | 目标检测 | PP-YOLO Tiny |
|object_detection/ppyolov2.py | 目标检测 | PP-YOLOv2 |
|object_detection/yolov3.py | 目标检测 | YOLOv3 |
|semantic_segmentation/bisenetv2.py | 图像分割 | BiSeNet V2 |
|semantic_segmentation/deeplabv3p.py | 图像分割 | DeepLab V3+ |
|semantic_segmentation/farseg.py | 图像分割 | FarSeg |
|semantic_segmentation/fast_scnn.py | 图像分割 | Fast-SCNN |
|semantic_segmentation/hrnet.py | 图像分割 | HRNet |
|semantic_segmentation/unet.py | 图像分割 | UNet |
## 环境准备

@ -0,0 +1,93 @@
#!/usr/bin/env python
# 图像分割模型BiSeNet V2训练示例脚本
# 执行此脚本前,请确认已正确安装PaddleRS库
import paddlers as pdrs
from paddlers import transforms as T
# 数据集存放目录
DATA_DIR = './data/rsseg/'
# 训练集`file_list`文件路径
TRAIN_FILE_LIST_PATH = './data/rsseg/train.txt'
# 验证集`file_list`文件路径
EVAL_FILE_LIST_PATH = './data/rsseg/val.txt'
# 数据集类别信息文件路径
LABEL_LIST_PATH = './data/rsseg/labels.txt'
# 实验目录,保存输出的模型权重和结果
EXP_DIR = './output/unet/'
# 影像波段数量
NUM_BANDS = 10
# 下载和解压多光谱地块分类数据集
pdrs.utils.download_and_decompress(
'https://paddlers.bj.bcebos.com/datasets/rsseg.zip', path='./data/')
# 定义训练和验证时使用的数据变换(数据增强、预处理等)
# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
train_transforms = T.Compose([
# 读取影像
T.DecodeImg(),
# 将影像缩放到512x512大小
T.Resize(target_size=512),
# 以50%的概率实施随机水平翻转
T.RandomHorizontalFlip(prob=0.5),
# 将数据归一化到[-1,1]
T.Normalize(
mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
T.ArrangeSegmenter('train')
])
eval_transforms = T.Compose([
T.DecodeImg(),
T.Resize(target_size=512),
# 验证阶段与训练阶段的数据归一化方式必须相同
T.Normalize(
mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
T.ReloadMask(),
T.ArrangeSegmenter('eval')
])
# 分别构建训练和验证所用的数据集
train_dataset = pdrs.datasets.SegDataset(
data_dir=DATA_DIR,
file_list=TRAIN_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,
transforms=train_transforms,
num_workers=0,
shuffle=True)
eval_dataset = pdrs.datasets.SegDataset(
data_dir=DATA_DIR,
file_list=EVAL_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,
transforms=eval_transforms,
num_workers=0,
shuffle=False)
# 构建BiSeNet V2模型
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/segmenter.py
model = pdrs.tasks.seg.BiSeNetV2(
in_channels=NUM_BANDS, num_classes=len(train_dataset.labels))
# 执行模型训练
model.train(
num_epochs=10,
train_dataset=train_dataset,
train_batch_size=4,
eval_dataset=eval_dataset,
save_interval_epochs=5,
# 每多少次迭代记录一次日志
log_interval_steps=4,
save_dir=EXP_DIR,
# 初始学习率大小
learning_rate=0.001,
# 是否使用early stopping策略,当精度不再改善时提前终止训练
early_stop=False,
# 是否启用VisualDL日志功能
use_vdl=True,
# 指定从某个检查点继续训练
resume_checkpoint=None)

@ -0,0 +1,93 @@
#!/usr/bin/env python
# 图像分割模型Fast-SCNN训练示例脚本
# 执行此脚本前,请确认已正确安装PaddleRS库
import paddlers as pdrs
from paddlers import transforms as T
# 数据集存放目录
DATA_DIR = './data/rsseg/'
# 训练集`file_list`文件路径
TRAIN_FILE_LIST_PATH = './data/rsseg/train.txt'
# 验证集`file_list`文件路径
EVAL_FILE_LIST_PATH = './data/rsseg/val.txt'
# 数据集类别信息文件路径
LABEL_LIST_PATH = './data/rsseg/labels.txt'
# 实验目录,保存输出的模型权重和结果
EXP_DIR = './output/unet/'
# 影像波段数量
NUM_BANDS = 10
# 下载和解压多光谱地块分类数据集
pdrs.utils.download_and_decompress(
'https://paddlers.bj.bcebos.com/datasets/rsseg.zip', path='./data/')
# 定义训练和验证时使用的数据变换(数据增强、预处理等)
# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
train_transforms = T.Compose([
# 读取影像
T.DecodeImg(),
# 将影像缩放到512x512大小
T.Resize(target_size=512),
# 以50%的概率实施随机水平翻转
T.RandomHorizontalFlip(prob=0.5),
# 将数据归一化到[-1,1]
T.Normalize(
mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
T.ArrangeSegmenter('train')
])
eval_transforms = T.Compose([
T.DecodeImg(),
T.Resize(target_size=512),
# 验证阶段与训练阶段的数据归一化方式必须相同
T.Normalize(
mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
T.ReloadMask(),
T.ArrangeSegmenter('eval')
])
# 分别构建训练和验证所用的数据集
train_dataset = pdrs.datasets.SegDataset(
data_dir=DATA_DIR,
file_list=TRAIN_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,
transforms=train_transforms,
num_workers=0,
shuffle=True)
eval_dataset = pdrs.datasets.SegDataset(
data_dir=DATA_DIR,
file_list=EVAL_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,
transforms=eval_transforms,
num_workers=0,
shuffle=False)
# 构建Fast-SCNN模型
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/segmenter.py
model = pdrs.tasks.seg.FastSCNN(
in_channels=NUM_BANDS, num_classes=len(train_dataset.labels))
# 执行模型训练
model.train(
num_epochs=10,
train_dataset=train_dataset,
train_batch_size=4,
eval_dataset=eval_dataset,
save_interval_epochs=5,
# 每多少次迭代记录一次日志
log_interval_steps=4,
save_dir=EXP_DIR,
# 初始学习率大小
learning_rate=0.001,
# 是否使用early stopping策略,当精度不再改善时提前终止训练
early_stop=False,
# 是否启用VisualDL日志功能
use_vdl=True,
# 指定从某个检查点继续训练
resume_checkpoint=None)

@ -0,0 +1,93 @@
#!/usr/bin/env python
# 图像分割模型HRNet训练示例脚本
# 执行此脚本前,请确认已正确安装PaddleRS库
import paddlers as pdrs
from paddlers import transforms as T
# 数据集存放目录
DATA_DIR = './data/rsseg/'
# 训练集`file_list`文件路径
TRAIN_FILE_LIST_PATH = './data/rsseg/train.txt'
# 验证集`file_list`文件路径
EVAL_FILE_LIST_PATH = './data/rsseg/val.txt'
# 数据集类别信息文件路径
LABEL_LIST_PATH = './data/rsseg/labels.txt'
# 实验目录,保存输出的模型权重和结果
EXP_DIR = './output/unet/'
# 影像波段数量
NUM_BANDS = 10
# 下载和解压多光谱地块分类数据集
pdrs.utils.download_and_decompress(
'https://paddlers.bj.bcebos.com/datasets/rsseg.zip', path='./data/')
# 定义训练和验证时使用的数据变换(数据增强、预处理等)
# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
train_transforms = T.Compose([
# 读取影像
T.DecodeImg(),
# 将影像缩放到512x512大小
T.Resize(target_size=512),
# 以50%的概率实施随机水平翻转
T.RandomHorizontalFlip(prob=0.5),
# 将数据归一化到[-1,1]
T.Normalize(
mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
T.ArrangeSegmenter('train')
])
eval_transforms = T.Compose([
T.DecodeImg(),
T.Resize(target_size=512),
# 验证阶段与训练阶段的数据归一化方式必须相同
T.Normalize(
mean=[0.5] * NUM_BANDS, std=[0.5] * NUM_BANDS),
T.ReloadMask(),
T.ArrangeSegmenter('eval')
])
# 分别构建训练和验证所用的数据集
train_dataset = pdrs.datasets.SegDataset(
data_dir=DATA_DIR,
file_list=TRAIN_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,
transforms=train_transforms,
num_workers=0,
shuffle=True)
eval_dataset = pdrs.datasets.SegDataset(
data_dir=DATA_DIR,
file_list=EVAL_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,
transforms=eval_transforms,
num_workers=0,
shuffle=False)
# 构建HRNet模型
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/segmenter.py
model = pdrs.tasks.seg.HRNet(
in_channels=NUM_BANDS, num_classes=len(train_dataset.labels))
# 执行模型训练
model.train(
num_epochs=10,
train_dataset=train_dataset,
train_batch_size=4,
eval_dataset=eval_dataset,
save_interval_epochs=5,
# 每多少次迭代记录一次日志
log_interval_steps=4,
save_dir=EXP_DIR,
# 初始学习率大小
learning_rate=0.001,
# 是否使用early stopping策略,当精度不再改善时提前终止训练
early_stop=False,
# 是否启用VisualDL日志功能
use_vdl=True,
# 指定从某个检查点继续训练
resume_checkpoint=None)
Loading…
Cancel
Save