OpenMMLab Detection Toolbox and Benchmark https://mmdetection.readthedocs.io/
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

109 lines
3.4 KiB

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import mmcv
import numpy as np
prog_description = '''K-Fold coco split.
To split coco data for semi-supervised object detection:
python tools/misc/split_coco.py
'''
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--data-root',
type=str,
help='The data root of coco dataset.',
default='./data/coco/')
parser.add_argument(
'--out-dir',
type=str,
help='The output directory of coco semi-supervised annotations.',
default='./data/coco_semi_annos/')
parser.add_argument(
'--labeled-percent',
type=float,
nargs='+',
help='The percentage of labeled data in the training set.',
default=[1, 2, 5, 10])
parser.add_argument(
'--fold',
type=int,
help='K-fold cross validation for semi-supervised object detection.',
default=5)
args = parser.parse_args()
return args
def split_coco(data_root, out_dir, percent, fold):
"""Split COCO data for Semi-supervised object detection.
Args:
data_root (str): The data root of coco dataset.
out_dir (str): The output directory of coco semi-supervised
annotations.
percent (float): The percentage of labeled data in the training set.
fold (int): The fold of dataset and set as random seed for data split.
"""
def save_anns(name, images, annotations):
sub_anns = dict()
sub_anns['images'] = images
sub_anns['annotations'] = annotations
sub_anns['licenses'] = anns['licenses']
sub_anns['categories'] = anns['categories']
sub_anns['info'] = anns['info']
mmcv.mkdir_or_exist(out_dir)
mmcv.dump(sub_anns, f'{out_dir}/{name}.json')
# set random seed with the fold
np.random.seed(fold)
ann_file = osp.join(data_root, 'annotations/instances_train2017.json')
anns = mmcv.load(ann_file)
image_list = anns['images']
labeled_total = int(percent / 100. * len(image_list))
labeled_inds = set(
np.random.choice(range(len(image_list)), size=labeled_total))
labeled_ids, labeled_images, unlabeled_images = [], [], []
for i in range(len(image_list)):
if i in labeled_inds:
labeled_images.append(image_list[i])
labeled_ids.append(image_list[i]['id'])
else:
unlabeled_images.append(image_list[i])
# get all annotations of labeled images
labeled_ids = set(labeled_ids)
labeled_annotations, unlabeled_annotations = [], []
for ann in anns['annotations']:
if ann['image_id'] in labeled_ids:
labeled_annotations.append(ann)
else:
unlabeled_annotations.append(ann)
# save labeled and unlabeled
labeled_name = f'instances_train2017.{fold}@{percent}'
unlabeled_name = f'instances_train2017.{fold}@{percent}-unlabeled'
save_anns(labeled_name, labeled_images, labeled_annotations)
save_anns(unlabeled_name, unlabeled_images, unlabeled_annotations)
def multi_wrapper(args):
return split_coco(*args)
if __name__ == '__main__':
args = parse_args()
arguments_list = [(args.data_root, args.out_dir, p, f)
for f in range(1, args.fold + 1)
for p in args.labeled_percent]
mmcv.track_parallel_progress(multi_wrapper, arguments_list, args.fold)