# Copyright (c) ByteDance, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math import random import numpy as np import torch from torch.utils.data.sampler import Sampler import dist def worker_init_fn(worker_id): # https://pytorch.org/docs/stable/notes/randomness.html#dataloader worker_seed = torch.initial_seed() % 2 ** 32 np.random.seed(worker_seed) random.seed(worker_seed) class RASampler(Sampler): """Sampler that restricts data loading to a subset of the dataset for distributed, with repeated augmentation. It ensures that different each augmented version of a sample will be visible to a different process (GPU). Heavily based on 'torch.utils.data.DistributedSampler'. This is borrowed from the DeiT Repo: https://github.com/facebookresearch/deit/blob/main/samplers.py """ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, repetitions=3): if num_replicas is None: num_replicas = dist.get_world_size() if rank is None: rank = dist.get_rank() self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.num_samples = int(math.ceil(len(self.dataset) * float(repetitions) / self.num_replicas)) self.total_size = self.num_samples * self.num_replicas self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) self.shuffle = shuffle self.seed = seed self.repetitions = repetitions def __iter__(self): if self.shuffle: # Deterministically shuffle based on epoch g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() else: indices = list(range(len(self.dataset))) # Add extra samples to make it evenly divisible indices = [ele for ele in indices for i in range(self.repetitions)] indices += indices[: (self.total_size - len(indices))] assert len(indices) == self.total_size # Subsample indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples return iter(indices[: self.num_selected_samples]) def __len__(self): return self.num_selected_samples def set_epoch(self, epoch): self.epoch = epoch class InfiniteBatchSampler(Sampler): def __init__(self, dataset_len, batch_size, seed=0, filling=False, shuffle=True, drop_last=False): self.dataset_len = dataset_len self.batch_size = batch_size self.iters_per_ep = dataset_len // batch_size if drop_last else (dataset_len + batch_size - 1) // batch_size self.max_p = self.iters_per_ep * batch_size self.filling = filling self.shuffle = shuffle self.epoch = 0 self.seed = seed self.indices = self.gener_indices() def gener_indices(self): if self.shuffle: g = torch.Generator() g.manual_seed(self.epoch + self.seed) indices = torch.randperm(self.dataset_len, generator=g).numpy() else: indices = torch.arange(self.dataset_len).numpy() tails = self.batch_size - (self.dataset_len % self.batch_size) if tails != self.batch_size and self.filling: tails = indices[:tails] np.random.shuffle(indices) indices = np.concatenate((indices, tails)) # built-in list/tuple is faster than np.ndarray (when collating the data via a for-loop) # noinspection PyTypeChecker return tuple(indices.tolist()) def __iter__(self): self.epoch = 0 while True: self.epoch += 1 p, q = 0, 0 while p < self.max_p: q = p + self.batch_size yield self.indices[p:q] p = q if self.shuffle: self.indices = self.gener_indices() def __len__(self): return self.iters_per_ep class DistInfiniteBatchSampler(InfiniteBatchSampler): def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=0, repeated_aug=0, filling=False, shuffle=True): # from torchvision.models import ResNet50_Weights # RA sampler: https://github.com/pytorch/vision/blob/5521e9d01ee7033b9ee9d421c1ef6fb211ed3782/references/classification/sampler.py assert glb_batch_size % world_size == 0 self.world_size, self.rank = world_size, rank self.dataset_len = dataset_len self.glb_batch_size = glb_batch_size self.batch_size = glb_batch_size // world_size self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size self.filling = filling self.shuffle = shuffle self.repeated_aug = repeated_aug self.epoch = 0 self.seed = seed self.indices = self.gener_indices() def gener_indices(self): global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0 if self.shuffle: g = torch.Generator() g.manual_seed(self.epoch + self.seed) global_indices = torch.randperm(self.dataset_len, generator=g) if self.repeated_aug > 1: global_indices = global_indices[:(self.dataset_len + self.repeated_aug - 1) // self.repeated_aug].repeat_interleave(self.repeated_aug, dim=0)[:global_max_p] else: global_indices = torch.arange(self.dataset_len) filling = global_max_p - global_indices.shape[0] if filling > 0 and self.filling: global_indices = torch.cat((global_indices, global_indices[:filling])) global_indices = tuple(global_indices.numpy().tolist()) seps = torch.linspace(0, len(global_indices), self.world_size + 1, dtype=torch.int) local_indices = global_indices[seps[self.rank]:seps[self.rank + 1]] self.max_p = len(local_indices) return local_indices if __name__ == '__main__': W = 16 for rk in range(W): ind = DistInfiniteBatchSampler(W, rk, 5024, 5024).gener_indices() print(rk, len(ind))