OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
429 lines
16 KiB
429 lines
16 KiB
# Copyright (c) OpenMMLab. All rights reserved. |
|
import bisect |
|
import collections |
|
import copy |
|
import math |
|
from collections import defaultdict |
|
|
|
import numpy as np |
|
from mmcv.utils import build_from_cfg, print_log |
|
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset |
|
|
|
from .builder import DATASETS, PIPELINES |
|
from .coco import CocoDataset |
|
|
|
|
|
@DATASETS.register_module() |
|
class ConcatDataset(_ConcatDataset): |
|
"""A wrapper of concatenated dataset. |
|
|
|
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but |
|
concat the group flag for image aspect ratio. |
|
|
|
Args: |
|
datasets (list[:obj:`Dataset`]): A list of datasets. |
|
separate_eval (bool): Whether to evaluate the results |
|
separately if it is used as validation dataset. |
|
Defaults to True. |
|
""" |
|
|
|
def __init__(self, datasets, separate_eval=True): |
|
super(ConcatDataset, self).__init__(datasets) |
|
self.CLASSES = datasets[0].CLASSES |
|
self.PALETTE = getattr(datasets[0], 'PALETTE', None) |
|
self.separate_eval = separate_eval |
|
if not separate_eval: |
|
if any([isinstance(ds, CocoDataset) for ds in datasets]): |
|
raise NotImplementedError( |
|
'Evaluating concatenated CocoDataset as a whole is not' |
|
' supported! Please set "separate_eval=True"') |
|
elif len(set([type(ds) for ds in datasets])) != 1: |
|
raise NotImplementedError( |
|
'All the datasets should have same types') |
|
|
|
if hasattr(datasets[0], 'flag'): |
|
flags = [] |
|
for i in range(0, len(datasets)): |
|
flags.append(datasets[i].flag) |
|
self.flag = np.concatenate(flags) |
|
|
|
def get_cat_ids(self, idx): |
|
"""Get category ids of concatenated dataset by index. |
|
|
|
Args: |
|
idx (int): Index of data. |
|
|
|
Returns: |
|
list[int]: All categories in the image of specified index. |
|
""" |
|
|
|
if idx < 0: |
|
if -idx > len(self): |
|
raise ValueError( |
|
'absolute value of index should not exceed dataset length') |
|
idx = len(self) + idx |
|
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
|
if dataset_idx == 0: |
|
sample_idx = idx |
|
else: |
|
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
|
return self.datasets[dataset_idx].get_cat_ids(sample_idx) |
|
|
|
def get_ann_info(self, idx): |
|
"""Get annotation of concatenated dataset by index. |
|
|
|
Args: |
|
idx (int): Index of data. |
|
|
|
Returns: |
|
dict: Annotation info of specified index. |
|
""" |
|
|
|
if idx < 0: |
|
if -idx > len(self): |
|
raise ValueError( |
|
'absolute value of index should not exceed dataset length') |
|
idx = len(self) + idx |
|
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
|
if dataset_idx == 0: |
|
sample_idx = idx |
|
else: |
|
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
|
return self.datasets[dataset_idx].get_ann_info(sample_idx) |
|
|
|
def evaluate(self, results, logger=None, **kwargs): |
|
"""Evaluate the results. |
|
|
|
Args: |
|
results (list[list | tuple]): Testing results of the dataset. |
|
logger (logging.Logger | str | None): Logger used for printing |
|
related information during evaluation. Default: None. |
|
|
|
Returns: |
|
dict[str: float]: AP results of the total dataset or each separate |
|
dataset if `self.separate_eval=True`. |
|
""" |
|
assert len(results) == self.cumulative_sizes[-1], \ |
|
('Dataset and results have different sizes: ' |
|
f'{self.cumulative_sizes[-1]} v.s. {len(results)}') |
|
|
|
# Check whether all the datasets support evaluation |
|
for dataset in self.datasets: |
|
assert hasattr(dataset, 'evaluate'), \ |
|
f'{type(dataset)} does not implement evaluate function' |
|
|
|
if self.separate_eval: |
|
dataset_idx = -1 |
|
total_eval_results = dict() |
|
for size, dataset in zip(self.cumulative_sizes, self.datasets): |
|
start_idx = 0 if dataset_idx == -1 else \ |
|
self.cumulative_sizes[dataset_idx] |
|
end_idx = self.cumulative_sizes[dataset_idx + 1] |
|
|
|
results_per_dataset = results[start_idx:end_idx] |
|
print_log( |
|
f'\nEvaluateing {dataset.ann_file} with ' |
|
f'{len(results_per_dataset)} images now', |
|
logger=logger) |
|
|
|
eval_results_per_dataset = dataset.evaluate( |
|
results_per_dataset, logger=logger, **kwargs) |
|
dataset_idx += 1 |
|
for k, v in eval_results_per_dataset.items(): |
|
total_eval_results.update({f'{dataset_idx}_{k}': v}) |
|
|
|
return total_eval_results |
|
elif any([isinstance(ds, CocoDataset) for ds in self.datasets]): |
|
raise NotImplementedError( |
|
'Evaluating concatenated CocoDataset as a whole is not' |
|
' supported! Please set "separate_eval=True"') |
|
elif len(set([type(ds) for ds in self.datasets])) != 1: |
|
raise NotImplementedError( |
|
'All the datasets should have same types') |
|
else: |
|
original_data_infos = self.datasets[0].data_infos |
|
self.datasets[0].data_infos = sum( |
|
[dataset.data_infos for dataset in self.datasets], []) |
|
eval_results = self.datasets[0].evaluate( |
|
results, logger=logger, **kwargs) |
|
self.datasets[0].data_infos = original_data_infos |
|
return eval_results |
|
|
|
|
|
@DATASETS.register_module() |
|
class RepeatDataset: |
|
"""A wrapper of repeated dataset. |
|
|
|
The length of repeated dataset will be `times` larger than the original |
|
dataset. This is useful when the data loading time is long but the dataset |
|
is small. Using RepeatDataset can reduce the data loading time between |
|
epochs. |
|
|
|
Args: |
|
dataset (:obj:`Dataset`): The dataset to be repeated. |
|
times (int): Repeat times. |
|
""" |
|
|
|
def __init__(self, dataset, times): |
|
self.dataset = dataset |
|
self.times = times |
|
self.CLASSES = dataset.CLASSES |
|
self.PALETTE = getattr(dataset, 'PALETTE', None) |
|
if hasattr(self.dataset, 'flag'): |
|
self.flag = np.tile(self.dataset.flag, times) |
|
|
|
self._ori_len = len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
return self.dataset[idx % self._ori_len] |
|
|
|
def get_cat_ids(self, idx): |
|
"""Get category ids of repeat dataset by index. |
|
|
|
Args: |
|
idx (int): Index of data. |
|
|
|
Returns: |
|
list[int]: All categories in the image of specified index. |
|
""" |
|
|
|
return self.dataset.get_cat_ids(idx % self._ori_len) |
|
|
|
def get_ann_info(self, idx): |
|
"""Get annotation of repeat dataset by index. |
|
|
|
Args: |
|
idx (int): Index of data. |
|
|
|
Returns: |
|
dict: Annotation info of specified index. |
|
""" |
|
|
|
return self.dataset.get_ann_info(idx % self._ori_len) |
|
|
|
def __len__(self): |
|
"""Length after repetition.""" |
|
return self.times * self._ori_len |
|
|
|
|
|
# Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa |
|
@DATASETS.register_module() |
|
class ClassBalancedDataset: |
|
"""A wrapper of repeated dataset with repeat factor. |
|
|
|
Suitable for training on class imbalanced datasets like LVIS. Following |
|
the sampling strategy in the `paper <https://arxiv.org/abs/1908.03195>`_, |
|
in each epoch, an image may appear multiple times based on its |
|
"repeat factor". |
|
The repeat factor for an image is a function of the frequency the rarest |
|
category labeled in that image. The "frequency of category c" in [0, 1] |
|
is defined by the fraction of images in the training set (without repeats) |
|
in which category c appears. |
|
The dataset needs to instantiate :func:`self.get_cat_ids` to support |
|
ClassBalancedDataset. |
|
|
|
The repeat factor is computed as followed. |
|
|
|
1. For each category c, compute the fraction # of images |
|
that contain it: :math:`f(c)` |
|
2. For each category c, compute the category-level repeat factor: |
|
:math:`r(c) = max(1, sqrt(t/f(c)))` |
|
3. For each image I, compute the image-level repeat factor: |
|
:math:`r(I) = max_{c in I} r(c)` |
|
|
|
Args: |
|
dataset (:obj:`CustomDataset`): The dataset to be repeated. |
|
oversample_thr (float): frequency threshold below which data is |
|
repeated. For categories with ``f_c >= oversample_thr``, there is |
|
no oversampling. For categories with ``f_c < oversample_thr``, the |
|
degree of oversampling following the square-root inverse frequency |
|
heuristic above. |
|
filter_empty_gt (bool, optional): If set true, images without bounding |
|
boxes will not be oversampled. Otherwise, they will be categorized |
|
as the pure background class and involved into the oversampling. |
|
Default: True. |
|
""" |
|
|
|
def __init__(self, dataset, oversample_thr, filter_empty_gt=True): |
|
self.dataset = dataset |
|
self.oversample_thr = oversample_thr |
|
self.filter_empty_gt = filter_empty_gt |
|
self.CLASSES = dataset.CLASSES |
|
self.PALETTE = getattr(dataset, 'PALETTE', None) |
|
|
|
repeat_factors = self._get_repeat_factors(dataset, oversample_thr) |
|
repeat_indices = [] |
|
for dataset_idx, repeat_factor in enumerate(repeat_factors): |
|
repeat_indices.extend([dataset_idx] * math.ceil(repeat_factor)) |
|
self.repeat_indices = repeat_indices |
|
|
|
flags = [] |
|
if hasattr(self.dataset, 'flag'): |
|
for flag, repeat_factor in zip(self.dataset.flag, repeat_factors): |
|
flags.extend([flag] * int(math.ceil(repeat_factor))) |
|
assert len(flags) == len(repeat_indices) |
|
self.flag = np.asarray(flags, dtype=np.uint8) |
|
|
|
def _get_repeat_factors(self, dataset, repeat_thr): |
|
"""Get repeat factor for each images in the dataset. |
|
|
|
Args: |
|
dataset (:obj:`CustomDataset`): The dataset |
|
repeat_thr (float): The threshold of frequency. If an image |
|
contains the categories whose frequency below the threshold, |
|
it would be repeated. |
|
|
|
Returns: |
|
list[float]: The repeat factors for each images in the dataset. |
|
""" |
|
|
|
# 1. For each category c, compute the fraction # of images |
|
# that contain it: f(c) |
|
category_freq = defaultdict(int) |
|
num_images = len(dataset) |
|
for idx in range(num_images): |
|
cat_ids = set(self.dataset.get_cat_ids(idx)) |
|
if len(cat_ids) == 0 and not self.filter_empty_gt: |
|
cat_ids = set([len(self.CLASSES)]) |
|
for cat_id in cat_ids: |
|
category_freq[cat_id] += 1 |
|
for k, v in category_freq.items(): |
|
category_freq[k] = v / num_images |
|
|
|
# 2. For each category c, compute the category-level repeat factor: |
|
# r(c) = max(1, sqrt(t/f(c))) |
|
category_repeat = { |
|
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq)) |
|
for cat_id, cat_freq in category_freq.items() |
|
} |
|
|
|
# 3. For each image I, compute the image-level repeat factor: |
|
# r(I) = max_{c in I} r(c) |
|
repeat_factors = [] |
|
for idx in range(num_images): |
|
cat_ids = set(self.dataset.get_cat_ids(idx)) |
|
if len(cat_ids) == 0 and not self.filter_empty_gt: |
|
cat_ids = set([len(self.CLASSES)]) |
|
repeat_factor = 1 |
|
if len(cat_ids) > 0: |
|
repeat_factor = max( |
|
{category_repeat[cat_id] |
|
for cat_id in cat_ids}) |
|
repeat_factors.append(repeat_factor) |
|
|
|
return repeat_factors |
|
|
|
def __getitem__(self, idx): |
|
ori_index = self.repeat_indices[idx] |
|
return self.dataset[ori_index] |
|
|
|
def get_ann_info(self, idx): |
|
"""Get annotation of dataset by index. |
|
|
|
Args: |
|
idx (int): Index of data. |
|
|
|
Returns: |
|
dict: Annotation info of specified index. |
|
""" |
|
ori_index = self.repeat_indices[idx] |
|
return self.dataset.get_ann_info(ori_index) |
|
|
|
def __len__(self): |
|
"""Length after repetition.""" |
|
return len(self.repeat_indices) |
|
|
|
|
|
@DATASETS.register_module() |
|
class MultiImageMixDataset: |
|
"""A wrapper of multiple images mixed dataset. |
|
|
|
Suitable for training on multiple images mixed data augmentation like |
|
mosaic and mixup. For the augmentation pipeline of mixed image data, |
|
the `get_indexes` method needs to be provided to obtain the image |
|
indexes, and you can set `skip_flags` to change the pipeline running |
|
process. At the same time, we provide the `dynamic_scale` parameter |
|
to dynamically change the output image size. |
|
|
|
Args: |
|
dataset (:obj:`CustomDataset`): The dataset to be mixed. |
|
pipeline (Sequence[dict]): Sequence of transform object or |
|
config dict to be composed. |
|
dynamic_scale (tuple[int], optional): The image scale can be changed |
|
dynamically. Default to None. It is deprecated. |
|
skip_type_keys (list[str], optional): Sequence of type string to |
|
be skip pipeline. Default to None. |
|
""" |
|
|
|
def __init__(self, |
|
dataset, |
|
pipeline, |
|
dynamic_scale=None, |
|
skip_type_keys=None): |
|
if dynamic_scale is not None: |
|
raise RuntimeError( |
|
'dynamic_scale is deprecated. Please use Resize pipeline ' |
|
'to achieve similar functions') |
|
assert isinstance(pipeline, collections.abc.Sequence) |
|
if skip_type_keys is not None: |
|
assert all([ |
|
isinstance(skip_type_key, str) |
|
for skip_type_key in skip_type_keys |
|
]) |
|
self._skip_type_keys = skip_type_keys |
|
|
|
self.pipeline = [] |
|
self.pipeline_types = [] |
|
for transform in pipeline: |
|
if isinstance(transform, dict): |
|
self.pipeline_types.append(transform['type']) |
|
transform = build_from_cfg(transform, PIPELINES) |
|
self.pipeline.append(transform) |
|
else: |
|
raise TypeError('pipeline must be a dict') |
|
|
|
self.dataset = dataset |
|
self.CLASSES = dataset.CLASSES |
|
self.PALETTE = getattr(dataset, 'PALETTE', None) |
|
if hasattr(self.dataset, 'flag'): |
|
self.flag = dataset.flag |
|
self.num_samples = len(dataset) |
|
|
|
def __len__(self): |
|
return self.num_samples |
|
|
|
def __getitem__(self, idx): |
|
results = copy.deepcopy(self.dataset[idx]) |
|
for (transform, transform_type) in zip(self.pipeline, |
|
self.pipeline_types): |
|
if self._skip_type_keys is not None and \ |
|
transform_type in self._skip_type_keys: |
|
continue |
|
|
|
if hasattr(transform, 'get_indexes'): |
|
indexes = transform.get_indexes(self.dataset) |
|
if not isinstance(indexes, collections.abc.Sequence): |
|
indexes = [indexes] |
|
mix_results = [ |
|
copy.deepcopy(self.dataset[index]) for index in indexes |
|
] |
|
results['mix_results'] = mix_results |
|
|
|
results = transform(results) |
|
|
|
if 'mix_results' in results: |
|
results.pop('mix_results') |
|
|
|
return results |
|
|
|
def update_skip_type_keys(self, skip_type_keys): |
|
"""Update skip_type_keys. It is called by an external hook. |
|
|
|
Args: |
|
skip_type_keys (list[str], optional): Sequence of type |
|
string to be skip pipeline. |
|
""" |
|
assert all([ |
|
isinstance(skip_type_key, str) for skip_type_key in skip_type_keys |
|
]) |
|
self._skip_type_keys = skip_type_keys
|
|
|