Cherry-pick commits from refactor_data

own
Bobholamovic 3 years ago committed by Bobholamovic
parent dce23e4a6e
commit 9c382d82bb
  1. 30
      paddlers/datasets/base.py
  2. 19
      paddlers/datasets/cd_dataset.py
  3. 20
      paddlers/datasets/clas_dataset.py
  4. 16
      paddlers/datasets/coco.py
  5. 19
      paddlers/datasets/seg_dataset.py
  6. 16
      paddlers/datasets/voc.py

@ -0,0 +1,30 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from paddle.io import Dataset
from paddlers.utils import get_num_workers
class BaseDataset(Dataset):
def __init__(self, data_dir, label_list, transforms, num_workers, shuffle):
super(BaseDataset, self).__init__()
self.data_dir = data_dir
self.label_list = label_list
self.transforms = deepcopy(transforms)
self.num_workers = get_num_workers(num_workers)
self.shuffle = shuffle

@ -16,12 +16,11 @@ import copy
from enum import IntEnum from enum import IntEnum
import os.path as osp import os.path as osp
from paddle.io import Dataset from .base import BaseDataset
from paddlers.utils import logging, get_encoding, path_normalization, is_pic
from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
class CDDataset(BaseDataset):
class CDDataset(Dataset):
""" """
读取变化检测任务数据集并对样本进行相应的处理来自SegDataset图像标签需要两个 读取变化检测任务数据集并对样本进行相应的处理来自SegDataset图像标签需要两个
@ -31,8 +30,10 @@ class CDDataset(Dataset):
False默认设置文件中每一行应依次包含第一时相影像第二时相影像以及变化检测标签的路径`with_seg_labels`为True时 False默认设置文件中每一行应依次包含第一时相影像第二时相影像以及变化检测标签的路径`with_seg_labels`为True时
文件中每一行应依次包含第一时相影像第二时相影像变化检测标签第一时相建筑物标签以及第二时相建筑物标签的路径 文件中每一行应依次包含第一时相影像第二时相影像变化检测标签第一时相建筑物标签以及第二时相建筑物标签的路径
label_list (str): 描述数据集包含的类别信息文件路径默认值为None label_list (str): 描述数据集包含的类别信息文件路径默认值为None
transforms (paddlers.transforms): 数据集中每个样本的预处理/增强算子 transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子
num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数默认为'auto' num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数默认为'auto'当设为'auto'根据
系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8`num_workers`为8否则为CPU核数的
一半
shuffle (bool): 是否需要对数据集中样本打乱顺序默认为False shuffle (bool): 是否需要对数据集中样本打乱顺序默认为False
with_seg_labels (bool, optional): 数据集中是否包含两个时相的语义分割标签默认为False with_seg_labels (bool, optional): 数据集中是否包含两个时相的语义分割标签默认为False
binarize_labels (bool, optional): 是否对数据集中的标签进行二值化操作默认为False binarize_labels (bool, optional): 是否对数据集中的标签进行二值化操作默认为False
@ -47,15 +48,13 @@ class CDDataset(Dataset):
shuffle=False, shuffle=False,
with_seg_labels=False, with_seg_labels=False,
binarize_labels=False): binarize_labels=False):
super(CDDataset, self).__init__() super(CDDataset, self).__init__(data_dir, label_list, transforms,
num_workers, shuffle)
DELIMETER = ' ' DELIMETER = ' '
self.transforms = copy.deepcopy(transforms)
# TODO: batch padding # TODO: batch padding
self.batch_transforms = None self.batch_transforms = None
self.num_workers = get_num_workers(num_workers)
self.shuffle = shuffle
self.file_list = list() self.file_list = list()
self.labels = list() self.labels = list()
self.with_seg_labels = with_seg_labels self.with_seg_labels = with_seg_labels

