You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
74 lines
2.6 KiB
74 lines
2.6 KiB
# Copyright (c) ByteDance, Inc. and its affiliates. |
|
# All rights reserved. |
|
# |
|
# This source code is licensed under the license found in the |
|
# LICENSE file in the root directory of this source tree. |
|
|
|
import random |
|
|
|
import numpy as np |
|
import torch |
|
from torch.utils.data.sampler import Sampler |
|
|
|
|
|
def worker_init_fn(worker_id): |
|
# https://pytorch.org/docs/stable/notes/randomness.html#dataloader |
|
worker_seed = torch.initial_seed() % 2 ** 32 |
|
np.random.seed(worker_seed) |
|
random.seed(worker_seed) |
|
|
|
|
|
class DistInfiniteBatchSampler(Sampler): |
|
def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=1, filling=False, shuffle=True): |
|
assert glb_batch_size % world_size == 0 |
|
self.world_size, self.rank = world_size, rank |
|
self.dataset_len = dataset_len |
|
self.glb_batch_size = glb_batch_size |
|
self.batch_size = glb_batch_size // world_size |
|
|
|
self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size |
|
self.filling = filling |
|
self.shuffle = shuffle |
|
self.epoch = 0 |
|
self.seed = seed |
|
self.indices = self.gener_indices() |
|
|
|
def gener_indices(self): |
|
global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0 |
|
if self.shuffle: |
|
g = torch.Generator() |
|
g.manual_seed(self.epoch + self.seed) |
|
global_indices = torch.randperm(self.dataset_len, generator=g) |
|
else: |
|
global_indices = torch.arange(self.dataset_len) |
|
filling = global_max_p - global_indices.shape[0] |
|
if filling > 0 and self.filling: |
|
global_indices = torch.cat((global_indices, global_indices[:filling])) |
|
global_indices = tuple(global_indices.numpy().tolist()) |
|
|
|
seps = torch.linspace(0, len(global_indices), self.world_size + 1, dtype=torch.int) |
|
local_indices = global_indices[seps[self.rank]:seps[self.rank + 1]] |
|
self.max_p = len(local_indices) |
|
return local_indices |
|
|
|
def __iter__(self): |
|
self.epoch = 0 |
|
while True: |
|
self.epoch += 1 |
|
p, q = 0, 0 |
|
while p < self.max_p: |
|
q = p + self.batch_size |
|
yield self.indices[p:q] |
|
p = q |
|
if self.shuffle: |
|
self.indices = self.gener_indices() |
|
|
|
def __len__(self): |
|
return self.iters_per_ep |
|
|
|
|
|
if __name__ == '__main__': |
|
W = 16 |
|
for rk in range(W): |
|
ind = DistInfiniteBatchSampler(W, rk, 5024, 5024).gener_indices() |
|
print(rk, len(ind))
|
|
|