[Feature] Support splitting COCO data for Semi-supervised object detection. (#7431)
* Split COCO data for Semi-supervised object detection. * import mmcv and use f-string * add a parser out_dir to set the path of semi-annos * Support multiprocessing * use mmcv.track_parallel_progress * fix * rename some variablespull/7676/head
parent
992b1984bc
commit
04db930cec
1 changed files with 109 additions and 0 deletions
@ -0,0 +1,109 @@ |
||||
# Copyright (c) OpenMMLab. All rights reserved. |
||||
import argparse |
||||
import os.path as osp |
||||
|
||||
import mmcv |
||||
import numpy as np |
||||
|
||||
prog_description = '''K-Fold coco split. |
||||
|
||||
To split coco data for semi-supervised object detection: |
||||
python tools/misc/split_coco.py |
||||
''' |
||||
|
||||
|
||||
def parse_args(): |
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument( |
||||
'--data-root', |
||||
type=str, |
||||
help='The data root of coco dataset.', |
||||
default='./data/coco/') |
||||
parser.add_argument( |
||||
'--out-dir', |
||||
type=str, |
||||
help='The output directory of coco semi-supervised annotations.', |
||||
default='./data/coco_semi_annos/') |
||||
parser.add_argument( |
||||
'--labeled-percent', |
||||
type=float, |
||||
nargs='+', |
||||
help='The percentage of labeled data in the training set.', |
||||
default=[1, 2, 5, 10]) |
||||
parser.add_argument( |
||||
'--fold', |
||||
type=int, |
||||
help='K-fold cross validation for semi-supervised object detection.', |
||||
default=5) |
||||
args = parser.parse_args() |
||||
return args |
||||
|
||||
|
||||
def split_coco(data_root, out_dir, percent, fold): |
||||
"""Split COCO data for Semi-supervised object detection. |
||||
|
||||
Args: |
||||
data_root (str): The data root of coco dataset. |
||||
out_dir (str): The output directory of coco semi-supervised |
||||
annotations. |
||||
percent (float): The percentage of labeled data in the training set. |
||||
fold (int): The fold of dataset and set as random seed for data split. |
||||
""" |
||||
|
||||
def save_anns(name, images, annotations): |
||||
sub_anns = dict() |
||||
sub_anns['images'] = images |
||||
sub_anns['annotations'] = annotations |
||||
sub_anns['licenses'] = anns['licenses'] |
||||
sub_anns['categories'] = anns['categories'] |
||||
sub_anns['info'] = anns['info'] |
||||
|
||||
mmcv.mkdir_or_exist(out_dir) |
||||
mmcv.dump(sub_anns, f'{out_dir}/{name}.json') |
||||
|
||||
# set random seed with the fold |
||||
np.random.seed(fold) |
||||
ann_file = osp.join(data_root, 'annotations/instances_train2017.json') |
||||
anns = mmcv.load(ann_file) |
||||
|
||||
image_list = anns['images'] |
||||
labeled_total = int(percent / 100. * len(image_list)) |
||||
labeled_inds = set( |
||||
np.random.choice(range(len(image_list)), size=labeled_total)) |
||||
labeled_ids, labeled_images, unlabeled_images = [], [], [] |
||||
|
||||
for i in range(len(image_list)): |
||||
if i in labeled_inds: |
||||
labeled_images.append(image_list[i]) |
||||
labeled_ids.append(image_list[i]['id']) |
||||
else: |
||||
unlabeled_images.append(image_list[i]) |
||||
|
||||
# get all annotations of labeled images |
||||
labeled_ids = set(labeled_ids) |
||||
labeled_annotations, unlabeled_annotations = [], [] |
||||
|
||||
for ann in anns['annotations']: |
||||
if ann['image_id'] in labeled_ids: |
||||
labeled_annotations.append(ann) |
||||
else: |
||||
unlabeled_annotations.append(ann) |
||||
|
||||
# save labeled and unlabeled |
||||
labeled_name = f'instances_train2017.{fold}@{percent}' |
||||
unlabeled_name = f'instances_train2017.{fold}@{percent}-unlabeled' |
||||
|
||||
save_anns(labeled_name, labeled_images, labeled_annotations) |
||||
save_anns(unlabeled_name, unlabeled_images, unlabeled_annotations) |
||||
|
||||
|
||||
def multi_wrapper(args): |
||||
return split_coco(*args) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
args = parse_args() |
||||
arguments_list = [(args.data_root, args.out_dir, p, f) |
||||
for f in range(1, args.fold + 1) |
||||
for p in args.labeled_percent] |
||||
mmcv.track_parallel_progress(multi_wrapper, arguments_list, args.fold) |
Loading…
Reference in new issue