@ -15,20 +15,21 @@
import os.path as osp import os.path as osp
import copy import copy
from paddle.io import Dataset from .base import BaseDataset
from paddlers.utils import logging, get_encoding, path_normalization, is_pic
from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
class ClasDataset(BaseDataset):
class ClasDataset(Dataset):
"""读取图像分类任务数据集,并对样本进行相应的处理。 """读取图像分类任务数据集,并对样本进行相应的处理。
Args: Args:
data_dir (str): 数据集所在的目录路径 data_dir (str): 数据集所在的目录路径
file_list (str): 描述数据集图片文件和对应标注序号文本内每行路径为相对data_dir的相对路 file_list (str): 描述数据集图片文件和对应标注序号文本内每行路径为相对data_dir的相对路
label_list (str): 描述数据集包含的类别信息文件路径文件格式为类别 说明默认值为None label_list (str): 描述数据集包含的类别信息文件路径文件格式为类别 说明默认值为None
transforms (paddlers.transforms): 数据集中每个样本的预处理/增强算子 transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子
num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数默认为'auto' num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数默认为'auto'当设为'auto'根据
系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8`num_workers`为8否则为CPU核数的
一半
shuffle (bool): 是否需要对数据集中样本打乱顺序默认为False shuffle (bool): 是否需要对数据集中样本打乱顺序默认为False
""" """
@ -39,14 +40,11 @@ class ClasDataset(Dataset):
transforms=None, transforms=None,
num_workers='auto', num_workers='auto',
shuffle=False): shuffle=False):
super(ClasDataset, self).__init__() super(ClasDataset, self).__init__(data_dir, label_list, transforms,
self.transforms = copy.deepcopy(transforms) num_workers, shuffle)
# TODO batch padding # TODO batch padding
self.batch_transforms = None self.batch_transforms = None
self.num_workers = get_num_workers(num_workers)
self.shuffle = shuffle
self.file_list = list() self.file_list = list()
self.label_list = label_list
self.labels = list() self.labels = list()
# TODO:非None时,让用户跳转数据集分析生成label_list # TODO:非None时,让用户跳转数据集分析生成label_list

@ -20,14 +20,14 @@ import random
from collections import OrderedDict from collections import OrderedDict
import numpy as np import numpy as np
from paddle.io import Dataset
from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic from .base import BaseDataset
from paddlers.utils import logging, get_encoding, path_normalization, is_pic
from paddlers.transforms import DecodeImg, MixupImage from paddlers.transforms import DecodeImg, MixupImage
from paddlers.tools import YOLOAnchorCluster from paddlers.tools import YOLOAnchorCluster
class COCODetection(Dataset): class COCODetection(BaseDataset):
"""读取COCO格式的检测数据集,并对样本进行相应的处理。 """读取COCO格式的检测数据集,并对样本进行相应的处理。
Args: Args:
@ -35,7 +35,7 @@ class COCODetection(Dataset):
image_dir (str): 描述数据集图片文件路径 image_dir (str): 描述数据集图片文件路径
anno_path (str): COCO标注文件路径 anno_path (str): COCO标注文件路径
label_list (str): 描述数据集包含的类别信息文件路径 label_list (str): 描述数据集包含的类别信息文件路径
transforms (paddlers.det.transforms): 数据集中每个样本的预处理/增强算子 transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子
num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数默认为'auto'当设为'auto'根据 num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数默认为'auto'当设为'auto'根据
系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8`num_workers`为8否则为CPU核数的 系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8`num_workers`为8否则为CPU核数的
一半 一半
@ -60,10 +60,10 @@ class COCODetection(Dataset):
import matplotlib import matplotlib
matplotlib.use('Agg') matplotlib.use('Agg')
from pycocotools.coco import COCO from pycocotools.coco import COCO
super(COCODetection, self).__init__() super(COCODetection, self).__init__(data_dir, label_list, transforms,
self.data_dir = data_dir num_workers, shuffle)
self.data_fields = None self.data_fields = None
self.transforms = copy.deepcopy(transforms)
self.num_max_boxes = 50 self.num_max_boxes = 50
self.use_mix = False self.use_mix = False
@ -76,8 +76,6 @@ class COCODetection(Dataset):
break break
self.batch_transforms = None self.batch_transforms = None
self.num_workers = get_num_workers(num_workers)
self.shuffle = shuffle
self.allow_empty = allow_empty self.allow_empty = allow_empty
self.empty_ratio = empty_ratio self.empty_ratio = empty_ratio
self.file_list = list() self.file_list = list()

