# Copyright (c) ByteDance, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os import random import time import PIL.Image as PImage import numpy as np import torch import torchvision from timm.data import AutoAugment as TimmAutoAugment from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, create_transform from timm.data.distributed_sampler import RepeatAugSampler from timm.data.transforms_factory import transforms_imagenet_eval from torch.utils.data import DataLoader from torch.utils.data.sampler import Sampler from torchvision.transforms import AutoAugment as TorchAutoAugment from torchvision.transforms import transforms, TrivialAugmentWide try: from torchvision.transforms import InterpolationMode interpolation = InterpolationMode.BICUBIC except: import PIL interpolation = PIL.Image.BICUBIC def create_classification_dataset(data_path, img_size, rep_aug, workers, batch_size_per_gpu, world_size, global_rank): import warnings warnings.filterwarnings('ignore', category=UserWarning) mean, std = IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD trans_train = create_transform( is_training=True, input_size=img_size, auto_augment='v0', interpolation='bicubic', re_prob=0.25, re_mode='pixel', re_count=1, mean=mean, std=std, ) for i, t in enumerate(trans_train.transforms): if isinstance(t, (TorchAutoAugment, TimmAutoAugment)): trans_train.transforms[i] = TrivialAugmentWide(interpolation=interpolation) break if img_size >= 384: trans_val = transforms.Compose([ transforms.Resize((img_size, img_size), interpolation=interpolation), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) else: trans_val = transforms_imagenet_eval(img_size=img_size, interpolation='bicubic', crop_pct=0.95, mean=mean, std=std) print_transform(trans_train, '[train]') print_transform(trans_val, '[val]') imagenet_folder = os.path.abspath(data_path) for postfix in ('train', 'val'): if imagenet_folder.endswith(postfix): imagenet_folder = imagenet_folder[:-len(postfix)] dataset_train = torchvision.datasets.ImageFolder(os.path.join(imagenet_folder, 'train'), trans_train) dataset_val = torchvision.datasets.ImageFolder(os.path.join(imagenet_folder, 'val'), trans_val) if rep_aug: print(f'[dataset] using repeated augmentation: count={rep_aug}') train_sp = RepeatAugSampler(dataset_train, shuffle=True, num_repeats=rep_aug) else: train_sp = torch.utils.data.distributed.DistributedSampler(dataset_train, shuffle=True, drop_last=True) loader_train = DataLoader( dataset=dataset_train, num_workers=workers, pin_memory=True, batch_size=batch_size_per_gpu, sampler=train_sp, persistent_workers=workers > 0, worker_init_fn=worker_init_fn, ) iters_train = len(loader_train) print(f'[dataset: train] bs={world_size}x{batch_size_per_gpu}={world_size * batch_size_per_gpu}, num_iters={iters_train}') val_ratio = 2 loader_val = DataLoader( dataset=dataset_val, num_workers=workers, pin_memory=True, batch_sampler=DistInfiniteBatchSampler(world_size, global_rank, len(dataset_val), glb_batch_size=val_ratio * batch_size_per_gpu, filling=False, shuffle=False), worker_init_fn=worker_init_fn, ) iters_val = len(loader_val) print(f'[dataset: val] bs={world_size}x{val_ratio * batch_size_per_gpu}={val_ratio * world_size * batch_size_per_gpu}, num_iters={iters_val}') time.sleep(3) warnings.resetwarnings() return loader_train, iters_train, iter(loader_val), iters_val def worker_init_fn(worker_id): # see: https://pytorch.org/docs/stable/notes/randomness.html#dataloader worker_seed = torch.initial_seed() % 2 ** 32 np.random.seed(worker_seed) random.seed(worker_seed) def print_transform(transform, s): print(f'Transform {s} = ') for t in transform.transforms: print(t) print('---------------------------\n') class DistInfiniteBatchSampler(Sampler): def __init__(self, world_size, global_rank, dataset_len, glb_batch_size, seed=0, filling=False, shuffle=True): assert glb_batch_size % world_size == 0 self.world_size, self.rank = world_size, global_rank self.dataset_len = dataset_len self.glb_batch_size = glb_batch_size self.batch_size = glb_batch_size // world_size self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size self.filling = filling self.shuffle = shuffle self.epoch = 0 self.seed = seed self.indices = self.gener_indices() def gener_indices(self): global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0 if self.shuffle: g = torch.Generator() g.manual_seed(self.epoch + self.seed) global_indices = torch.randperm(self.dataset_len, generator=g) else: global_indices = torch.arange(self.dataset_len) filling = global_max_p - global_indices.shape[0] if filling > 0 and self.filling: global_indices = torch.cat((global_indices, global_indices[:filling])) global_indices = tuple(global_indices.numpy().tolist()) seps = torch.linspace(0, len(global_indices), self.world_size + 1, dtype=torch.int) local_indices = global_indices[seps[self.rank]:seps[self.rank + 1]] self.max_p = len(local_indices) return local_indices def __iter__(self): self.epoch = 0 while True: self.epoch += 1 p, q = 0, 0 while p < self.max_p: q = p + self.batch_size yield self.indices[p:q] p = q if self.shuffle: self.indices = self.gener_indices() def __len__(self): return self.iters_per_ep