From 1c30b71959b8b5dfede6d549c396ea1821d93c36 Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Tue, 30 Aug 2022 17:04:16 +0800 Subject: [PATCH] Add util functions for data preparation --- tools/prepare_dataset/common.py | 68 +++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/tools/prepare_dataset/common.py b/tools/prepare_dataset/common.py index 1eb1b37..09eec87 100644 --- a/tools/prepare_dataset/common.py +++ b/tools/prepare_dataset/common.py @@ -1,4 +1,6 @@ import argparse +import random +import copy import os import os.path as osp from glob import glob @@ -198,6 +200,20 @@ def create_file_list(file_list, path_tuples, sep=' '): f.write(line + '\n') +def create_label_list(label_list, labels): + """ + Create label list. + + Args: + label_list (str): Path of label list to create. + labels (list[str]|tuple[str]]): Label names. + """ + + with open(label_list, 'w') as f: + for label in labels: + f.write(label + '\n') + + def link_dataset(src, dst): """ Make a symbolic link to a dataset. @@ -211,5 +227,57 @@ def link_dataset(src, dst): raise ValueError(f"{dst} exists and is not a directory.") elif not osp.exists(dst): os.makedirs(dst) + src = osp.realpath(src) name = osp.basename(osp.normpath(src)) os.symlink(src, osp.join(dst, name), target_is_directory=True) + + +def random_split(samples, + ratios=(0.7, 0.2, 0.1), + inplace=True, + drop_remainder=False): + """ + Randomly split the dataset into two or three subsets. + + Args: + samples (list): All samples of the dataset. + ratios (tuple[float], optional): If the length of `ratios` is 2, + the two elements indicate the ratios of samples used for training + and evaluation. If the length of `ratios` is 3, the three elements + indicate the ratios of samples used for training, validation, and + testing. Defaults to (0.7, 0.2, 0.1). + inplace (bool, optional): Whether to shuffle `samples` in place. + Defaults to True. + drop_remainder (bool, optional): Whether to discard the remaining samples. + If False, the remaining samples will be included in the last subset. + For example, if `ratios` is (0.7, 0.1) and `drop_remainder` is False, + the two subsets after splitting will contain 70% and 30% of the samples, + respectively. Defaults to False. + """ + + if not inplace: + samples = copy.deepcopy(samples) + + if len(samples) == 0: + raise ValueError("There are no samples!") + + if len(ratios) not in (2, 3): + raise ValueError("`len(ratios)` must be 2 or 3!") + + random.shuffle(samples) + + n_samples = len(samples) + acc_r = 0 + st_idx, ed_idx = 0, 0 + splits = [] + for r in ratios: + acc_r += r + ed_idx = round(acc_r * n_samples) + splits.append(samples[st_idx:ed_idx]) + st_idx = ed_idx + + if ed_idx < len(ratios) and not drop_remainder: + # Append remainder to the last split + splits[-1].append(splits[ed_idx:]) + + return splits \ No newline at end of file