Merge pull request #32 from Bobholamovic/farseg_in_chns

[Feat] Add multi-channel support to FarSeg
own
cc 2 years ago committed by GitHub
commit 07630a11f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      paddlers/rs_models/cd/bit.py
  2. 45
      paddlers/rs_models/seg/farseg.py
  3. 45
      paddlers/tasks/change_detector.py
  4. 45
      paddlers/tasks/classifier.py
  5. 42
      paddlers/tasks/object_detector.py
  6. 35
      paddlers/tasks/restorer.py
  7. 45
      paddlers/tasks/segmenter.py
  8. 20
      paddlers/utils/checkpoint.py
  9. 1
      test_tipc/README.md
  10. 11
      test_tipc/configs/seg/farseg/farseg_rsseg.yaml
  11. 53
      test_tipc/configs/seg/farseg/train_infer_python.txt
  12. 2
      test_tipc/docs/test_train_inference_python.md
  13. 3
      tests/rs_models/test_seg_models.py
  14. 1
      tutorials/train/README.md
  15. 2
      tutorials/train/semantic_segmentation/deeplabv3p.py
  16. 94
      tutorials/train/semantic_segmentation/farseg.py
  17. 2
      tutorials/train/semantic_segmentation/unet.py

@ -56,7 +56,7 @@ class BIT(nn.Layer):
Default: 2.
enc_with_pos (bool, optional): Whether to add leanred positional embedding to the input feature sequence of the
encoder. Default: True.
enc_depth (int, optional): Number of attention blocks used in the encoder. Default: 1
enc_depth (int, optional): Number of attention blocks used in the encoder. Default: 1.
enc_head_dim (int, optional): Embedding dimension of each encoder head. Default: 64.
dec_depth (int, optional): Number of attention blocks used in the decoder. Default: 8.
dec_head_dim (int, optional): Embedding dimension of each decoder head. Default: 8.

