You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
135 lines
4.3 KiB
135 lines
4.3 KiB
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
import os |
|
import copy |
|
from pathlib import Path |
|
from paddle.io import Dataset |
|
from abc import ABCMeta, abstractmethod |
|
|
|
from .preprocess import build_preprocess |
|
|
|
IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', |
|
'.PPM', '.bmp', '.BMP') |
|
|
|
|
|
def scandir(dir_path, suffix=None, recursive=False): |
|
"""Scan a directory to find the interested files. |
|
|
|
Args: |
|
dir_path (str | obj:`Path`): Path of the directory. |
|
suffix (str | tuple(str), optional): File suffix that we are |
|
interested in. Default: None. |
|
recursive (bool, optional): If set to True, recursively scan the |
|
directory. Default: False. |
|
|
|
Returns: |
|
A generator for all the interested files with relative pathes. |
|
""" |
|
if isinstance(dir_path, (str, Path)): |
|
dir_path = str(dir_path) |
|
else: |
|
raise TypeError('"dir_path" must be a string or Path object') |
|
|
|
if (suffix is not None) and not isinstance(suffix, (str, tuple)): |
|
raise TypeError('"suffix" must be a string or tuple of strings') |
|
|
|
root = dir_path |
|
|
|
def _scandir(dir_path, suffix, recursive): |
|
for entry in os.scandir(dir_path): |
|
if not entry.name.startswith('.') and entry.is_file(): |
|
rel_path = os.path.relpath(entry.path, root) |
|
if suffix is None: |
|
yield rel_path |
|
elif rel_path.endswith(suffix): |
|
yield rel_path |
|
else: |
|
if recursive: |
|
yield from _scandir( |
|
entry.path, suffix=suffix, recursive=recursive) |
|
else: |
|
continue |
|
|
|
return _scandir(dir_path, suffix=suffix, recursive=recursive) |
|
|
|
|
|
class BaseDataset(Dataset, metaclass=ABCMeta): |
|
"""Base class for datasets. |
|
|
|
All datasets should subclass it. |
|
All subclasses should overwrite: |
|
|
|
``prepare_data_infos``, supporting to load information and generate |
|
image lists. |
|
|
|
Args: |
|
preprocess (list[dict]): A sequence of data preprocess config. |
|
|
|
""" |
|
|
|
def __init__(self, preprocess=None): |
|
super(BaseDataset, self).__init__() |
|
|
|
if preprocess: |
|
self.preprocess = build_preprocess(preprocess) |
|
|
|
@abstractmethod |
|
def prepare_data_infos(self): |
|
"""Abstract function for loading annotation. |
|
|
|
All subclasses should overwrite this function |
|
should set self.annotations in this fucntion |
|
data_infos should be as list of dict: |
|
[{key_path: file_path}, {key_path: file_path}, {key_path: file_path}] |
|
""" |
|
self.data_infos = None |
|
|
|
@staticmethod |
|
def scan_folder(path): |
|
"""Obtain sample path list (including sub-folders) from a given folder. |
|
|
|
Args: |
|
path (str|pathlib.Path): Folder path. |
|
|
|
Returns: |
|
list[str]: sample list obtained form given folder. |
|
""" |
|
|
|
if isinstance(path, (str, Path)): |
|
path = str(path) |
|
else: |
|
raise TypeError("'path' must be a str or a Path object, " |
|
f'but received {type(path)}.') |
|
|
|
samples = list(scandir(path, suffix=IMG_EXTENSIONS, recursive=True)) |
|
samples = [os.path.join(path, v) for v in samples] |
|
assert samples, '{} has no valid image file.'.format(path) |
|
return samples |
|
|
|
def __getitem__(self, idx): |
|
datas = copy.deepcopy(self.data_infos[idx]) |
|
|
|
if hasattr(self, 'preprocess') and self.preprocess: |
|
datas = self.preprocess(datas) |
|
|
|
return datas |
|
|
|
def __len__(self): |
|
"""Length of the dataset. |
|
|
|
Returns: |
|
int: Length of the dataset. |
|
""" |
|
return len(self.data_infos)
|
|
|