# Copyright (c) ByteDance, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os from typing import Any, Callable, Optional, Tuple import PIL.Image as PImage from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, create_transform from timm.data.transforms_factory import transforms_imagenet_eval from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS from torchvision.transforms import transforms try: from torchvision.transforms import InterpolationMode interpolation = InterpolationMode.BICUBIC except: import PIL interpolation = PIL.Image.BICUBIC def pil_loader(path): # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) with open(path, 'rb') as f: img: PImage.Image = PImage.open(f).convert('RGB') return img class ImageNetDataset(DatasetFolder): def __init__( self, imagenet_folder: str, train: bool, transform: Callable, is_valid_file: Optional[Callable[[str], bool]] = None, ): imagenet_folder = os.path.join(imagenet_folder, 'train' if train else 'val') super(ImageNetDataset, self).__init__( imagenet_folder, loader=pil_loader, extensions=IMG_EXTENSIONS if is_valid_file is None else None, transform=transform, target_transform=None, is_valid_file=is_valid_file ) self.samples = tuple(self.samples) self.targets = tuple([s[1] for s in self.samples]) def __getitem__(self, index: int) -> Tuple[Any, int]: path, target = self.samples[index] return self.transform(self.loader(path)), target def build_imagenet_pretrain(imagenet_folder, input_size): trans_train = transforms.Compose([ transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), ]) imagenet_folder = os.path.abspath(imagenet_folder) for postfix in ('train', 'val'): if imagenet_folder.endswith(postfix): imagenet_folder = imagenet_folder[:-len(postfix)] dataset_train = ImageNetDataset(imagenet_folder=imagenet_folder, transform=trans_train, train=True) print_transform(trans_train, '[pre-train]') return dataset_train def print_transform(transform, s): print(f'Transform {s} = ') for t in transform.transforms: print(t) print('---------------------------\n')