You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
168 lines
6.4 KiB
168 lines
6.4 KiB
2 years ago
|
# Copyright (c) ByteDance, Inc. and its affiliates.
|
||
|
# All rights reserved.
|
||
|
#
|
||
|
# This source code is licensed under the license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
import math
|
||
|
import random
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from torch.utils.data.sampler import Sampler
|
||
|
|
||
|
import dist
|
||
|
|
||
|
|
||
|
def worker_init_fn(worker_id):
|
||
|
# https://pytorch.org/docs/stable/notes/randomness.html#dataloader
|
||
|
worker_seed = torch.initial_seed() % 2 ** 32
|
||
|
np.random.seed(worker_seed)
|
||
|
random.seed(worker_seed)
|
||
|
|
||
|
|
||
|
class RASampler(Sampler):
|
||
|
"""Sampler that restricts data loading to a subset of the dataset for distributed,
|
||
|
with repeated augmentation.
|
||
|
It ensures that different each augmented version of a sample will be visible to a
|
||
|
different process (GPU).
|
||
|
Heavily based on 'torch.utils.data.DistributedSampler'.
|
||
|
This is borrowed from the DeiT Repo:
|
||
|
https://github.com/facebookresearch/deit/blob/main/samplers.py
|
||
|
"""
|
||
|
|
||
|
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3):
|
||
|
if num_replicas is None:
|
||
|
num_replicas = dist.get_world_size()
|
||
|
if rank is None:
|
||
|
rank = dist.get_rank()
|
||
|
self.dataset = dataset
|
||
|
self.num_replicas = num_replicas
|
||
|
self.rank = rank
|
||
|
self.epoch = 0
|
||
|
self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas))
|
||
|
self.total_size = self.num_samples * self.num_replicas
|
||
|
self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
|
||
|
self.shuffle = shuffle
|
||
|
self.seed = seed
|
||
|
self.repetitions = repetitions
|
||
|
|
||
|
def __iter__(self):
|
||
|
if self.shuffle:
|
||
|
# Deterministically shuffle based on epoch
|
||
|
g = torch.Generator()
|
||
|
g.manual_seed(self.seed + self.epoch)
|
||
|
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||
|
else:
|
||
|
indices = list(range(len(self.dataset)))
|
||
|
|
||
|
# Add extra samples to make it evenly divisible
|
||
|
indices = [ele for ele in indices for i in range(self.repetitions)]
|
||
|
indices += indices[: (self.total_size - len(indices))]
|
||
|
assert len(indices) == self.total_size
|
||
|
|
||
|
# Subsample
|
||
|
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||
|
assert len(indices) == self.num_samples
|
||
|
|
||
|
return iter(indices[: self.num_selected_samples])
|
||
|
|
||
|
def __len__(self):
|
||
|
return self.num_selected_samples
|
||
|
|
||
|
def set_epoch(self, epoch):
|
||
|
self.epoch = epoch
|
||
|
|
||
|
|
||
|
class InfiniteBatchSampler(Sampler):
|
||
|
def __init__(self, dataset_len, batch_size, seed=0, filling=False, shuffle=True, drop_last=False):
|
||
|
self.dataset_len = dataset_len
|
||
|
self.batch_size = batch_size
|
||
|
self.iters_per_ep = dataset_len // batch_size if drop_last else (dataset_len + batch_size - 1) // batch_size
|
||
|
self.max_p = self.iters_per_ep * batch_size
|
||
|
self.filling = filling
|
||
|
self.shuffle = shuffle
|
||
|
self.epoch = 0
|
||
|
self.seed = seed
|
||
|
self.indices = self.gener_indices()
|
||
|
|
||
|
def gener_indices(self):
|
||
|
if self.shuffle:
|
||
|
g = torch.Generator()
|
||
|
g.manual_seed(self.epoch + self.seed)
|
||
|
indices = torch.randperm(self.dataset_len, generator=g).numpy()
|
||
|
else:
|
||
|
indices = torch.arange(self.dataset_len).numpy()
|
||
|
|
||
|
tails = self.batch_size - (self.dataset_len % self.batch_size)
|
||
|
if tails != self.batch_size and self.filling:
|
||
|
tails = indices[:tails]
|
||
|
np.random.shuffle(indices)
|
||
|
indices = np.concatenate((indices, tails))
|
||
|
|
||
|
# built-in list/tuple is faster than np.ndarray (when collating the data via a for-loop)
|
||
|
# noinspection PyTypeChecker
|
||
|
return tuple(indices.tolist())
|
||
|
|
||
|
def __iter__(self):
|
||
|
self.epoch = 0
|
||
|
while True:
|
||
|
self.epoch += 1
|
||
|
p, q = 0, 0
|
||
|
while p < self.max_p:
|
||
|
q = p + self.batch_size
|
||
|
yield self.indices[p:q]
|
||
|
p = q
|
||
|
if self.shuffle:
|
||
|
self.indices = self.gener_indices()
|
||
|
|
||
|
def __len__(self):
|
||
|
return self.iters_per_ep
|
||
|
|
||
|
|
||
|
class DistInfiniteBatchSampler(InfiniteBatchSampler):
|
||
|
def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=0, repeated_aug=0, filling=False, shuffle=True):
|
||
|
# from torchvision.models import ResNet50_Weights
|
||
|
# RA sampler: https://github.com/pytorch/vision/blob/5521e9d01ee7033b9ee9d421c1ef6fb211ed3782/references/classification/sampler.py
|
||
|
|
||
|
assert glb_batch_size % world_size == 0
|
||
|
self.world_size, self.rank = world_size, rank
|
||
|
self.dataset_len = dataset_len
|
||
|
self.glb_batch_size = glb_batch_size
|
||
|
self.batch_size = glb_batch_size // world_size
|
||
|
|
||
|
self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size
|
||
|
self.filling = filling
|
||
|
self.shuffle = shuffle
|
||
|
self.repeated_aug = repeated_aug
|
||
|
self.epoch = 0
|
||
|
self.seed = seed
|
||
|
self.indices = self.gener_indices()
|
||
|
|
||
|
def gener_indices(self):
|
||
|
global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0
|
||
|
if self.shuffle:
|
||
|
g = torch.Generator()
|
||
|
g.manual_seed(self.epoch + self.seed)
|
||
|
global_indices = torch.randperm(self.dataset_len, generator=g)
|
||
|
if self.repeated_aug > 1:
|
||
|
global_indices = global_indices[:(self.dataset_len + self.repeated_aug - 1) // self.repeated_aug].repeat_interleave(self.repeated_aug, dim=0)[:global_max_p]
|
||
|
else:
|
||
|
global_indices = torch.arange(self.dataset_len)
|
||
|
filling = global_max_p - global_indices.shape[0]
|
||
|
if filling > 0 and self.filling:
|
||
|
global_indices = torch.cat((global_indices, global_indices[:filling]))
|
||
|
global_indices = tuple(global_indices.numpy().tolist())
|
||
|
|
||
|
seps = torch.linspace(0, len(global_indices), self.world_size + 1, dtype=torch.int)
|
||
|
local_indices = global_indices[seps[self.rank]:seps[self.rank + 1]]
|
||
|
self.max_p = len(local_indices)
|
||
|
return local_indices
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
W = 16
|
||
|
for rk in range(W):
|
||
|
ind = DistInfiniteBatchSampler(W, rk, 5024, 5024).gener_indices()
|
||
|
print(rk, len(ind))
|