OpenMMLab Detection Toolbox and Benchmark https://mmdetection.readthedocs.io/
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

112 lines
4.6 KiB

# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from mmcv.utils import print_log
from mmdet.core import eval_map, eval_recalls
from .builder import DATASETS
from .xml_style import XMLDataset
@DATASETS.register_module()
class VOCDataset(XMLDataset):
CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
'tvmonitor')
PALETTE = [(106, 0, 228), (119, 11, 32), (165, 42, 42), (0, 0, 192),
(197, 226, 255), (0, 60, 100), (0, 0, 142), (255, 77, 255),
(153, 69, 1), (120, 166, 157), (0, 182, 199), (0, 226, 252),
(182, 182, 255), (0, 0, 230), (220, 20, 60), (163, 255, 0),
(0, 82, 0), (3, 95, 161), (0, 80, 100), (183, 130, 88)]
def __init__(self, **kwargs):
super(VOCDataset, self).__init__(**kwargs)
if 'VOC2007' in self.img_prefix:
self.year = 2007
elif 'VOC2012' in self.img_prefix:
self.year = 2012
else:
raise ValueError('Cannot infer dataset year from img_prefix')
def evaluate(self,
results,
metric='mAP',
logger=None,
proposal_nums=(100, 300, 1000),
iou_thr=0.5,
scale_ranges=None):
"""Evaluate in VOC protocol.
Args:
results (list[list | tuple]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. Options are
'mAP', 'recall'.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Default: None.
proposal_nums (Sequence[int]): Proposal number used for evaluating
recalls, such as recall@100, recall@1000.
Default: (100, 300, 1000).
iou_thr (float | list[float]): IoU threshold. Default: 0.5.
scale_ranges (list[tuple], optional): Scale ranges for evaluating
mAP. If not specified, all bounding boxes would be included in
evaluation. Default: None.
Returns:
dict[str, float]: AP/recall metrics.
"""
if not isinstance(metric, str):
assert len(metric) == 1
metric = metric[0]
allowed_metrics = ['mAP', 'recall']
if metric not in allowed_metrics:
raise KeyError(f'metric {metric} is not supported')
annotations = [self.get_ann_info(i) for i in range(len(self))]
eval_results = OrderedDict()
iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr
if metric == 'mAP':
assert isinstance(iou_thrs, list)
if self.year == 2007:
ds_name = 'voc07'
else:
ds_name = self.CLASSES
mean_aps = []
for iou_thr in iou_thrs:
print_log(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}')
# Follow the official implementation,
# http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCdevkit_18-May-2011.tar
# we should use the legacy coordinate system in mmdet 1.x,
# which means w, h should be computed as 'x2 - x1 + 1` and
# `y2 - y1 + 1`
mean_ap, _ = eval_map(
results,
annotations,
scale_ranges=None,
iou_thr=iou_thr,
dataset=ds_name,
logger=logger,
use_legacy_coordinate=True)
mean_aps.append(mean_ap)
eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3)
eval_results['mAP'] = sum(mean_aps) / len(mean_aps)
eval_results.move_to_end('mAP', last=False)
elif metric == 'recall':
gt_bboxes = [ann['bboxes'] for ann in annotations]
recalls = eval_recalls(
gt_bboxes,
results,
proposal_nums,
iou_thrs,
logger=logger,
use_legacy_coordinate=True)
for i, num in enumerate(proposal_nums):
for j, iou_thr in enumerate(iou_thrs):
eval_results[f'recall@{num}@{iou_thr}'] = recalls[i, j]
if recalls.shape[1] > 1:
ar = recalls.mean(axis=1)
for i, num in enumerate(proposal_nums):
eval_results[f'AR@{num}'] = ar[i]
return eval_results