OpenMMLab Detection Toolbox and Benchmark https://mmdetection.readthedocs.io/
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

157 lines
5.4 KiB

import bisect
import math
from collections import defaultdict
from unittest.mock import MagicMock
import numpy as np
import pytest
from mmdet.datasets import (DATASETS, ClassBalancedDataset, ConcatDataset,
CustomDataset, RepeatDataset)
@pytest.mark.parametrize('dataset',
['CocoDataset', 'VOCDataset', 'CityscapesDataset'])
def test_custom_classes_override_default(dataset):
dataset_class = DATASETS.get(dataset)
dataset_class.load_annotations = MagicMock()
if dataset in ['CocoDataset', 'CityscapesDataset']:
dataset_class.coco = MagicMock()
dataset_class.cat_ids = MagicMock()
original_classes = dataset_class.CLASSES
# Test setting classes as a tuple
custom_dataset = dataset_class(
ann_file=MagicMock(),
pipeline=[],
classes=('bus', 'car'),
test_mode=True,
img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ('bus', 'car')
assert custom_dataset.custom_classes
# Test setting classes as a list
custom_dataset = dataset_class(
ann_file=MagicMock(),
pipeline=[],
classes=['bus', 'car'],
test_mode=True,
img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ['bus', 'car']
assert custom_dataset.custom_classes
# Test overriding not a subset
custom_dataset = dataset_class(
ann_file=MagicMock(),
pipeline=[],
classes=['foo'],
test_mode=True,
img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ['foo']
assert custom_dataset.custom_classes
# Test default behavior
custom_dataset = dataset_class(
ann_file=MagicMock(),
pipeline=[],
classes=None,
test_mode=True,
img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
assert custom_dataset.CLASSES == original_classes
assert not custom_dataset.custom_classes
# Test sending file path
import tempfile
tmp_file = tempfile.NamedTemporaryFile()
with open(tmp_file.name, 'w') as f:
f.write('bus\ncar\n')
custom_dataset = dataset_class(
ann_file=MagicMock(),
pipeline=[],
classes=tmp_file.name,
test_mode=True,
img_prefix='VOC2007' if dataset == 'VOCDataset' else '')
tmp_file.close()
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == ['bus', 'car']
assert custom_dataset.custom_classes
def test_dataset_wrapper():
CustomDataset.load_annotations = MagicMock()
CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
dataset_a = CustomDataset(
ann_file=MagicMock(), pipeline=[], test_mode=True, img_prefix='')
len_a = 10
cat_ids_list_a = [
np.random.randint(0, 80, num).tolist()
for num in np.random.randint(1, 20, len_a)
]
dataset_a.data_infos = MagicMock()
dataset_a.data_infos.__len__.return_value = len_a
dataset_a.get_cat_ids = MagicMock(
side_effect=lambda idx: cat_ids_list_a[idx])
dataset_b = CustomDataset(
ann_file=MagicMock(), pipeline=[], test_mode=True, img_prefix='')
len_b = 20
cat_ids_list_b = [
np.random.randint(0, 80, num).tolist()
for num in np.random.randint(1, 20, len_b)
]
dataset_b.data_infos = MagicMock()
dataset_b.data_infos.__len__.return_value = len_b
dataset_b.get_cat_ids = MagicMock(
side_effect=lambda idx: cat_ids_list_b[idx])
concat_dataset = ConcatDataset([dataset_a, dataset_b])
assert concat_dataset[5] == 5
assert concat_dataset[25] == 15
assert concat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
assert concat_dataset.get_cat_ids(25) == cat_ids_list_b[15]
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
repeat_dataset = RepeatDataset(dataset_a, 10)
assert repeat_dataset[5] == 5
assert repeat_dataset[15] == 5
assert repeat_dataset[27] == 7
assert repeat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
assert repeat_dataset.get_cat_ids(15) == cat_ids_list_a[5]
assert repeat_dataset.get_cat_ids(27) == cat_ids_list_a[7]
assert len(repeat_dataset) == 10 * len(dataset_a)
category_freq = defaultdict(int)
for cat_ids in cat_ids_list_a:
cat_ids = set(cat_ids)
for cat_id in cat_ids:
category_freq[cat_id] += 1
for k, v in category_freq.items():
category_freq[k] = v / len(cat_ids_list_a)
mean_freq = np.mean(list(category_freq.values()))
repeat_thr = mean_freq
category_repeat = {
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
for cat_id, cat_freq in category_freq.items()
}
repeat_factors = []
for cat_ids in cat_ids_list_a:
cat_ids = set(cat_ids)
repeat_factor = max({category_repeat[cat_id] for cat_id in cat_ids})
repeat_factors.append(math.ceil(repeat_factor))
repeat_factors_cumsum = np.cumsum(repeat_factors)
repeat_factor_dataset = ClassBalancedDataset(dataset_a, repeat_thr)
assert len(repeat_factor_dataset) == repeat_factors_cumsum[-1]
for idx in np.random.randint(0, len(repeat_factor_dataset), 3):
assert repeat_factor_dataset[idx] == bisect.bisect_right(
repeat_factors_cumsum, idx)