diff --git a/tools/misc/split_coco.py b/tools/misc/split_coco.py new file mode 100644 index 000000000..78cc65503 --- /dev/null +++ b/tools/misc/split_coco.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp + +import mmcv +import numpy as np + +prog_description = '''K-Fold coco split. + +To split coco data for semi-supervised object detection: + python tools/misc/split_coco.py +''' + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--data-root', + type=str, + help='The data root of coco dataset.', + default='./data/coco/') + parser.add_argument( + '--out-dir', + type=str, + help='The output directory of coco semi-supervised annotations.', + default='./data/coco_semi_annos/') + parser.add_argument( + '--labeled-percent', + type=float, + nargs='+', + help='The percentage of labeled data in the training set.', + default=[1, 2, 5, 10]) + parser.add_argument( + '--fold', + type=int, + help='K-fold cross validation for semi-supervised object detection.', + default=5) + args = parser.parse_args() + return args + + +def split_coco(data_root, out_dir, percent, fold): + """Split COCO data for Semi-supervised object detection. + + Args: + data_root (str): The data root of coco dataset. + out_dir (str): The output directory of coco semi-supervised + annotations. + percent (float): The percentage of labeled data in the training set. + fold (int): The fold of dataset and set as random seed for data split. + """ + + def save_anns(name, images, annotations): + sub_anns = dict() + sub_anns['images'] = images + sub_anns['annotations'] = annotations + sub_anns['licenses'] = anns['licenses'] + sub_anns['categories'] = anns['categories'] + sub_anns['info'] = anns['info'] + + mmcv.mkdir_or_exist(out_dir) + mmcv.dump(sub_anns, f'{out_dir}/{name}.json') + + # set random seed with the fold + np.random.seed(fold) + ann_file = osp.join(data_root, 'annotations/instances_train2017.json') + anns = mmcv.load(ann_file) + + image_list = anns['images'] + labeled_total = int(percent / 100. * len(image_list)) + labeled_inds = set( + np.random.choice(range(len(image_list)), size=labeled_total)) + labeled_ids, labeled_images, unlabeled_images = [], [], [] + + for i in range(len(image_list)): + if i in labeled_inds: + labeled_images.append(image_list[i]) + labeled_ids.append(image_list[i]['id']) + else: + unlabeled_images.append(image_list[i]) + + # get all annotations of labeled images + labeled_ids = set(labeled_ids) + labeled_annotations, unlabeled_annotations = [], [] + + for ann in anns['annotations']: + if ann['image_id'] in labeled_ids: + labeled_annotations.append(ann) + else: + unlabeled_annotations.append(ann) + + # save labeled and unlabeled + labeled_name = f'instances_train2017.{fold}@{percent}' + unlabeled_name = f'instances_train2017.{fold}@{percent}-unlabeled' + + save_anns(labeled_name, labeled_images, labeled_annotations) + save_anns(unlabeled_name, unlabeled_images, unlabeled_annotations) + + +def multi_wrapper(args): + return split_coco(*args) + + +if __name__ == '__main__': + args = parse_args() + arguments_list = [(args.data_root, args.out_dir, p, f) + for f in range(1, args.fold + 1) + for p in args.labeled_percent] + mmcv.track_parallel_progress(multi_wrapper, arguments_list, args.fold)