OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
180 lines
6.5 KiB
180 lines
6.5 KiB
# Copyright (c) OpenMMLab. All rights reserved. |
|
import itertools |
|
|
|
import numpy as np |
|
import torch |
|
from mmcv.runner import get_dist_info |
|
from torch.utils.data.sampler import Sampler |
|
|
|
from mmdet.core.utils import sync_random_seed |
|
|
|
|
|
class InfiniteGroupBatchSampler(Sampler): |
|
"""Similar to `BatchSampler` warping a `GroupSampler. It is designed for |
|
iteration-based runners like `IterBasedRunner` and yields a mini-batch |
|
indices each time, all indices in a batch should be in the same group. |
|
|
|
The implementation logic is referred to |
|
https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py |
|
|
|
Args: |
|
dataset (object): The dataset. |
|
batch_size (int): When model is :obj:`DistributedDataParallel`, |
|
it is the number of training samples on each GPU. |
|
When model is :obj:`DataParallel`, it is |
|
`num_gpus * samples_per_gpu`. |
|
Default : 1. |
|
world_size (int, optional): Number of processes participating in |
|
distributed training. Default: None. |
|
rank (int, optional): Rank of current process. Default: None. |
|
seed (int): Random seed. Default: 0. |
|
shuffle (bool): Whether shuffle the indices of a dummy `epoch`, it |
|
should be noted that `shuffle` can not guarantee that you can |
|
generate sequential indices because it need to ensure |
|
that all indices in a batch is in a group. Default: True. |
|
""" # noqa: W605 |
|
|
|
def __init__(self, |
|
dataset, |
|
batch_size=1, |
|
world_size=None, |
|
rank=None, |
|
seed=0, |
|
shuffle=True): |
|
_rank, _world_size = get_dist_info() |
|
if world_size is None: |
|
world_size = _world_size |
|
if rank is None: |
|
rank = _rank |
|
self.rank = rank |
|
self.world_size = world_size |
|
self.dataset = dataset |
|
self.batch_size = batch_size |
|
# Must be the same across all workers. If None, will use a |
|
# random seed shared among workers |
|
# (require synchronization among all workers) |
|
self.seed = sync_random_seed(seed) |
|
self.shuffle = shuffle |
|
|
|
assert hasattr(self.dataset, 'flag') |
|
self.flag = self.dataset.flag |
|
self.group_sizes = np.bincount(self.flag) |
|
# buffer used to save indices of each group |
|
self.buffer_per_group = {k: [] for k in range(len(self.group_sizes))} |
|
|
|
self.size = len(dataset) |
|
self.indices = self._indices_of_rank() |
|
|
|
def _infinite_indices(self): |
|
"""Infinitely yield a sequence of indices.""" |
|
g = torch.Generator() |
|
g.manual_seed(self.seed) |
|
while True: |
|
if self.shuffle: |
|
yield from torch.randperm(self.size, generator=g).tolist() |
|
|
|
else: |
|
yield from torch.arange(self.size).tolist() |
|
|
|
def _indices_of_rank(self): |
|
"""Slice the infinite indices by rank.""" |
|
yield from itertools.islice(self._infinite_indices(), self.rank, None, |
|
self.world_size) |
|
|
|
def __iter__(self): |
|
# once batch size is reached, yield the indices |
|
for idx in self.indices: |
|
flag = self.flag[idx] |
|
group_buffer = self.buffer_per_group[flag] |
|
group_buffer.append(idx) |
|
if len(group_buffer) == self.batch_size: |
|
yield group_buffer[:] |
|
del group_buffer[:] |
|
|
|
def __len__(self): |
|
"""Length of base dataset.""" |
|
return self.size |
|
|
|
def set_epoch(self, epoch): |
|
"""Not supported in `IterationBased` runner.""" |
|
raise NotImplementedError |
|
|
|
|
|
class InfiniteBatchSampler(Sampler): |
|
"""Similar to `BatchSampler` warping a `DistributedSampler. It is designed |
|
iteration-based runners like `IterBasedRunner` and yields a mini-batch |
|
indices each time. |
|
|
|
The implementation logic is referred to |
|
https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py |
|
|
|
Args: |
|
dataset (object): The dataset. |
|
batch_size (int): When model is :obj:`DistributedDataParallel`, |
|
it is the number of training samples on each GPU, |
|
When model is :obj:`DataParallel`, it is |
|
`num_gpus * samples_per_gpu`. |
|
Default : 1. |
|
world_size (int, optional): Number of processes participating in |
|
distributed training. Default: None. |
|
rank (int, optional): Rank of current process. Default: None. |
|
seed (int): Random seed. Default: 0. |
|
shuffle (bool): Whether shuffle the dataset or not. Default: True. |
|
""" # noqa: W605 |
|
|
|
def __init__(self, |
|
dataset, |
|
batch_size=1, |
|
world_size=None, |
|
rank=None, |
|
seed=0, |
|
shuffle=True): |
|
_rank, _world_size = get_dist_info() |
|
if world_size is None: |
|
world_size = _world_size |
|
if rank is None: |
|
rank = _rank |
|
self.rank = rank |
|
self.world_size = world_size |
|
self.dataset = dataset |
|
self.batch_size = batch_size |
|
# Must be the same across all workers. If None, will use a |
|
# random seed shared among workers |
|
# (require synchronization among all workers) |
|
self.seed = sync_random_seed(seed) |
|
self.shuffle = shuffle |
|
self.size = len(dataset) |
|
self.indices = self._indices_of_rank() |
|
|
|
def _infinite_indices(self): |
|
"""Infinitely yield a sequence of indices.""" |
|
g = torch.Generator() |
|
g.manual_seed(self.seed) |
|
while True: |
|
if self.shuffle: |
|
yield from torch.randperm(self.size, generator=g).tolist() |
|
|
|
else: |
|
yield from torch.arange(self.size).tolist() |
|
|
|
def _indices_of_rank(self): |
|
"""Slice the infinite indices by rank.""" |
|
yield from itertools.islice(self._infinite_indices(), self.rank, None, |
|
self.world_size) |
|
|
|
def __iter__(self): |
|
# once batch size is reached, yield the indices |
|
batch_buffer = [] |
|
for idx in self.indices: |
|
batch_buffer.append(idx) |
|
if len(batch_buffer) == self.batch_size: |
|
yield batch_buffer |
|
batch_buffer = [] |
|
|
|
def __len__(self): |
|
"""Length of base dataset.""" |
|
return self.size |
|
|
|
def set_epoch(self, epoch): |
|
"""Not supported in `IterationBased` runner.""" |
|
raise NotImplementedError
|
|
|