OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
139 lines
4.4 KiB
139 lines
4.4 KiB
import argparse |
|
import os.path as osp |
|
import xml.etree.ElementTree as ET |
|
|
|
import mmcv |
|
import numpy as np |
|
|
|
from mmdet.core import voc_classes |
|
|
|
label_ids = {name: i for i, name in enumerate(voc_classes())} |
|
|
|
|
|
def parse_xml(args): |
|
xml_path, img_path = args |
|
tree = ET.parse(xml_path) |
|
root = tree.getroot() |
|
size = root.find('size') |
|
w = int(size.find('width').text) |
|
h = int(size.find('height').text) |
|
bboxes = [] |
|
labels = [] |
|
bboxes_ignore = [] |
|
labels_ignore = [] |
|
for obj in root.findall('object'): |
|
name = obj.find('name').text |
|
label = label_ids[name] |
|
difficult = int(obj.find('difficult').text) |
|
bnd_box = obj.find('bndbox') |
|
bbox = [ |
|
int(bnd_box.find('xmin').text), |
|
int(bnd_box.find('ymin').text), |
|
int(bnd_box.find('xmax').text), |
|
int(bnd_box.find('ymax').text) |
|
] |
|
if difficult: |
|
bboxes_ignore.append(bbox) |
|
labels_ignore.append(label) |
|
else: |
|
bboxes.append(bbox) |
|
labels.append(label) |
|
if not bboxes: |
|
bboxes = np.zeros((0, 4)) |
|
labels = np.zeros((0, )) |
|
else: |
|
bboxes = np.array(bboxes, ndmin=2) - 1 |
|
labels = np.array(labels) |
|
if not bboxes_ignore: |
|
bboxes_ignore = np.zeros((0, 4)) |
|
labels_ignore = np.zeros((0, )) |
|
else: |
|
bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1 |
|
labels_ignore = np.array(labels_ignore) |
|
annotation = { |
|
'filename': img_path, |
|
'width': w, |
|
'height': h, |
|
'ann': { |
|
'bboxes': bboxes.astype(np.float32), |
|
'labels': labels.astype(np.int64), |
|
'bboxes_ignore': bboxes_ignore.astype(np.float32), |
|
'labels_ignore': labels_ignore.astype(np.int64) |
|
} |
|
} |
|
return annotation |
|
|
|
|
|
def cvt_annotations(devkit_path, years, split, out_file): |
|
if not isinstance(years, list): |
|
years = [years] |
|
annotations = [] |
|
for year in years: |
|
filelist = osp.join(devkit_path, |
|
f'VOC{year}/ImageSets/Main/{split}.txt') |
|
if not osp.isfile(filelist): |
|
print(f'filelist does not exist: {filelist}, ' |
|
f'skip voc{year} {split}') |
|
return |
|
img_names = mmcv.list_from_file(filelist) |
|
xml_paths = [ |
|
osp.join(devkit_path, f'VOC{year}/Annotations/{img_name}.xml') |
|
for img_name in img_names |
|
] |
|
img_paths = [ |
|
f'VOC{year}/JPEGImages/{img_name}.jpg' for img_name in img_names |
|
] |
|
part_annotations = mmcv.track_progress(parse_xml, |
|
list(zip(xml_paths, img_paths))) |
|
annotations.extend(part_annotations) |
|
mmcv.dump(annotations, out_file) |
|
return annotations |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser( |
|
description='Convert PASCAL VOC annotations to mmdetection format') |
|
parser.add_argument('devkit_path', help='pascal voc devkit path') |
|
parser.add_argument('-o', '--out-dir', help='output path') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
devkit_path = args.devkit_path |
|
out_dir = args.out_dir if args.out_dir else devkit_path |
|
mmcv.mkdir_or_exist(out_dir) |
|
|
|
years = [] |
|
if osp.isdir(osp.join(devkit_path, 'VOC2007')): |
|
years.append('2007') |
|
if osp.isdir(osp.join(devkit_path, 'VOC2012')): |
|
years.append('2012') |
|
if '2007' in years and '2012' in years: |
|
years.append(['2007', '2012']) |
|
if not years: |
|
raise IOError(f'The devkit path {devkit_path} contains neither ' |
|
'"VOC2007" nor "VOC2012" subfolder') |
|
for year in years: |
|
if year == '2007': |
|
prefix = 'voc07' |
|
elif year == '2012': |
|
prefix = 'voc12' |
|
elif year == ['2007', '2012']: |
|
prefix = 'voc0712' |
|
for split in ['train', 'val', 'trainval']: |
|
dataset_name = prefix + '_' + split |
|
print(f'processing {dataset_name} ...') |
|
cvt_annotations(devkit_path, year, split, |
|
osp.join(out_dir, dataset_name + '.pkl')) |
|
if not isinstance(year, list): |
|
dataset_name = prefix + '_test' |
|
print(f'processing {dataset_name} ...') |
|
cvt_annotations(devkit_path, year, 'test', |
|
osp.join(out_dir, dataset_name + '.pkl')) |
|
print('Done!') |
|
|
|
|
|
if __name__ == '__main__': |
|
main()
|
|
|