@ -11,11 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/Z-Zheng/FarSeg
Ths copyright of Z-Zheng/FarSeg is as follows:
Apache License [see LICENSE for details]
"""
# This code is based on https://github.com/Z-Zheng/FarSeg
# The copyright of Z-Zheng/FarSeg is as follows:
# Apache License (see https://github.com/Z-Zheng/FarSeg/blob/master/LICENSE for details).
import math
@ -164,7 +163,7 @@ class SceneRelation(nn.Layer):
return refined_feats
class AssymetricDecoder(nn.Layer):
class AsymmetricDecoder(nn.Layer):
def __init__(self,
in_channels,
out_channels,
@ -172,7 +171,7 @@ class AssymetricDecoder(nn.Layer):
out_feat_output_stride=4,
norm_fn=nn.BatchNorm2D,
num_groups_gn=None):
super(AssymetricDecoder, self).__init__()
super(AsymmetricDecoder, self).__init__()
if norm_fn == nn.BatchNorm2D:
norm_fn_args = dict(num_features=out_channels)
elif norm_fn == nn.GroupNorm:
@ -215,9 +214,12 @@ class AssymetricDecoder(nn.Layer):
class ResNet50Encoder(nn.Layer):
def __init__(self, pretrained=True):
def __init__(self, in_ch=3, pretrained=True):
super(ResNet50Encoder, self).__init__()
self.resnet = resnet50(pretrained=pretrained)
if in_ch != 3:
self.resnet.conv1 = nn.Conv2D(
in_ch, 64, kernel_size=7, stride=2, padding=3, bias_attr=False)
def forward(self, inputs):
x = inputs
@ -234,25 +236,35 @@ class ResNet50Encoder(nn.Layer):
class FarSeg(nn.Layer):
"""
The FarSeg implementation based on PaddlePaddle.
The FarSeg implementation based on PaddlePaddle.
The original article refers to
Zheng, Zhuo, et al. "Foreground-Aware Relation Network for Geospatial Object Segmentation in High Spatial Resolution
Remote Sensing Imagery"
(https://openaccess.thecvf.com/content_CVPR_2020/papers/Zheng_Foreground-Aware_Relation_Network_for_Geospatial_Object_Segmentation_in_High_Spatial_CVPR_2020_paper.pdf)
The original article refers to
Zheng, Zhuo, et al. "Foreground-Aware Relation Network for Geospatial Object
Segmentation in High Spatial Resolution Remote Sensing Imagery"
(https://openaccess.thecvf.com/content_CVPR_2020/papers/Zheng_Foreground-Aware_Relation_Network_for_Geospatial_Object_Segmentation_in_High_Spatial_CVPR_2020_paper.pdf)
Args:
in_channels (int, optional): Number of bands of the input images. Default: 3.
num_classes (int, optional): Number of target classes. Default: 16.
fpn_ch_list (list[int]|tuple[int], optional): Channel list of the FPN. Default: (256, 512, 1024, 2048).
mid_ch (int, optional): Output channels of the FPN. Default: 256.
out_ch (int, optional): Output channels of the decoder. Default: 128.
sr_ch_list (list[int]|tuple[int], optional): Channel list of the foreground-scene relation module. Default: (256, 256, 256, 256).
pretrained_encoder (bool, optional): Whether to use a pretrained encoder. Default: True.
"""
def __init__(self,
in_channels=3,
num_classes=16,
fpn_ch_list=(256, 512, 1024, 2048),
mid_ch=256,
out_ch=128,
sr_ch_list=(256, 256, 256, 256),
encoder_pretrained=True):
pretrained_encoder=True):
super(FarSeg, self).__init__()
self.en = ResNet50Encoder(encoder_pretrained)
self.en = ResNet50Encoder(in_channels, pretrained_encoder)
self.fpn = FPN(in_channels_list=fpn_ch_list, out_channels=mid_ch)
self.decoder = AssymetricDecoder(
self.decoder = AsymmetricDecoder(
in_channels=mid_ch, out_channels=out_ch)
self.cls_pred_conv = nn.Conv2D(out_ch, num_classes, 1)
self.upsample4x_op = nn.UpsamplingBilinear2D(scale_factor=4)
@ -273,5 +285,4 @@ class FarSeg(nn.Layer):
final_feat = self.decoder(refined_fpn_feat_list)
cls_pred = self.cls_pred_conv(final_feat)
cls_pred = self.upsample4x_op(cls_pred)
cls_pred = F.softmax(cls_pred, axis=1)
return [cls_pred]

@ -31,7 +31,7 @@ import paddlers.utils.logging as logging
from paddlers.models import seg_losses
from paddlers.transforms import Resize, decode_image
from paddlers.utils import get_single_card_bs
from paddlers.utils.checkpoint import seg_pretrain_weights_dict
from paddlers.utils.checkpoint import cd_pretrain_weights_dict
from .base import BaseModel
from .utils import seg_metrics as metrics
from .utils.infer_nets import InferCDNet
@ -275,7 +275,7 @@ class BaseChangeDetector(BaseModel):
exit=True)
if pretrain_weights is not None and resume_checkpoint is not None:
logging.error(
"pretrain_weights and resume_checkpoint cannot be set simultaneously.",
"`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.",
exit=True)
self.labels = train_dataset.labels
if self.losses is None:
@ -289,23 +289,30 @@ class BaseChangeDetector(BaseModel):
else:
self.optimizer = optimizer
if pretrain_weights is not None and not osp.exists(pretrain_weights):
if pretrain_weights not in seg_pretrain_weights_dict[
self.model_name]:
logging.warning(
"Path of pretrain_weights('{}') does not exist!".format(
pretrain_weights))
logging.warning("Pretrain_weights is forcibly set to '{}'. "
"If don't want to use pretrain weights, "
"set pretrain_weights to be None.".format(
seg_pretrain_weights_dict[self.model_name][
0]))
pretrain_weights = seg_pretrain_weights_dict[self.model_name][0]
elif pretrain_weights is not None and osp.exists(pretrain_weights):
if osp.splitext(pretrain_weights)[-1] != '.pdparams':
logging.error(
"Invalid pretrain weights. Please specify a '.pdparams' file.",
exit=True)
if pretrain_weights is not None:
if not osp.exists(pretrain_weights):
if self.model_name not in cd_pretrain_weights_dict:
logging.warning(
"Path of pretrained weights ('{}') does not exist!".
format(pretrain_weights))
pretrain_weights = None
elif pretrain_weights not in cd_pretrain_weights_dict[
self.model_name]:
logging.warning(
"Path of pretrained weights ('{}') does not exist!".
format(pretrain_weights))
pretrain_weights = cd_pretrain_weights_dict[
self.model_name][0]
logging.warning(
"`pretrain_weights` is forcibly set to '{}'. "
"If you don't want to use pretrained weights, "
"please set `pretrain_weights` to None.".format(
pretrain_weights))
else:
if osp.splitext(pretrain_weights)[-1] != '.pdparams':
logging.error(
"Invalid pretrained weights. Please specify a .pdparams file.",
exit=True)
pretrained_dir = osp.join(save_dir, 'pretrain')
is_backbone_weights = pretrain_weights == 'IMAGENET'
self.net_initialize(

@ -246,7 +246,7 @@ class BaseClassifier(BaseModel):
exit=True)
if pretrain_weights is not None and resume_checkpoint is not None:
logging.error(
"pretrain_weights and resume_checkpoint cannot be set simultaneously.",
"`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.",
exit=True)
self.labels = train_dataset.labels
if self.losses is None:
@ -262,25 +262,32 @@ class BaseClassifier(BaseModel):
else:
self.optimizer = optimizer
if pretrain_weights is not None and not osp.exists(pretrain_weights):
if pretrain_weights not in cls_pretrain_weights_dict[
self.model_name]:
logging.warning(
"Path of pretrain_weights('{}') does not exist!".format(
pretrain_weights))
logging.warning("Pretrain_weights is forcibly set to '{}'. "
"If don't want to use pretrain weights, "
"set pretrain_weights to be None.".format(
cls_pretrain_weights_dict[self.model_name][
0]))
pretrain_weights = cls_pretrain_weights_dict[self.model_name][0]
elif pretrain_weights is not None and osp.exists(pretrain_weights):
if osp.splitext(pretrain_weights)[-1] != '.pdparams':
logging.error(
"Invalid pretrain weights. Please specify a '.pdparams' file.",
exit=True)
if pretrain_weights is not None:
if not osp.exists(pretrain_weights):
if self.model_name not in cls_pretrain_weights_dict:
logging.warning(
"Path of `pretrain_weights` ('{}') does not exist!".
format(pretrain_weights))
pretrain_weights = None
elif pretrain_weights not in cls_pretrain_weights_dict[
self.model_name]:
logging.warning(
"Path of `pretrain_weights` ('{}') does not exist!".
format(pretrain_weights))
pretrain_weights = cls_pretrain_weights_dict[
self.model_name][0]
logging.warning(
"`pretrain_weights` is forcibly set to '{}'. "
"If you don't want to use pretrained weights, "
"set `pretrain_weights` to None.".format(
pretrain_weights))
else:
if osp.splitext(pretrain_weights)[-1] != '.pdparams':
logging.error(
"Invalid pretrained weights. Please specify a .pdparams file.",
exit=True)
pretrained_dir = osp.join(save_dir, 'pretrain')
is_backbone_weights = False # pretrain_weights == 'IMAGENET' # TODO: this is backbone
is_backbone_weights = False
self.net_initialize(
pretrain_weights=pretrain_weights,
save_dir=pretrained_dir,

@ -274,7 +274,7 @@ class BaseDetector(BaseModel):
exit=True)
if pretrain_weights is not None and resume_checkpoint is not None:
logging.error(
"pretrain_weights and resume_checkpoint cannot be set simultaneously.",
"`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.",
exit=True)
if train_dataset.__class__.__name__ == 'VOCDetDataset':
train_dataset.data_fields = {
@ -323,23 +323,29 @@ class BaseDetector(BaseModel):
self.optimizer = optimizer
# Initiate weights
if pretrain_weights is not None and not osp.exists(pretrain_weights):
if pretrain_weights not in det_pretrain_weights_dict['_'.join(
[self.model_name, self.backbone_name])]:
logging.warning(
"Path of pretrain_weights('{}') does not exist!".format(
pretrain_weights))
pretrain_weights = det_pretrain_weights_dict['_'.join(
[self.model_name, self.backbone_name])][0]
logging.warning("Pretrain_weights is forcibly set to '{}'. "
"If you don't want to use pretrain weights, "
"set pretrain_weights to be None.".format(
pretrain_weights))
elif pretrain_weights is not None and osp.exists(pretrain_weights):
if osp.splitext(pretrain_weights)[-1] != '.pdparams':
logging.error(
"Invalid pretrain weights. Please specify a '.pdparams' file.",
exit=True)
if pretrain_weights is not None:
if not osp.exists(pretrain_weights):
key = '_'.join([self.model_name, self.backbone_name])
if key not in det_pretrain_weights_dict:
logging.warning(
"Path of pretrained weights ('{}') does not exist!".
format(pretrain_weights))
pretrain_weights = None
elif pretrain_weights not in det_pretrain_weights_dict[key]:
logging.warning(
"Path of pretrained weights ('{}') does not exist!".
format(pretrain_weights))
pretrain_weights = det_pretrain_weights_dict[key][0]
logging.warning(
"`pretrain_weights` is forcibly set to '{}'. "
"If you don't want to use pretrained weights, "
"please set `pretrain_weights` to None.".format(
pretrain_weights))
else:
if osp.splitext(pretrain_weights)[-1] != '.pdparams':
logging.error(
"Invalid pretrained weights. Please specify a .pdparams file.",
exit=True)
pretrained_dir = osp.join(save_dir, 'pretrain')
self.net_initialize(
pretrain_weights=pretrain_weights,

@ -31,6 +31,7 @@ from paddlers.models import res_losses
from paddlers.transforms import Resize, decode_image
from paddlers.transforms.functions import calc_hr_shape
from paddlers.utils import get_single_card_bs
from paddlers.utils.checkpoint import res_pretrain_weights_dict
from .base import BaseModel
from .utils.res_adapters import GANAdapter, OptimizerAdapter
from .utils.infer_nets import InferResNet
@ -234,7 +235,7 @@ class BaseRestorer(BaseModel):
exit=True)
if pretrain_weights is not None and resume_checkpoint is not None:
logging.error(
"pretrain_weights and resume_checkpoint cannot be set simultaneously.",
"`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.",
exit=True)
if self.losses is None:
@ -256,14 +257,30 @@ class BaseRestorer(BaseModel):
else:
self.optimizer = optimizer
if pretrain_weights is not None and not osp.exists(pretrain_weights):
logging.warning("Path of pretrain_weights('{}') does not exist!".
format(pretrain_weights))
elif pretrain_weights is not None and osp.exists(pretrain_weights):
if osp.splitext(pretrain_weights)[-1] != '.pdparams':
logging.error(
"Invalid pretrain weights. Please specify a '.pdparams' file.",
exit=True)
if pretrain_weights is not None:
if not osp.exists(pretrain_weights):
if self.model_name not in res_pretrain_weights_dict:
logging.warning(
"Path of pretrained weights ('{}') does not exist!".
format(pretrain_weights))
pretrain_weights = None
elif pretrain_weights not in res_pretrain_weights_dict[
self.model_name]:
logging.warning(
"Path of pretrained weights ('{}') does not exist!".
format(pretrain_weights))
pretrain_weights = res_pretrain_weights_dict[
self.model_name][0]
logging.warning(
"`pretrain_weights` is forcibly set to '{}'. "
"If you don't want to use pretrained weights, "
"please set `pretrain_weights` to None.".format(
pretrain_weights))
else:
if osp.splitext(pretrain_weights)[-1] != '.pdparams':
logging.error(
"Invalid pretrained weights. Please specify a .pdparams file.",
exit=True)
pretrained_dir = osp.join(save_dir, 'pretrain')
is_backbone_weights = pretrain_weights == 'IMAGENET'
self.net_initialize(

@ -267,7 +267,7 @@ class BaseSegmenter(BaseModel):
exit=True)
if pretrain_weights is not None and resume_checkpoint is not None:
logging.error(
"pretrain_weights and resume_checkpoint cannot be set simultaneously.",
"`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.",
exit=True)
self.labels = train_dataset.labels
if self.losses is None:
@ -281,23 +281,30 @@ class BaseSegmenter(BaseModel):
else:
self.optimizer = optimizer
if pretrain_weights is not None and not osp.exists(pretrain_weights):
if pretrain_weights not in seg_pretrain_weights_dict[
self.model_name]:
logging.warning(
"Path of pretrain_weights('{}') does not exist!".format(
pretrain_weights))
logging.warning("Pretrain_weights is forcibly set to '{}'. "
"If don't want to use pretrain weights, "
"set pretrain_weights to be None.".format(
seg_pretrain_weights_dict[self.model_name][
0]))
pretrain_weights = seg_pretrain_weights_dict[self.model_name][0]
elif pretrain_weights is not None and osp.exists(pretrain_weights):
if osp.splitext(pretrain_weights)[-1] != '.pdparams':
logging.error(
"Invalid pretrain weights. Please specify a '.pdparams' file.",
exit=True)
if pretrain_weights is not None:
if not osp.exists(pretrain_weights):
if self.model_name not in seg_pretrain_weights_dict:
logging.warning(
"Path of pretrained weights ('{}') does not exist!".
format(pretrain_weights))
pretrain_weights = None
elif pretrain_weights not in seg_pretrain_weights_dict[
self.model_name]:
logging.warning(
"Path of pretrained weights ('{}') does not exist!".
format(pretrain_weights))
pretrain_weights = seg_pretrain_weights_dict[
self.model_name][0]
logging.warning(
"`pretrain_weights` is forcibly set to '{}'. "
"If you don't want to use pretrained weights, "
"please set `pretrain_weights` to None.".format(
pretrain_weights))
else:
if osp.splitext(pretrain_weights)[-1] != '.pdparams':
logging.error(
"Invalid pretrained weights. Please specify a .pdparams file.",
exit=True)
pretrained_dir = osp.join(save_dir, 'pretrain')
is_backbone_weights = pretrain_weights == 'IMAGENET'
self.net_initialize(
@ -909,6 +916,7 @@ class BiSeNetV2(BaseSegmenter):
class FarSeg(BaseSegmenter):
def __init__(self,
in_channels=3,
num_classes=2,
use_mixed_loss=False,
losses=None,
@ -918,4 +926,5 @@ class FarSeg(BaseSegmenter):
num_classes=num_classes,
use_mixed_loss=use_mixed_loss,
losses=losses,
in_channels=in_channels,
**params)

@ -21,20 +21,14 @@ import paddle
from . import logging
from .download import download_and_decompress
cd_pretrain_weights_dict = {}
cls_pretrain_weights_dict = {
'ResNet50_vd': ['IMAGENET'],
'MobileNetV3_small_x1_0': ['IMAGENET'],
'HRNet_W18_C': ['IMAGENET'],
}
seg_pretrain_weights_dict = {
'UNet': ['CITYSCAPES'],
'DeepLabV3P': ['CITYSCAPES', 'PascalVOC', 'IMAGENET'],
'FastSCNN': ['CITYSCAPES'],
'HRNet': ['CITYSCAPES', 'PascalVOC'],
'BiSeNetV2': ['CITYSCAPES']
}
det_pretrain_weights_dict = {
'PicoDet_ESNet_s': ['COCO', 'IMAGENET'],
'PicoDet_ESNet_m': ['COCO', 'IMAGENET'],
@ -74,6 +68,16 @@ det_pretrain_weights_dict = {
'MaskRCNN_ResNet101_vd_fpn': ['COCO', 'IMAGENET']
}
res_pretrain_weights_dict = {}
seg_pretrain_weights_dict = {
'UNet': ['CITYSCAPES'],
'DeepLabV3P': ['CITYSCAPES', 'PascalVOC', 'IMAGENET'],
'FastSCNN': ['CITYSCAPES'],
'HRNet': ['CITYSCAPES', 'PascalVOC'],
'BiSeNetV2': ['CITYSCAPES']
}
cityscapes_weights = {
'UNet_CITYSCAPES':
'https://bj.bcebos.com/paddleseg/dygraph/cityscapes/unet_cityscapes_1024x512_160k/model.pdparams',

@ -44,6 +44,7 @@
| 目标检测 | PP-YOLOv2 | 支持 | - | - | - |
| 目标检测 | YOLOv3 | 支持 | - | - | - |
| 图像分割 | DeepLab V3+ | 支持 | - | - | - |
| 图像分割 | FarSeg | 支持 | - | - | - |
| 图像分割 | UNet | 支持 | - | - | - |
## 3 测试工具简介

@ -0,0 +1,11 @@
# Configurations of FarSeg with RSSeg dataset
_base_: ../_base_/rsseg.yaml
save_dir: ./test_tipc/output/seg/farseg/
model: !Node
type: FarSeg
args:
in_channels: 10
num_classes: 5

@ -0,0 +1,53 @@
===========================train_params===========================
model_name:seg:farseg
python:python
gpu_list:0|0,1
use_gpu:null|null
--precision:null
--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=20
--save_dir:adaptive
--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4
--model_path:null
--config:lite_train_lite_infer=./test_tipc/configs/seg/farseg/farseg_rsseg.yaml|lite_train_whole_infer=./test_tipc/configs/seg/farseg/farseg_rsseg.yaml|whole_train_whole_infer=./test_tipc/configs/seg/farseg/farseg_rsseg.yaml
train_model_name:best_model
null:null
##
trainer:norm
norm_train:test_tipc/run_task.py train seg
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================export_params===========================
--save_dir:adaptive
--model_dir:adaptive
--fixed_input_shape:[-1,10,512,512]
norm_export:deploy/export/export_model.py
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
===========================infer_params===========================
infer_model:null
infer_export:null
infer_quant:False
inference:test_tipc/infer.py
--device:cpu|gpu
--enable_mkldnn:True
--cpu_threads:6
--batch_size:1
--use_trt:False
--precision:fp32
--model_dir:null
--config:null
--save_log_path:null
--benchmark:True
--model_name:farseg
null:null

@ -31,6 +31,7 @@ Linux GPU/CPU 基础训练推理测试的主程序为`test_train_inference_pytho
| 目标检测 | PP-YOLOv2 | 正常训练 | 正常训练 | mAP=59.37% |
| 目标检测 | YOLOv3 | 正常训练 | 正常训练 | mAP=47.33% |
| 图像分割 | DeepLab V3+ | 正常训练 | 正常训练 | mIoU=56.05% |
| 图像分割 | FarSeg | 正常训练 | 正常训练 | mIoU=49.58% |
| 图像分割 | UNet | 正常训练 | 正常训练 | mIoU=55.50% |
*注:参考预测精度为whole_train_whole_infer模式下单卡训练汇报的精度数据。*
@ -61,6 +62,7 @@ Linux GPU/CPU 基础训练推理测试的主程序为`test_train_inference_pytho
| 目标检测 | PP-YOLOv2 | 支持 | 支持 | 1 |
| 目标检测 | YOLOv3 | 支持 | 支持 | 1 |
| 图像分割 | DeepLab V3+ | 支持 | 支持 | 1 |
| 图像分割 | FarSeg | 支持 | 支持 | 1 |
| 图像分割 | UNet | 支持 | 支持 | 1 |
## 2 测试流程

@ -50,7 +50,8 @@ class TestFarSegModel(TestSegModel):
def set_specs(self):
self.specs = [
dict(), dict(num_classes=20), dict(encoder_pretrained=False)
dict(), dict(num_classes=20), dict(pretrained_encoder=False),
dict(in_channels=10)
]
def set_targets(self):

@ -27,6 +27,7 @@
|object_detection/ppyolov2.py | 目标检测 | PP-YOLOv2 |
|object_detection/yolov3.py | 目标检测 | YOLOv3 |
|semantic_segmentation/deeplabv3p.py | 图像分割 | DeepLab V3+ |
|semantic_segmentation/farseg.py | 图像分割 | FarSeg |
|semantic_segmentation/unet.py | 图像分割 | UNet |
## 环境准备

@ -71,7 +71,7 @@ eval_dataset = pdrs.datasets.SegDataset(
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/segmenter.py
model = pdrs.tasks.seg.DeepLabV3P(
input_channel=NUM_BANDS,
in_channels=NUM_BANDS,
num_classes=len(train_dataset.labels),
backbone='ResNet50_vd')

@ -0,0 +1,94 @@
#!/usr/bin/env python
# 图像分割模型FarSeg训练示例脚本
# 执行此脚本前,请确认已正确安装PaddleRS库
import paddlers as pdrs
from paddlers import transforms as T
# 数据集存放目录
DATA_DIR = './data/rsseg/'
# 训练集`file_list`文件路径
TRAIN_FILE_LIST_PATH = './data/rsseg/train.txt'
# 验证集`file_list`文件路径
EVAL_FILE_LIST_PATH = './data/rsseg/val.txt'
# 数据集类别信息文件路径
LABEL_LIST_PATH = './data/rsseg/labels.txt'
# 实验目录,保存输出的模型权重和结果
EXP_DIR = './output/farseg/'
# 下载和解压多光谱地块分类数据集
pdrs.utils.download_and_decompress(
'https://paddlers.bj.bcebos.com/datasets/rsseg.zip', path='./data/')
# 定义训练和验证时使用的数据变换(数据增强、预处理等)
# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行
# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md
train_transforms = T.Compose([
# 读取影像
T.DecodeImg(),
# 选择前三个波段
T.SelectBand([1, 2, 3]),
# 将影像缩放到512x512大小
T.Resize(target_size=512),
# 以50%的概率实施随机水平翻转
T.RandomHorizontalFlip(prob=0.5),
# 将数据归一化到[-1,1]
T.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
T.ArrangeSegmenter('train')
])
eval_transforms = T.Compose([
T.DecodeImg(),
# 验证阶段与训练阶段应当选择相同的波段
T.SelectBand([1, 2, 3]),
T.Resize(target_size=512),
# 验证阶段与训练阶段的数据归一化方式必须相同
T.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
T.ReloadMask(),
T.ArrangeSegmenter('eval')
])
# 分别构建训练和验证所用的数据集
train_dataset = pdrs.datasets.SegDataset(
data_dir=DATA_DIR,
file_list=TRAIN_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,
transforms=train_transforms,
num_workers=0,
shuffle=True)
eval_dataset = pdrs.datasets.SegDataset(
data_dir=DATA_DIR,
file_list=EVAL_FILE_LIST_PATH,
label_list=LABEL_LIST_PATH,
transforms=eval_transforms,
num_workers=0,
shuffle=False)
# 构建FarSeg模型
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/segmenter.py
model = pdrs.tasks.seg.FarSeg(num_classes=len(train_dataset.labels))
# 执行模型训练
model.train(
num_epochs=10,
train_dataset=train_dataset,
train_batch_size=4,
eval_dataset=eval_dataset,
save_interval_epochs=5,
# 每多少次迭代记录一次日志
log_interval_steps=4,
save_dir=EXP_DIR,
pretrain_weights=None,
# 初始学习率大小
learning_rate=0.001,
# 是否使用early stopping策略,当精度不再改善时提前终止训练
early_stop=False,
# 是否启用VisualDL日志功能
use_vdl=True,
# 指定从某个检查点继续训练
resume_checkpoint=None)

@ -71,7 +71,7 @@ eval_dataset = pdrs.datasets.SegDataset(
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/segmenter.py
model = pdrs.tasks.seg.UNet(
input_channel=NUM_BANDS, num_classes=len(train_dataset.labels))
in_channels=NUM_BANDS, num_classes=len(train_dataset.labels))
# 执行模型训练
model.train(

Loading…
Cancel
Save