diff --git a/paddlers/rs_models/cd/bit.py b/paddlers/rs_models/cd/bit.py index 720969d..13f2290 100644 --- a/paddlers/rs_models/cd/bit.py +++ b/paddlers/rs_models/cd/bit.py @@ -56,7 +56,7 @@ class BIT(nn.Layer): Default: 2. enc_with_pos (bool, optional): Whether to add leanred positional embedding to the input feature sequence of the encoder. Default: True. - enc_depth (int, optional): Number of attention blocks used in the encoder. Default: 1 + enc_depth (int, optional): Number of attention blocks used in the encoder. Default: 1. enc_head_dim (int, optional): Embedding dimension of each encoder head. Default: 64. dec_depth (int, optional): Number of attention blocks used in the decoder. Default: 8. dec_head_dim (int, optional): Embedding dimension of each decoder head. Default: 8. diff --git a/paddlers/rs_models/seg/farseg.py b/paddlers/rs_models/seg/farseg.py index b9c6b95..7a5a62a 100644 --- a/paddlers/rs_models/seg/farseg.py +++ b/paddlers/rs_models/seg/farseg.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -This code is based on https://github.com/Z-Zheng/FarSeg -Ths copyright of Z-Zheng/FarSeg is as follows: -Apache License [see LICENSE for details] -""" + +# This code is based on https://github.com/Z-Zheng/FarSeg +# The copyright of Z-Zheng/FarSeg is as follows: +# Apache License (see https://github.com/Z-Zheng/FarSeg/blob/master/LICENSE for details). import math @@ -164,7 +163,7 @@ class SceneRelation(nn.Layer): return refined_feats -class AssymetricDecoder(nn.Layer): +class AsymmetricDecoder(nn.Layer): def __init__(self, in_channels, out_channels, @@ -172,7 +171,7 @@ class AssymetricDecoder(nn.Layer): out_feat_output_stride=4, norm_fn=nn.BatchNorm2D, num_groups_gn=None): - super(AssymetricDecoder, self).__init__() + super(AsymmetricDecoder, self).__init__() if norm_fn == nn.BatchNorm2D: norm_fn_args = dict(num_features=out_channels) elif norm_fn == nn.GroupNorm: @@ -215,9 +214,12 @@ class AssymetricDecoder(nn.Layer): class ResNet50Encoder(nn.Layer): - def __init__(self, pretrained=True): + def __init__(self, in_ch=3, pretrained=True): super(ResNet50Encoder, self).__init__() self.resnet = resnet50(pretrained=pretrained) + if in_ch != 3: + self.resnet.conv1 = nn.Conv2D( + in_ch, 64, kernel_size=7, stride=2, padding=3, bias_attr=False) def forward(self, inputs): x = inputs @@ -234,25 +236,35 @@ class ResNet50Encoder(nn.Layer): class FarSeg(nn.Layer): """ - The FarSeg implementation based on PaddlePaddle. + The FarSeg implementation based on PaddlePaddle. + + The original article refers to + Zheng, Zhuo, et al. "Foreground-Aware Relation Network for Geospatial Object Segmentation in High Spatial Resolution + Remote Sensing Imagery" + (https://openaccess.thecvf.com/content_CVPR_2020/papers/Zheng_Foreground-Aware_Relation_Network_for_Geospatial_Object_Segmentation_in_High_Spatial_CVPR_2020_paper.pdf) - The original article refers to - Zheng, Zhuo, et al. "Foreground-Aware Relation Network for Geospatial Object - Segmentation in High Spatial Resolution Remote Sensing Imagery" - (https://openaccess.thecvf.com/content_CVPR_2020/papers/Zheng_Foreground-Aware_Relation_Network_for_Geospatial_Object_Segmentation_in_High_Spatial_CVPR_2020_paper.pdf) + Args: + in_channels (int, optional): Number of bands of the input images. Default: 3. + num_classes (int, optional): Number of target classes. Default: 16. + fpn_ch_list (list[int]|tuple[int], optional): Channel list of the FPN. Default: (256, 512, 1024, 2048). + mid_ch (int, optional): Output channels of the FPN. Default: 256. + out_ch (int, optional): Output channels of the decoder. Default: 128. + sr_ch_list (list[int]|tuple[int], optional): Channel list of the foreground-scene relation module. Default: (256, 256, 256, 256). + pretrained_encoder (bool, optional): Whether to use a pretrained encoder. Default: True. """ def __init__(self, + in_channels=3, num_classes=16, fpn_ch_list=(256, 512, 1024, 2048), mid_ch=256, out_ch=128, sr_ch_list=(256, 256, 256, 256), - encoder_pretrained=True): + pretrained_encoder=True): super(FarSeg, self).__init__() - self.en = ResNet50Encoder(encoder_pretrained) + self.en = ResNet50Encoder(in_channels, pretrained_encoder) self.fpn = FPN(in_channels_list=fpn_ch_list, out_channels=mid_ch) - self.decoder = AssymetricDecoder( + self.decoder = AsymmetricDecoder( in_channels=mid_ch, out_channels=out_ch) self.cls_pred_conv = nn.Conv2D(out_ch, num_classes, 1) self.upsample4x_op = nn.UpsamplingBilinear2D(scale_factor=4) @@ -273,5 +285,4 @@ class FarSeg(nn.Layer): final_feat = self.decoder(refined_fpn_feat_list) cls_pred = self.cls_pred_conv(final_feat) cls_pred = self.upsample4x_op(cls_pred) - cls_pred = F.softmax(cls_pred, axis=1) return [cls_pred] diff --git a/paddlers/tasks/change_detector.py b/paddlers/tasks/change_detector.py index 7a45172..17f3446 100644 --- a/paddlers/tasks/change_detector.py +++ b/paddlers/tasks/change_detector.py @@ -31,7 +31,7 @@ import paddlers.utils.logging as logging from paddlers.models import seg_losses from paddlers.transforms import Resize, decode_image from paddlers.utils import get_single_card_bs -from paddlers.utils.checkpoint import seg_pretrain_weights_dict +from paddlers.utils.checkpoint import cd_pretrain_weights_dict from .base import BaseModel from .utils import seg_metrics as metrics from .utils.infer_nets import InferCDNet @@ -275,7 +275,7 @@ class BaseChangeDetector(BaseModel): exit=True) if pretrain_weights is not None and resume_checkpoint is not None: logging.error( - "pretrain_weights and resume_checkpoint cannot be set simultaneously.", + "`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.", exit=True) self.labels = train_dataset.labels if self.losses is None: @@ -289,23 +289,30 @@ class BaseChangeDetector(BaseModel): else: self.optimizer = optimizer - if pretrain_weights is not None and not osp.exists(pretrain_weights): - if pretrain_weights not in seg_pretrain_weights_dict[ - self.model_name]: - logging.warning( - "Path of pretrain_weights('{}') does not exist!".format( - pretrain_weights)) - logging.warning("Pretrain_weights is forcibly set to '{}'. " - "If don't want to use pretrain weights, " - "set pretrain_weights to be None.".format( - seg_pretrain_weights_dict[self.model_name][ - 0])) - pretrain_weights = seg_pretrain_weights_dict[self.model_name][0] - elif pretrain_weights is not None and osp.exists(pretrain_weights): - if osp.splitext(pretrain_weights)[-1] != '.pdparams': - logging.error( - "Invalid pretrain weights. Please specify a '.pdparams' file.", - exit=True) + if pretrain_weights is not None: + if not osp.exists(pretrain_weights): + if self.model_name not in cd_pretrain_weights_dict: + logging.warning( + "Path of pretrained weights ('{}') does not exist!". + format(pretrain_weights)) + pretrain_weights = None + elif pretrain_weights not in cd_pretrain_weights_dict[ + self.model_name]: + logging.warning( + "Path of pretrained weights ('{}') does not exist!". + format(pretrain_weights)) + pretrain_weights = cd_pretrain_weights_dict[ + self.model_name][0] + logging.warning( + "`pretrain_weights` is forcibly set to '{}'. " + "If you don't want to use pretrained weights, " + "please set `pretrain_weights` to None.".format( + pretrain_weights)) + else: + if osp.splitext(pretrain_weights)[-1] != '.pdparams': + logging.error( + "Invalid pretrained weights. Please specify a .pdparams file.", + exit=True) pretrained_dir = osp.join(save_dir, 'pretrain') is_backbone_weights = pretrain_weights == 'IMAGENET' self.net_initialize( diff --git a/paddlers/tasks/classifier.py b/paddlers/tasks/classifier.py index 23ab154..7c80301 100644 --- a/paddlers/tasks/classifier.py +++ b/paddlers/tasks/classifier.py @@ -246,7 +246,7 @@ class BaseClassifier(BaseModel): exit=True) if pretrain_weights is not None and resume_checkpoint is not None: logging.error( - "pretrain_weights and resume_checkpoint cannot be set simultaneously.", + "`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.", exit=True) self.labels = train_dataset.labels if self.losses is None: @@ -262,25 +262,32 @@ class BaseClassifier(BaseModel): else: self.optimizer = optimizer - if pretrain_weights is not None and not osp.exists(pretrain_weights): - if pretrain_weights not in cls_pretrain_weights_dict[ - self.model_name]: - logging.warning( - "Path of pretrain_weights('{}') does not exist!".format( - pretrain_weights)) - logging.warning("Pretrain_weights is forcibly set to '{}'. " - "If don't want to use pretrain weights, " - "set pretrain_weights to be None.".format( - cls_pretrain_weights_dict[self.model_name][ - 0])) - pretrain_weights = cls_pretrain_weights_dict[self.model_name][0] - elif pretrain_weights is not None and osp.exists(pretrain_weights): - if osp.splitext(pretrain_weights)[-1] != '.pdparams': - logging.error( - "Invalid pretrain weights. Please specify a '.pdparams' file.", - exit=True) + if pretrain_weights is not None: + if not osp.exists(pretrain_weights): + if self.model_name not in cls_pretrain_weights_dict: + logging.warning( + "Path of `pretrain_weights` ('{}') does not exist!". + format(pretrain_weights)) + pretrain_weights = None + elif pretrain_weights not in cls_pretrain_weights_dict[ + self.model_name]: + logging.warning( + "Path of `pretrain_weights` ('{}') does not exist!". + format(pretrain_weights)) + pretrain_weights = cls_pretrain_weights_dict[ + self.model_name][0] + logging.warning( + "`pretrain_weights` is forcibly set to '{}'. " + "If you don't want to use pretrained weights, " + "set `pretrain_weights` to None.".format( + pretrain_weights)) + else: + if osp.splitext(pretrain_weights)[-1] != '.pdparams': + logging.error( + "Invalid pretrained weights. Please specify a .pdparams file.", + exit=True) pretrained_dir = osp.join(save_dir, 'pretrain') - is_backbone_weights = False # pretrain_weights == 'IMAGENET' # TODO: this is backbone + is_backbone_weights = False self.net_initialize( pretrain_weights=pretrain_weights, save_dir=pretrained_dir, diff --git a/paddlers/tasks/object_detector.py b/paddlers/tasks/object_detector.py index 6531481..8b92ca0 100644 --- a/paddlers/tasks/object_detector.py +++ b/paddlers/tasks/object_detector.py @@ -274,7 +274,7 @@ class BaseDetector(BaseModel): exit=True) if pretrain_weights is not None and resume_checkpoint is not None: logging.error( - "pretrain_weights and resume_checkpoint cannot be set simultaneously.", + "`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.", exit=True) if train_dataset.__class__.__name__ == 'VOCDetDataset': train_dataset.data_fields = { @@ -323,23 +323,29 @@ class BaseDetector(BaseModel): self.optimizer = optimizer # Initiate weights - if pretrain_weights is not None and not osp.exists(pretrain_weights): - if pretrain_weights not in det_pretrain_weights_dict['_'.join( - [self.model_name, self.backbone_name])]: - logging.warning( - "Path of pretrain_weights('{}') does not exist!".format( - pretrain_weights)) - pretrain_weights = det_pretrain_weights_dict['_'.join( - [self.model_name, self.backbone_name])][0] - logging.warning("Pretrain_weights is forcibly set to '{}'. " - "If you don't want to use pretrain weights, " - "set pretrain_weights to be None.".format( - pretrain_weights)) - elif pretrain_weights is not None and osp.exists(pretrain_weights): - if osp.splitext(pretrain_weights)[-1] != '.pdparams': - logging.error( - "Invalid pretrain weights. Please specify a '.pdparams' file.", - exit=True) + if pretrain_weights is not None: + if not osp.exists(pretrain_weights): + key = '_'.join([self.model_name, self.backbone_name]) + if key not in det_pretrain_weights_dict: + logging.warning( + "Path of pretrained weights ('{}') does not exist!". + format(pretrain_weights)) + pretrain_weights = None + elif pretrain_weights not in det_pretrain_weights_dict[key]: + logging.warning( + "Path of pretrained weights ('{}') does not exist!". + format(pretrain_weights)) + pretrain_weights = det_pretrain_weights_dict[key][0] + logging.warning( + "`pretrain_weights` is forcibly set to '{}'. " + "If you don't want to use pretrained weights, " + "please set `pretrain_weights` to None.".format( + pretrain_weights)) + else: + if osp.splitext(pretrain_weights)[-1] != '.pdparams': + logging.error( + "Invalid pretrained weights. Please specify a .pdparams file.", + exit=True) pretrained_dir = osp.join(save_dir, 'pretrain') self.net_initialize( pretrain_weights=pretrain_weights, diff --git a/paddlers/tasks/restorer.py b/paddlers/tasks/restorer.py index fe17f82..d39e283 100644 --- a/paddlers/tasks/restorer.py +++ b/paddlers/tasks/restorer.py @@ -31,6 +31,7 @@ from paddlers.models import res_losses from paddlers.transforms import Resize, decode_image from paddlers.transforms.functions import calc_hr_shape from paddlers.utils import get_single_card_bs +from paddlers.utils.checkpoint import res_pretrain_weights_dict from .base import BaseModel from .utils.res_adapters import GANAdapter, OptimizerAdapter from .utils.infer_nets import InferResNet @@ -234,7 +235,7 @@ class BaseRestorer(BaseModel): exit=True) if pretrain_weights is not None and resume_checkpoint is not None: logging.error( - "pretrain_weights and resume_checkpoint cannot be set simultaneously.", + "`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.", exit=True) if self.losses is None: @@ -256,14 +257,30 @@ class BaseRestorer(BaseModel): else: self.optimizer = optimizer - if pretrain_weights is not None and not osp.exists(pretrain_weights): - logging.warning("Path of pretrain_weights('{}') does not exist!". - format(pretrain_weights)) - elif pretrain_weights is not None and osp.exists(pretrain_weights): - if osp.splitext(pretrain_weights)[-1] != '.pdparams': - logging.error( - "Invalid pretrain weights. Please specify a '.pdparams' file.", - exit=True) + if pretrain_weights is not None: + if not osp.exists(pretrain_weights): + if self.model_name not in res_pretrain_weights_dict: + logging.warning( + "Path of pretrained weights ('{}') does not exist!". + format(pretrain_weights)) + pretrain_weights = None + elif pretrain_weights not in res_pretrain_weights_dict[ + self.model_name]: + logging.warning( + "Path of pretrained weights ('{}') does not exist!". + format(pretrain_weights)) + pretrain_weights = res_pretrain_weights_dict[ + self.model_name][0] + logging.warning( + "`pretrain_weights` is forcibly set to '{}'. " + "If you don't want to use pretrained weights, " + "please set `pretrain_weights` to None.".format( + pretrain_weights)) + else: + if osp.splitext(pretrain_weights)[-1] != '.pdparams': + logging.error( + "Invalid pretrained weights. Please specify a .pdparams file.", + exit=True) pretrained_dir = osp.join(save_dir, 'pretrain') is_backbone_weights = pretrain_weights == 'IMAGENET' self.net_initialize( diff --git a/paddlers/tasks/segmenter.py b/paddlers/tasks/segmenter.py index b9c586f..bfa8977 100644 --- a/paddlers/tasks/segmenter.py +++ b/paddlers/tasks/segmenter.py @@ -267,7 +267,7 @@ class BaseSegmenter(BaseModel): exit=True) if pretrain_weights is not None and resume_checkpoint is not None: logging.error( - "pretrain_weights and resume_checkpoint cannot be set simultaneously.", + "`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.", exit=True) self.labels = train_dataset.labels if self.losses is None: @@ -281,23 +281,30 @@ class BaseSegmenter(BaseModel): else: self.optimizer = optimizer - if pretrain_weights is not None and not osp.exists(pretrain_weights): - if pretrain_weights not in seg_pretrain_weights_dict[ - self.model_name]: - logging.warning( - "Path of pretrain_weights('{}') does not exist!".format( - pretrain_weights)) - logging.warning("Pretrain_weights is forcibly set to '{}'. " - "If don't want to use pretrain weights, " - "set pretrain_weights to be None.".format( - seg_pretrain_weights_dict[self.model_name][ - 0])) - pretrain_weights = seg_pretrain_weights_dict[self.model_name][0] - elif pretrain_weights is not None and osp.exists(pretrain_weights): - if osp.splitext(pretrain_weights)[-1] != '.pdparams': - logging.error( - "Invalid pretrain weights. Please specify a '.pdparams' file.", - exit=True) + if pretrain_weights is not None: + if not osp.exists(pretrain_weights): + if self.model_name not in seg_pretrain_weights_dict: + logging.warning( + "Path of pretrained weights ('{}') does not exist!". + format(pretrain_weights)) + pretrain_weights = None + elif pretrain_weights not in seg_pretrain_weights_dict[ + self.model_name]: + logging.warning( + "Path of pretrained weights ('{}') does not exist!". + format(pretrain_weights)) + pretrain_weights = seg_pretrain_weights_dict[ + self.model_name][0] + logging.warning( + "`pretrain_weights` is forcibly set to '{}'. " + "If you don't want to use pretrained weights, " + "please set `pretrain_weights` to None.".format( + pretrain_weights)) + else: + if osp.splitext(pretrain_weights)[-1] != '.pdparams': + logging.error( + "Invalid pretrained weights. Please specify a .pdparams file.", + exit=True) pretrained_dir = osp.join(save_dir, 'pretrain') is_backbone_weights = pretrain_weights == 'IMAGENET' self.net_initialize( @@ -909,6 +916,7 @@ class BiSeNetV2(BaseSegmenter): class FarSeg(BaseSegmenter): def __init__(self, + in_channels=3, num_classes=2, use_mixed_loss=False, losses=None, @@ -918,4 +926,5 @@ class FarSeg(BaseSegmenter): num_classes=num_classes, use_mixed_loss=use_mixed_loss, losses=losses, + in_channels=in_channels, **params) diff --git a/paddlers/utils/checkpoint.py b/paddlers/utils/checkpoint.py index 82b5b12..029c4fc 100644 --- a/paddlers/utils/checkpoint.py +++ b/paddlers/utils/checkpoint.py @@ -21,20 +21,14 @@ import paddle from . import logging from .download import download_and_decompress +cd_pretrain_weights_dict = {} + cls_pretrain_weights_dict = { 'ResNet50_vd': ['IMAGENET'], 'MobileNetV3_small_x1_0': ['IMAGENET'], 'HRNet_W18_C': ['IMAGENET'], } -seg_pretrain_weights_dict = { - 'UNet': ['CITYSCAPES'], - 'DeepLabV3P': ['CITYSCAPES', 'PascalVOC', 'IMAGENET'], - 'FastSCNN': ['CITYSCAPES'], - 'HRNet': ['CITYSCAPES', 'PascalVOC'], - 'BiSeNetV2': ['CITYSCAPES'] -} - det_pretrain_weights_dict = { 'PicoDet_ESNet_s': ['COCO', 'IMAGENET'], 'PicoDet_ESNet_m': ['COCO', 'IMAGENET'], @@ -74,6 +68,16 @@ det_pretrain_weights_dict = { 'MaskRCNN_ResNet101_vd_fpn': ['COCO', 'IMAGENET'] } +res_pretrain_weights_dict = {} + +seg_pretrain_weights_dict = { + 'UNet': ['CITYSCAPES'], + 'DeepLabV3P': ['CITYSCAPES', 'PascalVOC', 'IMAGENET'], + 'FastSCNN': ['CITYSCAPES'], + 'HRNet': ['CITYSCAPES', 'PascalVOC'], + 'BiSeNetV2': ['CITYSCAPES'] +} + cityscapes_weights = { 'UNet_CITYSCAPES': 'https://bj.bcebos.com/paddleseg/dygraph/cityscapes/unet_cityscapes_1024x512_160k/model.pdparams', diff --git a/test_tipc/README.md b/test_tipc/README.md index 70fd203..934a83f 100644 --- a/test_tipc/README.md +++ b/test_tipc/README.md @@ -44,6 +44,7 @@ | 目标检测 | PP-YOLOv2 | 支持 | - | - | - | | 目标检测 | YOLOv3 | 支持 | - | - | - | | 图像分割 | DeepLab V3+ | 支持 | - | - | - | +| 图像分割 | FarSeg | 支持 | - | - | - | | 图像分割 | UNet | 支持 | - | - | - | ## 3 测试工具简介 diff --git a/test_tipc/configs/seg/farseg/farseg_rsseg.yaml b/test_tipc/configs/seg/farseg/farseg_rsseg.yaml new file mode 100644 index 0000000..fa6d97b --- /dev/null +++ b/test_tipc/configs/seg/farseg/farseg_rsseg.yaml @@ -0,0 +1,11 @@ +# Configurations of FarSeg with RSSeg dataset + +_base_: ../_base_/rsseg.yaml + +save_dir: ./test_tipc/output/seg/farseg/ + +model: !Node + type: FarSeg + args: + in_channels: 10 + num_classes: 5 \ No newline at end of file diff --git a/test_tipc/configs/seg/farseg/train_infer_python.txt b/test_tipc/configs/seg/farseg/train_infer_python.txt new file mode 100644 index 0000000..6619052 --- /dev/null +++ b/test_tipc/configs/seg/farseg/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:seg:farseg +python:python +gpu_list:0|0,1 +use_gpu:null|null +--precision:null +--num_epochs:lite_train_lite_infer=3|lite_train_whole_infer=3|whole_train_whole_infer=20 +--save_dir:adaptive +--train_batch_size:lite_train_lite_infer=4|lite_train_whole_infer=4|whole_train_whole_infer=4 +--model_path:null +--config:lite_train_lite_infer=./test_tipc/configs/seg/farseg/farseg_rsseg.yaml|lite_train_whole_infer=./test_tipc/configs/seg/farseg/farseg_rsseg.yaml|whole_train_whole_infer=./test_tipc/configs/seg/farseg/farseg_rsseg.yaml +train_model_name:best_model +null:null +## +trainer:norm +norm_train:test_tipc/run_task.py train seg +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================export_params=========================== +--save_dir:adaptive +--model_dir:adaptive +--fixed_input_shape:[-1,10,512,512] +norm_export:deploy/export/export_model.py +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +===========================infer_params=========================== +infer_model:null +infer_export:null +infer_quant:False +inference:test_tipc/infer.py +--device:cpu|gpu +--enable_mkldnn:True +--cpu_threads:6 +--batch_size:1 +--use_trt:False +--precision:fp32 +--model_dir:null +--config:null +--save_log_path:null +--benchmark:True +--model_name:farseg +null:null \ No newline at end of file diff --git a/test_tipc/docs/test_train_inference_python.md b/test_tipc/docs/test_train_inference_python.md index 72a321b..b117cce 100644 --- a/test_tipc/docs/test_train_inference_python.md +++ b/test_tipc/docs/test_train_inference_python.md @@ -31,6 +31,7 @@ Linux GPU/CPU 基础训练推理测试的主程序为`test_train_inference_pytho | 目标检测 | PP-YOLOv2 | 正常训练 | 正常训练 | mAP=59.37% | | 目标检测 | YOLOv3 | 正常训练 | 正常训练 | mAP=47.33% | | 图像分割 | DeepLab V3+ | 正常训练 | 正常训练 | mIoU=56.05% | +| 图像分割 | FarSeg | 正常训练 | 正常训练 | mIoU=49.58% | | 图像分割 | UNet | 正常训练 | 正常训练 | mIoU=55.50% | *注:参考预测精度为whole_train_whole_infer模式下单卡训练汇报的精度数据。* @@ -61,6 +62,7 @@ Linux GPU/CPU 基础训练推理测试的主程序为`test_train_inference_pytho | 目标检测 | PP-YOLOv2 | 支持 | 支持 | 1 | | 目标检测 | YOLOv3 | 支持 | 支持 | 1 | | 图像分割 | DeepLab V3+ | 支持 | 支持 | 1 | +| 图像分割 | FarSeg | 支持 | 支持 | 1 | | 图像分割 | UNet | 支持 | 支持 | 1 | ## 2 测试流程 diff --git a/tests/rs_models/test_seg_models.py b/tests/rs_models/test_seg_models.py index 88fb6e1..4b19f83 100644 --- a/tests/rs_models/test_seg_models.py +++ b/tests/rs_models/test_seg_models.py @@ -50,7 +50,8 @@ class TestFarSegModel(TestSegModel): def set_specs(self): self.specs = [ - dict(), dict(num_classes=20), dict(encoder_pretrained=False) + dict(), dict(num_classes=20), dict(pretrained_encoder=False), + dict(in_channels=10) ] def set_targets(self): diff --git a/tutorials/train/README.md b/tutorials/train/README.md index 44e93a3..44c2491 100644 --- a/tutorials/train/README.md +++ b/tutorials/train/README.md @@ -27,6 +27,7 @@ |object_detection/ppyolov2.py | 目标检测 | PP-YOLOv2 | |object_detection/yolov3.py | 目标检测 | YOLOv3 | |semantic_segmentation/deeplabv3p.py | 图像分割 | DeepLab V3+ | +|semantic_segmentation/farseg.py | 图像分割 | FarSeg | |semantic_segmentation/unet.py | 图像分割 | UNet | ## 环境准备 diff --git a/tutorials/train/semantic_segmentation/deeplabv3p.py b/tutorials/train/semantic_segmentation/deeplabv3p.py index b3dbd50..4e0cb0b 100644 --- a/tutorials/train/semantic_segmentation/deeplabv3p.py +++ b/tutorials/train/semantic_segmentation/deeplabv3p.py @@ -71,7 +71,7 @@ eval_dataset = pdrs.datasets.SegDataset( # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/segmenter.py model = pdrs.tasks.seg.DeepLabV3P( - input_channel=NUM_BANDS, + in_channels=NUM_BANDS, num_classes=len(train_dataset.labels), backbone='ResNet50_vd') diff --git a/tutorials/train/semantic_segmentation/farseg.py b/tutorials/train/semantic_segmentation/farseg.py new file mode 100644 index 0000000..f8561b5 --- /dev/null +++ b/tutorials/train/semantic_segmentation/farseg.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python + +# 图像分割模型FarSeg训练示例脚本 +# 执行此脚本前,请确认已正确安装PaddleRS库 + +import paddlers as pdrs +from paddlers import transforms as T + +# 数据集存放目录 +DATA_DIR = './data/rsseg/' +# 训练集`file_list`文件路径 +TRAIN_FILE_LIST_PATH = './data/rsseg/train.txt' +# 验证集`file_list`文件路径 +EVAL_FILE_LIST_PATH = './data/rsseg/val.txt' +# 数据集类别信息文件路径 +LABEL_LIST_PATH = './data/rsseg/labels.txt' +# 实验目录,保存输出的模型权重和结果 +EXP_DIR = './output/farseg/' + +# 下载和解压多光谱地块分类数据集 +pdrs.utils.download_and_decompress( + 'https://paddlers.bj.bcebos.com/datasets/rsseg.zip', path='./data/') + +# 定义训练和验证时使用的数据变换(数据增强、预处理等) +# 使用Compose组合多种变换方式。Compose中包含的变换将按顺序串行执行 +# API说明:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/apis/data.md +train_transforms = T.Compose([ + # 读取影像 + T.DecodeImg(), + # 选择前三个波段 + T.SelectBand([1, 2, 3]), + # 将影像缩放到512x512大小 + T.Resize(target_size=512), + # 以50%的概率实施随机水平翻转 + T.RandomHorizontalFlip(prob=0.5), + # 将数据归一化到[-1,1] + T.Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + T.ArrangeSegmenter('train') +]) + +eval_transforms = T.Compose([ + T.DecodeImg(), + # 验证阶段与训练阶段应当选择相同的波段 + T.SelectBand([1, 2, 3]), + T.Resize(target_size=512), + # 验证阶段与训练阶段的数据归一化方式必须相同 + T.Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + T.ReloadMask(), + T.ArrangeSegmenter('eval') +]) + +# 分别构建训练和验证所用的数据集 +train_dataset = pdrs.datasets.SegDataset( + data_dir=DATA_DIR, + file_list=TRAIN_FILE_LIST_PATH, + label_list=LABEL_LIST_PATH, + transforms=train_transforms, + num_workers=0, + shuffle=True) + +eval_dataset = pdrs.datasets.SegDataset( + data_dir=DATA_DIR, + file_list=EVAL_FILE_LIST_PATH, + label_list=LABEL_LIST_PATH, + transforms=eval_transforms, + num_workers=0, + shuffle=False) + +# 构建FarSeg模型 +# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md +# 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/segmenter.py +model = pdrs.tasks.seg.FarSeg(num_classes=len(train_dataset.labels)) + +# 执行模型训练 +model.train( + num_epochs=10, + train_dataset=train_dataset, + train_batch_size=4, + eval_dataset=eval_dataset, + save_interval_epochs=5, + # 每多少次迭代记录一次日志 + log_interval_steps=4, + save_dir=EXP_DIR, + pretrain_weights=None, + # 初始学习率大小 + learning_rate=0.001, + # 是否使用early stopping策略,当精度不再改善时提前终止训练 + early_stop=False, + # 是否启用VisualDL日志功能 + use_vdl=True, + # 指定从某个检查点继续训练 + resume_checkpoint=None) diff --git a/tutorials/train/semantic_segmentation/unet.py b/tutorials/train/semantic_segmentation/unet.py index e1e8b82..1aee709 100644 --- a/tutorials/train/semantic_segmentation/unet.py +++ b/tutorials/train/semantic_segmentation/unet.py @@ -71,7 +71,7 @@ eval_dataset = pdrs.datasets.SegDataset( # 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md # 模型输入参数请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/paddlers/tasks/segmenter.py model = pdrs.tasks.seg.UNet( - input_channel=NUM_BANDS, num_classes=len(train_dataset.labels)) + in_channels=NUM_BANDS, num_classes=len(train_dataset.labels)) # 执行模型训练 model.train(