@ -15,20 +15,21 @@
import os.path as osp import os.path as osp
import copy import copy
from paddle.io import Dataset from .base import BaseDataset
from paddlers.utils import logging, get_encoding, path_normalization, is_pic
from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic
class SegDataset(BaseDataset):
class SegDataset(Dataset):
"""读取语义分割任务数据集,并对样本进行相应的处理。 """读取语义分割任务数据集,并对样本进行相应的处理。
Args: Args:
data_dir (str): 数据集所在的目录路径 data_dir (str): 数据集所在的目录路径
file_list (str): 描述数据集图片文件和对应标注文件的文件路径文本内每行路径为相对data_dir的相对路 file_list (str): 描述数据集图片文件和对应标注文件的文件路径文本内每行路径为相对data_dir的相对路
label_list (str): 描述数据集包含的类别信息文件路径默认值为None label_list (str): 描述数据集包含的类别信息文件路径默认值为None
transforms (paddlers.transforms): 数据集中每个样本的预处理/增强算子 transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子
num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数默认为'auto' num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数默认为'auto'当设为'auto'根据
系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8`num_workers`为8否则为CPU核数的
一半
shuffle (bool): 是否需要对数据集中样本打乱顺序默认为False shuffle (bool): 是否需要对数据集中样本打乱顺序默认为False
""" """
@ -39,12 +40,10 @@ class SegDataset(Dataset):
transforms=None, transforms=None,
num_workers='auto', num_workers='auto',
shuffle=False): shuffle=False):
super(SegDataset, self).__init__() super(SegDataset, self).__init__(data_dir, label_list, transforms,
self.transforms = copy.deepcopy(transforms) num_workers, shuffle)
# TODO batch padding # TODO batch padding
self.batch_transforms = None self.batch_transforms = None
self.num_workers = get_num_workers(num_workers)
self.shuffle = shuffle
self.file_list = list() self.file_list = list()
self.labels = list() self.labels = list()

@ -22,21 +22,21 @@ from collections import OrderedDict
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import numpy as np import numpy as np
from paddle.io import Dataset
from paddlers.utils import logging, get_num_workers, get_encoding, path_normalization, is_pic from .base import BaseDataset
from paddlers.utils import logging, get_encoding, path_normalization, is_pic
from paddlers.transforms import DecodeImg, MixupImage from paddlers.transforms import DecodeImg, MixupImage
from paddlers.tools import YOLOAnchorCluster from paddlers.tools import YOLOAnchorCluster
class VOCDetection(Dataset): class VOCDetection(BaseDataset):
"""读取PascalVOC格式的检测数据集,并对样本进行相应的处理。 """读取PascalVOC格式的检测数据集,并对样本进行相应的处理。
Args: Args:
data_dir (str): 数据集所在的目录路径 data_dir (str): 数据集所在的目录路径
file_list (str): 描述数据集图片文件和对应标注文件的文件路径文本内每行路径为相对data_dir的相对路 file_list (str): 描述数据集图片文件和对应标注文件的文件路径文本内每行路径为相对data_dir的相对路
label_list (str): 描述数据集包含的类别信息文件路径 label_list (str): 描述数据集包含的类别信息文件路径
transforms (paddlers.det.transforms): 数据集中每个样本的预处理/增强算子 transforms (paddlers.transforms.Compose): 数据集中每个样本的预处理/增强算子
num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数默认为'auto'当设为'auto'根据 num_workers (int|str): 数据集中样本在预处理过程中的线程或进程数默认为'auto'当设为'auto'根据
系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8`num_workers`为8否则为CPU核数的 系统的实际CPU核数设置`num_workers`: 如果CPU核数的一半大于8`num_workers`为8否则为CPU核数的
一半 一半
@ -60,10 +60,10 @@ class VOCDetection(Dataset):
import matplotlib import matplotlib
matplotlib.use('Agg') matplotlib.use('Agg')
from pycocotools.coco import COCO from pycocotools.coco import COCO
super(VOCDetection, self).__init__() super(VOCDetection, self).__init__(data_dir, label_list, transforms,
self.data_dir = data_dir num_workers, shuffle)
self.data_fields = None self.data_fields = None
self.transforms = copy.deepcopy(transforms)
self.num_max_boxes = 50 self.num_max_boxes = 50
self.use_mix = False self.use_mix = False
@ -76,8 +76,6 @@ class VOCDetection(Dataset):
break break
self.batch_transforms = None self.batch_transforms = None
self.num_workers = get_num_workers(num_workers)
self.shuffle = shuffle
self.allow_empty = allow_empty self.allow_empty = allow_empty
self.empty_ratio = empty_ratio self.empty_ratio = empty_ratio
self.file_list = list() self.file_list = list()

Loading…
Cancel
Save