diff --git a/docs/README.md b/docs/README.md index 40174d8..2479f0f 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1 +1,5 @@ -PaddleSeg commit fec42fd869b6f796c74cd510671595e3512bc8e9 \ No newline at end of file +PaddleSeg commit fec42fd869b6f796c74cd510671595e3512bc8e9 + +# 开发规范 +请注意,paddlers/models/ppxxx系列除了修改import路径和支持多通道模型外,不要增删改任何代码。 +新增的模型需放在paddlers/models/下的seg、det、cls、cd目录下。 \ No newline at end of file diff --git a/docs/datasets.md b/docs/datasets.md new file mode 100644 index 0000000..c56bdae --- /dev/null +++ b/docs/datasets.md @@ -0,0 +1,40 @@ +# 遥感数据集 + +遥感影像的格式多种多样,不同传感器产生的数据格式也可能不同。PaddleRS至少兼容以下6种格式图片读取: + +- `tif` +- `png`, `jpeg`, `bmp` +- `img` +- `npy` + +标注图要求必须为单通道的png格式图像,像素值即为对应的类别,像素标注类别需要从0开始递增。例如0,1,2,3表示有4种类别,255用于指定不参与训练和评估的像素,标注类别最多为256类。 + +## L8 SPARCS数据集 +[L8 SPARCS公开数据集](https://www.usgs.gov/land-resources/nli/landsat/spatial-procedures-automated-removal-cloud-and-shadow-sparcs-validation)进行云雪分割,该数据集包含80张卫星影像,涵盖10个波段。原始标注图片包含7个类别,分别是`cloud`, `cloud shadow`, `shadow over water`, `snow/ice`, `water`, `land`和`flooded`。由于`flooded`和`shadow over water`2个类别占比仅为`1.8%`和`0.24%`,我们将其进行合并,`flooded`归为`land`,`shadow over water`归为`shadow`,合并后标注包含5个类别。 + +数值、类别、颜色对应表: + +|Pixel value|Class|Color| +|---|---|---| +|0|cloud|white| +|1|shadow|black| +|2|snow/ice|cyan| +|3|water|blue| +|4|land|grey| + +
+
+
+
+ L8 SPARCS数据集示例 +
+ +执行以下命令下载并解压经过类别合并后的数据集: +```shell script +mkdir dataset && cd dataset +wget https://paddleseg.bj.bcebos.com/dataset/remote_sensing_seg.zip +unzip remote_sensing_seg.zip +cd .. +``` +其中`data`目录存放遥感影像,`data_vis`目录存放彩色合成预览图,`mask`目录存放标注图。 diff --git a/paddlers/datasets/seg_dataset.py b/paddlers/datasets/seg_dataset.py index ae4fcd5..23cd266 100644 --- a/paddlers/datasets/seg_dataset.py +++ b/paddlers/datasets/seg_dataset.py @@ -64,10 +64,10 @@ class SegDataset(Dataset): " file_list[{}] has a space in the image or label path.".format(line, file_list)) items[0] = path_normalization(items[0]) items[1] = path_normalization(items[1]) - if not is_pic(items[0]) or not is_pic(items[1]): - continue full_path_im = osp.join(data_dir, items[0]) full_path_label = osp.join(data_dir, items[1]) + if not is_pic(full_path_im) or not is_pic(full_path_label): + continue if not osp.exists(full_path_im): raise IOError('Image file {} does not exist!'.format( full_path_im)) diff --git a/paddlers/datasets/voc.py b/paddlers/datasets/voc.py index b0d4fd8..111db3e 100644 --- a/paddlers/datasets/voc.py +++ b/paddlers/datasets/voc.py @@ -23,7 +23,7 @@ from collections import OrderedDict import xml.etree.ElementTree as ET from paddle.io import Dataset from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic -from paddlers.transforms import Decode, MixupImage +from paddlers.transforms import ImgDecoder, MixupImage from paddlers.tools import YOLOAnchorCluster @@ -319,7 +319,7 @@ class VOCDetection(Dataset): if self.data_fields is not None: sample_mix = {k: sample_mix[k] for k in self.data_fields} sample = self.mixup_op(sample=[ - Decode(to_rgb=False)(sample), Decode(to_rgb=False)(sample_mix) + ImgDecoder(to_rgb=False)(sample), ImgDecoder(to_rgb=False)(sample_mix) ]) sample = self.transforms(sample) return sample diff --git a/paddlers/models/ppseg/models/backbones/resnet_vd.py b/paddlers/models/ppseg/models/backbones/resnet_vd.py index 90c92f3..357af4e 100644 --- a/paddlers/models/ppseg/models/backbones/resnet_vd.py +++ b/paddlers/models/ppseg/models/backbones/resnet_vd.py @@ -211,13 +211,14 @@ class ResNet_vd(nn.Layer): """ def __init__(self, + input_channel=3, layers=50, output_stride=8, multi_grid=(1, 1, 1), pretrained=None, data_format='NCHW'): super(ResNet_vd, self).__init__() - + self.data_format = data_format self.conv1_logit = None # for gscnn shape stream self.layers = layers @@ -251,7 +252,7 @@ class ResNet_vd(nn.Layer): dilation_dict = {3: 2} self.conv1_1 = ConvBNLayer( - in_channels=3, + in_channels=input_channel, out_channels=32, kernel_size=3, stride=2, diff --git a/paddlers/tasks/changedetector.py b/paddlers/tasks/changedetector.py index b6cb37b..fa42ef1 100644 --- a/paddlers/tasks/changedetector.py +++ b/paddlers/tasks/changedetector.py @@ -28,7 +28,7 @@ import paddlers.utils.logging as logging from .base import BaseModel from .utils import seg_metrics as metrics from paddlers.utils.checkpoint import seg_pretrain_weights_dict -from paddlers.transforms import Decode, Resize +from paddlers.transforms import ImgDecoder, Resize from paddlers.models.ppcd import CDNet as _CDNet __all__ = ["CDNet"] @@ -516,7 +516,7 @@ class BaseChangeDetector(BaseModel): for im in images: sample = {'image': im} if isinstance(sample['image'], str): - sample = Decode(to_rgb=False)(sample) + sample = ImgDecoder(to_rgb=False)(sample) ori_shape = sample['image'].shape[:2] im = transforms(sample)[0] batch_im.append(im) diff --git a/paddlers/tasks/classifier.py b/paddlers/tasks/classifier.py index 34b86ca..fa03ba8 100644 --- a/paddlers/tasks/classifier.py +++ b/paddlers/tasks/classifier.py @@ -29,7 +29,7 @@ from paddlers.models.ppcls.metric import build_metrics from paddlers.models.ppcls.loss import build_loss from paddlers.models.ppcls.data.postprocess import build_postprocess from paddlers.utils.checkpoint import cls_pretrain_weights_dict -from paddlers.transforms import Decode, Resize +from paddlers.transforms import ImgDecoder, Resize __all__ = ["ResNet50_vd", "MobileNetV3_small_x1_0", "HRNet_W18_C"] @@ -433,7 +433,7 @@ class BaseClassifier(BaseModel): for im in images: sample = {'image': im} if isinstance(sample['image'], str): - sample = Decode(to_rgb=False)(sample) + sample = ImgDecoder(to_rgb=False)(sample) ori_shape = sample['image'].shape[:2] im = transforms(sample)[0] batch_im.append(im) diff --git a/paddlers/tasks/segmenter.py b/paddlers/tasks/segmenter.py index 63510b9..8e63e3d 100644 --- a/paddlers/tasks/segmenter.py +++ b/paddlers/tasks/segmenter.py @@ -28,7 +28,7 @@ import paddlers.utils.logging as logging from .base import BaseModel from .utils import seg_metrics as metrics from paddlers.utils.checkpoint import seg_pretrain_weights_dict -from paddlers.transforms import Decode, Resize +from paddlers.transforms import ImgDecoder, Resize __all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"] @@ -525,7 +525,7 @@ class BaseSegmenter(BaseModel): for im in images: sample = {'image': im} if isinstance(sample['image'], str): - sample = Decode(to_rgb=False)(sample) + sample = ImgDecode(to_rgb=False)(sample) ori_shape = sample['image'].shape[:2] im = transforms(sample)[0] batch_im.append(im) @@ -679,6 +679,7 @@ class UNet(BaseSegmenter): class DeepLabV3P(BaseSegmenter): def __init__(self, + input_channel=3, num_classes=2, backbone='ResNet50_vd', use_mixed_loss=False, @@ -696,6 +697,7 @@ class DeepLabV3P(BaseSegmenter): if params.get('with_net', True): with DisablePrint(): backbone = getattr(paddleseg.models, backbone)( + input_channel=input_channel, output_stride=output_stride) else: backbone = None diff --git a/paddlers/transforms/img_decoder.py b/paddlers/transforms/img_decoder.py deleted file mode 100644 index 5720d88..0000000 --- a/paddlers/transforms/img_decoder.py +++ /dev/null @@ -1,157 +0,0 @@ -# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import os.path as osp -import cv2 -import copy -import random -import imghdr -from PIL import Image - -try: - from collections.abc import Sequence -except Exception: - from collections import Sequence - -# from paddlers.transforms.operators import Transform - - -class Transform(object): - """ - Parent class of all data augmentation operations - """ - - def __init__(self): - pass - - def apply_im(self, image): - pass - - def apply_mask(self, mask): - pass - - def apply_bbox(self, bbox): - pass - - def apply_segm(self, segms): - pass - - def apply(self, sample): - sample['image'] = self.apply_im(sample['image']) - if 'mask' in sample: - sample['mask'] = self.apply_mask(sample['mask']) - if 'gt_bbox' in sample: - sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox']) - - return sample - - def __call__(self, sample): - if isinstance(sample, Sequence): - sample = [self.apply(s) for s in sample] - else: - sample = self.apply(sample) - - return sample - - -class ImgDecode(Transform): - """ - Decode image(s) in input. - Args: - to_rgb (bool, optional): If True, convert input images from BGR format to RGB format. Defaults to True. - """ - - def __init__(self, to_rgb=True): - super(ImgDecode, self).__init__() - self.to_rgb = to_rgb - - def read_img(self, img_path, input_channel=3): - img_format = imghdr.what(img_path) - name, ext = osp.splitext(img_path) - if img_format == 'tiff' or ext == '.img': - try: - import gdal - except: - try: - from osgeo import gdal - except: - raise Exception( - "Failed to import gdal! You can try use conda to install gdal" - ) - six.reraise(*sys.exc_info()) - - dataset = gdal.Open(img_path) - if dataset == None: - raise Exception('Can not open', img_path) - im_data = dataset.ReadAsArray() - return im_data.transpose((1, 2, 0)) - elif img_format in ['jpeg', 'bmp', 'png', 'jpg']: - if input_channel == 3: - return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH | - cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR) - else: - return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH | - cv2.IMREAD_ANYCOLOR) - elif ext == '.npy': - return np.load(img_path) - else: - raise Exception('Image format {} is not supported!'.format(ext)) - - def apply_im(self, im_path): - if isinstance(im_path, str): - try: - image = self.read_img(im_path) - except: - raise ValueError('Cannot read the image file {}!'.format( - im_path)) - else: - image = im_path - - if self.to_rgb: - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - - return image - - def apply_mask(self, mask): - try: - mask = np.asarray(Image.open(mask)) - except: - raise ValueError("Cannot read the mask file {}!".format(mask)) - if len(mask.shape) != 2: - raise Exception( - "Mask should be a 1-channel image, but recevied is a {}-channel image.". - format(mask.shape[2])) - return mask - - def apply(self, sample): - """ - Args: - sample (dict): Input sample, containing 'image' at least. - Returns: - dict: Decoded sample. - """ - sample['image'] = self.apply_im(sample['image']) - if 'mask' in sample: - sample['mask'] = self.apply_mask(sample['mask']) - im_height, im_width, _ = sample['image'].shape - se_height, se_width = sample['mask'].shape - if im_height != se_height or im_width != se_width: - raise Exception( - "The height or width of the im is not same as the mask") - - sample['im_shape'] = np.array( - sample['image'].shape[:2], dtype=np.float32) - sample['scale_factor'] = np.array([1., 1.], dtype=np.float32) - return sample \ No newline at end of file diff --git a/paddlers/transforms/operators.py b/paddlers/transforms/operators.py index ed2729c..c365093 100644 --- a/paddlers/transforms/operators.py +++ b/paddlers/transforms/operators.py @@ -31,7 +31,7 @@ from .functions import normalize, horizontal_flip, permute, vertical_flip, cente crop_rle, expand_poly, expand_rle, resize_poly, resize_rle __all__ = [ - "Compose", "Decode", "Resize", "RandomResize", "ResizeByShort", + "Compose", "ImgDecoder", "Resize", "RandomResize", "ResizeByShort", "RandomResizeByShort", "ResizeByLong", "RandomHorizontalFlip", "RandomVerticalFlip", "Normalize", "CenterCrop", "RandomCrop", "RandomScaleAspect", "RandomExpand", "Padding", "MixupImage", @@ -90,66 +90,15 @@ class Transform(object): return sample -class Compose(Transform): - """ - Apply a series of data augmentation to the input. - All input images are in Height-Width-Channel ([H, W, C]) format. - - Args: - transforms (List[paddlers.transforms.Transform]): List of data preprocess or augmentations. - Raises: - TypeError: Invalid type of transforms. - ValueError: Invalid length of transforms. - """ - - def __init__(self, transforms): - super(Compose, self).__init__() - if not isinstance(transforms, list): - raise TypeError( - 'Type of transforms is invalid. Must be List, but received is {}' - .format(type(transforms))) - if len(transforms) < 1: - raise ValueError( - 'Length of transforms must not be less than 1, but received is {}' - .format(len(transforms))) - self.transforms = transforms - self.decode_image = Decode() - self.arrange_outputs = None - self.apply_im_only = False - - def __call__(self, sample): - if self.apply_im_only and 'mask' in sample: - mask_backup = copy.deepcopy(sample['mask']) - del sample['mask'] - - sample = self.decode_image(sample) - - for op in self.transforms: - # skip batch transforms amd mixup - if isinstance(op, (paddlers.transforms.BatchRandomResize, - paddlers.transforms.BatchRandomResizeByShort, - MixupImage)): - continue - sample = op(sample) - - if self.arrange_outputs is not None: - if self.apply_im_only: - sample['mask'] = mask_backup - sample = self.arrange_outputs(sample) - - return sample - - -class Decode(Transform): +class ImgDecoder(Transform): """ Decode image(s) in input. - Args: to_rgb (bool, optional): If True, convert input images from BGR format to RGB format. Defaults to True. """ def __init__(self, to_rgb=True): - super(Decode, self).__init__() + super(ImgDecoder, self).__init__() self.to_rgb = to_rgb def read_img(self, img_path, input_channel=3): @@ -172,7 +121,7 @@ class Decode(Transform): raise Exception('Can not open', img_path) im_data = dataset.ReadAsArray() if im_data.ndim == 3: - im_data.transpose((1, 2, 0)) + im_data = im_data.transpose((1, 2, 0)) return im_data elif img_format in ['jpeg', 'bmp', 'png', 'jpg']: if input_channel == 3: @@ -196,7 +145,7 @@ class Decode(Transform): else: image = im_path - if self.to_rgb: + if self.to_rgb and image.shape[-1] == 3: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return image @@ -214,13 +163,10 @@ class Decode(Transform): def apply(self, sample): """ - Args: sample (dict): Input sample, containing 'image' at least. - Returns: dict: Decoded sample. - """ if 'image' in sample: sample['image'] = self.apply_im(sample['image']) @@ -234,12 +180,63 @@ class Decode(Transform): if im_height != se_height or im_width != se_width: raise Exception( "The height or width of the im is not same as the mask") + sample['im_shape'] = np.array( sample['image'].shape[:2], dtype=np.float32) sample['scale_factor'] = np.array([1., 1.], dtype=np.float32) return sample +class Compose(Transform): + """ + Apply a series of data augmentation to the input. + All input images are in Height-Width-Channel ([H, W, C]) format. + + Args: + transforms (List[paddlers.transforms.Transform]): List of data preprocess or augmentations. + Raises: + TypeError: Invalid type of transforms. + ValueError: Invalid length of transforms. + """ + + def __init__(self, transforms): + super(Compose, self).__init__() + if not isinstance(transforms, list): + raise TypeError( + 'Type of transforms is invalid. Must be List, but received is {}' + .format(type(transforms))) + if len(transforms) < 1: + raise ValueError( + 'Length of transforms must not be less than 1, but received is {}' + .format(len(transforms))) + self.transforms = transforms + self.decode_image = ImgDecoder() + self.arrange_outputs = None + self.apply_im_only = False + + def __call__(self, sample): + if self.apply_im_only and 'mask' in sample: + mask_backup = copy.deepcopy(sample['mask']) + del sample['mask'] + + sample = self.decode_image(sample) + + for op in self.transforms: + # skip batch transforms amd mixup + if isinstance(op, (paddlers.transforms.BatchRandomResize, + paddlers.transforms.BatchRandomResizeByShort, + MixupImage)): + continue + sample = op(sample) + + if self.arrange_outputs is not None: + if self.apply_im_only: + sample['mask'] = mask_backup + sample = self.arrange_outputs(sample) + + return sample + + class Resize(Transform): """ Resize input. @@ -618,10 +615,16 @@ class Normalize(Transform): def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], - min_val=[0, 0, 0], - max_val=[255., 255., 255.], + min_val=None, + max_val=None, is_scale=True): super(Normalize, self).__init__() + channel = len(mean) + if min_val is None: + min_val = [0] * channel + if max_val is None: + max_val = [255.] * channel + from functools import reduce if reduce(lambda x, y: x * y, std) == 0: raise ValueError( @@ -633,7 +636,6 @@ class Normalize(Transform): '(max_val - min_val) should not have 0, but received is {}'. format((np.asarray(max_val) - np.asarray(min_val)).tolist( ))) - self.mean = mean self.std = std self.min_val = min_val diff --git a/paddlers/utils/utils.py b/paddlers/utils/utils.py index 18d84ed..0d284f3 100644 --- a/paddlers/utils/utils.py +++ b/paddlers/utils/utils.py @@ -14,8 +14,10 @@ import sys import os +import os.path as osp import time import math +import imghdr import chardet import json import numpy as np @@ -73,12 +75,16 @@ def path_normalization(path): return path -def is_pic(img_name): - valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png', 'tiff'] - suffix = img_name.split('.')[-1] - if suffix not in valid_suffix: - return False - return True +def is_pic(img_path): + valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png', '.npy'] + suffix = img_path.split('.')[-1] + if suffix in valid_suffix: + return True + img_format = imghdr.what(img_path) + _, ext = osp.splitext(img_path) + if img_format == 'tiff' or ext == '.img': + return True + return False class MyEncoder(json.JSONEncoder): diff --git a/tutorials/train/README.md b/tutorials/train/README.md index 4923f49..5284dfe 100644 --- a/tutorials/train/README.md +++ b/tutorials/train/README.md @@ -5,7 +5,7 @@ |代码 | 模型任务 | 数据 | |------|--------|---------| |object_detection/ppyolo.py | 目标检测PPYOLO | 昆虫检测 | -|semantic_segmentation/deeplabv3p_resnet50_vd.py | 语义分割DeepLabV3 | 视盘分割 | +|semantic_segmentation/deeplabv3p_resnet50_multi_channel.py | 语义分割DeepLabV3 | 视盘分割 | |semantic_segmentation/farseg_test.py | 语义分割FarSeg | 遥感建筑分割 | |change_detection/cdnet_build.py | 变化检测CDNet | 遥感变化检测 | |classification/resnet50_vd_rs.py | 图像分类ResNet50_vd | 遥感场景分类 | @@ -25,7 +25,7 @@ ## 开始训练 -* 修改tutorials/train/semantic_segmentation/deeplabv3p_resnet50_vd.py中sys.path路径 +* 修改tutorials/train/semantic_segmentation/deeplabv3p_resnet50_multi_channel.py中sys.path路径 ``` sys.path.append("your/PaddleRS/path") ``` @@ -34,13 +34,13 @@ sys.path.append("your/PaddleRS/path") ```commandline export CUDA_VISIBLE_DEVICES=0 -python tutorials/train/semantic_segmentation/deeplabv3p_resnet50_vd.py +python tutorials/train/semantic_segmentation/deeplabv3p_resnet50_multi_channel.py ``` * 若需使用多张GPU卡进行训练,例如使用2张卡时执行: ```commandline -python -m paddle.distributed.launch --gpus 0,1 tutorials/train/semantic_segmentation/deeplabv3p_resnet50_vd.py +python -m paddle.distributed.launch --gpus 0,1 tutorials/train/semantic_segmentation/deeplabv3p_resnet50_multi_channel.py ``` 使用多卡时,参考[训练参数调整](../../docs/parameters.md)调整学习率和批量大小。 @@ -48,7 +48,7 @@ python -m paddle.distributed.launch --gpus 0,1 tutorials/train/semantic_segmenta ## VisualDL可视化训练指标 在模型训练过程,在`train`函数中,将`use_vdl`设为True,则训练过程会自动将训练日志以VisualDL的格式打点在`save_dir`(用户自己指定的路径)下的`vdl_log`目录,用户可以使用如下命令启动VisualDL服务,查看可视化指标 ```commandline -visualdl --logdir output/deeplabv3p_resnet50_vd/vdl_log --port 8001 +visualdl --logdir output/deeplabv3p_resnet50_multi_channel/vdl_log --port 8001 ``` 服务启动后,使用浏览器打开 https://0.0.0.0:8001 或 https://localhost:8001 diff --git a/tutorials/train/semantic_segmentation/deeplabv3p_resnet50_vd.py b/tutorials/train/semantic_segmentation/deeplabv3p_resnet50_multi_channel.py similarity index 67% rename from tutorials/train/semantic_segmentation/deeplabv3p_resnet50_vd.py rename to tutorials/train/semantic_segmentation/deeplabv3p_resnet50_multi_channel.py index 4c46255..55fc5ab 100644 --- a/tutorials/train/semantic_segmentation/deeplabv3p_resnet50_vd.py +++ b/tutorials/train/semantic_segmentation/deeplabv3p_resnet50_multi_channel.py @@ -5,39 +5,40 @@ sys.path.append("/mnt/chulutao/PaddleRS") import paddlers as pdrs from paddlers import transforms as T -# 下载和解压视盘分割数据集 -optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz' -pdrs.utils.download_and_decompress(optic_dataset, path='./') +# 下载和解压多光谱地块分类数据集 +dataset = 'https://paddleseg.bj.bcebos.com/dataset/remote_sensing_seg.zip' +pdrs.utils.download_and_decompress(dataset, path='./data') # 定义训练和验证时的transforms # API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/transforms/transforms.md +channel = 10 train_transforms = T.Compose([ T.Resize(target_size=512), T.RandomHorizontalFlip(), T.Normalize( - mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + mean=[0.5] * 10, std=[0.5] * 10), ]) eval_transforms = T.Compose([ T.Resize(target_size=512), T.Normalize( - mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + mean=[0.5] * 10, std=[0.5] * 10), ]) # 定义训练和验证所用的数据集 # API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/datasets.md train_dataset = pdrs.datasets.SegDataset( - data_dir='optic_disc_seg', - file_list='optic_disc_seg/train_list.txt', - label_list='optic_disc_seg/labels.txt', + data_dir='./data/remote_sensing_seg', + file_list='./data/remote_sensing_seg/train.txt', + label_list='./data/remote_sensing_seg/labels.txt', transforms=train_transforms, num_workers=0, shuffle=True) eval_dataset = pdrs.datasets.SegDataset( - data_dir='optic_disc_seg', - file_list='optic_disc_seg/val_list.txt', - label_list='optic_disc_seg/labels.txt', + data_dir='./data/remote_sensing_seg', + file_list='./data/remote_sensing_seg/val.txt', + label_list='./data/remote_sensing_seg/labels.txt', transforms=eval_transforms, num_workers=0, shuffle=False) @@ -45,7 +46,7 @@ eval_dataset = pdrs.datasets.SegDataset( # 初始化模型,并进行训练 # 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/paddlers/blob/develop/docs/visualdl.md num_classes = len(train_dataset.labels) -model = pdrs.tasks.DeepLabV3P(num_classes=num_classes, backbone='ResNet50_vd') +model = pdrs.tasks.DeepLabV3P(input_channel=channel, num_classes=num_classes, backbone='ResNet50_vd') # API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/models/semantic_segmentation.md # 各参数介绍与调整说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/parameters.md