|
|
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
import os
|
|
|
|
import copy
|
|
|
|
from pathlib import Path
|
|
|
|
from paddle.io import Dataset
|
|
|
|
from abc import ABCMeta, abstractmethod
|
|
|
|
|
|
|
|
from .preprocess import build_preprocess
|
|
|
|
|
|
|
|
IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm',
|
|
|
|
'.PPM', '.bmp', '.BMP')
|
|
|
|
|
|
|
|
|
|
|
|
def scandir(dir_path, suffix=None, recursive=False):
|
|
|
|
"""Scan a directory to find the interested files.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dir_path (str | obj:`Path`): Path of the directory.
|
|
|
|
suffix (str | tuple(str), optional): File suffix that we are
|
|
|
|
interested in. Default: None.
|
|
|
|
recursive (bool, optional): If set to True, recursively scan the
|
|
|
|
directory. Default: False.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A generator for all the interested files with relative pathes.
|
|
|
|
"""
|
|
|
|
if isinstance(dir_path, (str, Path)):
|
|
|
|
dir_path = str(dir_path)
|
|
|
|
else:
|
|
|
|
raise TypeError('"dir_path" must be a string or Path object')
|
|
|
|
|
|
|
|
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
|
|
|
raise TypeError('"suffix" must be a string or tuple of strings')
|
|
|
|
|
|
|
|
root = dir_path
|
|
|
|
|
|
|
|
def _scandir(dir_path, suffix, recursive):
|
|
|
|
for entry in os.scandir(dir_path):
|
|
|
|
if not entry.name.startswith('.') and entry.is_file():
|
|
|
|
rel_path = os.path.relpath(entry.path, root)
|
|
|
|
if suffix is None:
|
|
|
|
yield rel_path
|
|
|
|
elif rel_path.endswith(suffix):
|
|
|
|
yield rel_path
|
|
|
|
else:
|
|
|
|
if recursive:
|
|
|
|
yield from _scandir(
|
|
|
|
entry.path, suffix=suffix, recursive=recursive)
|
|
|
|
else:
|
|
|
|
continue
|
|
|
|
|
|
|
|
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
|
|
|
|
|
|
|
|
|
|
|
class BaseDataset(Dataset, metaclass=ABCMeta):
|
|
|
|
"""Base class for datasets.
|
|
|
|
|
|
|
|
All datasets should subclass it.
|
|
|
|
All subclasses should overwrite:
|
|
|
|
|
|
|
|
``prepare_data_infos``, supporting to load information and generate
|
|
|
|
image lists.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
preprocess (list[dict]): A sequence of data preprocess config.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, preprocess=None):
|
|
|
|
super(BaseDataset, self).__init__()
|
|
|
|
|
|
|
|
if preprocess:
|
|
|
|
self.preprocess = build_preprocess(preprocess)
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def prepare_data_infos(self):
|
|
|
|
"""Abstract function for loading annotation.
|
|
|
|
|
|
|
|
All subclasses should overwrite this function
|
|
|
|
should set self.annotations in this fucntion
|
|
|
|
data_infos should be as list of dict:
|
|
|
|
[{key_path: file_path}, {key_path: file_path}, {key_path: file_path}]
|
|
|
|
"""
|
|
|
|
self.data_infos = None
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def scan_folder(path):
|
|
|
|
"""Obtain sample path list (including sub-folders) from a given folder.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
path (str|pathlib.Path): Folder path.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
list[str]: sample list obtained form given folder.
|
|
|
|
"""
|
|
|
|
|
|
|
|
if isinstance(path, (str, Path)):
|
|
|
|
path = str(path)
|
|
|
|
else:
|
|
|
|
raise TypeError("'path' must be a str or a Path object, "
|
|
|
|
f'but received {type(path)}.')
|
|
|
|
|
|
|
|
samples = list(scandir(path, suffix=IMG_EXTENSIONS, recursive=True))
|
|
|
|
samples = [os.path.join(path, v) for v in samples]
|
|
|
|
assert samples, '{} has no valid image file.'.format(path)
|
|
|
|
return samples
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
datas = copy.deepcopy(self.data_infos[idx])
|
|
|
|
|
|
|
|
if hasattr(self, 'preprocess') and self.preprocess:
|
|
|
|
datas = self.preprocess(datas)
|
|
|
|
|
|
|
|
return datas
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
"""Length of the dataset.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: Length of the dataset.
|
|
|
|
"""
|
|
|
|
return len(self.data_infos)
|