|
|
|
# Copyright (c) ByteDance, Inc. and its affiliates.
|
|
|
|
# All rights reserved.
|
|
|
|
#
|
|
|
|
# This source code is licensed under the license found in the
|
|
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
import os
|
|
|
|
from typing import Any, Callable, Optional, Tuple
|
|
|
|
|
|
|
|
import PIL.Image as PImage
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, create_transform
|
|
|
|
from timm.data.transforms_factory import transforms_imagenet_eval
|
|
|
|
from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS
|
|
|
|
from torchvision.transforms import transforms
|
|
|
|
try:
|
|
|
|
from torchvision.transforms import InterpolationMode
|
|
|
|
interpolation = InterpolationMode.BICUBIC
|
|
|
|
except:
|
|
|
|
import PIL
|
|
|
|
interpolation = PIL.Image.BICUBIC
|
|
|
|
|
|
|
|
|
|
|
|
def pil_loader(path):
|
|
|
|
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
|
|
|
with open(path, 'rb') as f: img: PImage.Image = PImage.open(f).convert('RGB')
|
|
|
|
return img
|
|
|
|
|
|
|
|
|
|
|
|
class ImageNetDataset(DatasetFolder):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
imagenet_folder: str,
|
|
|
|
train: bool,
|
|
|
|
transform: Callable,
|
|
|
|
is_valid_file: Optional[Callable[[str], bool]] = None,
|
|
|
|
):
|
|
|
|
imagenet_folder = os.path.join(imagenet_folder, 'train' if train else 'val')
|
|
|
|
super(ImageNetDataset, self).__init__(
|
|
|
|
imagenet_folder,
|
|
|
|
loader=pil_loader,
|
|
|
|
extensions=IMG_EXTENSIONS if is_valid_file is None else None,
|
|
|
|
transform=transform, target_transform=None, is_valid_file=is_valid_file
|
|
|
|
)
|
|
|
|
|
|
|
|
self.samples = tuple(self.samples)
|
|
|
|
self.targets = tuple([s[1] for s in self.samples])
|
|
|
|
|
|
|
|
def __getitem__(self, index: int) -> Tuple[Any, int]:
|
|
|
|
path, target = self.samples[index]
|
|
|
|
return self.transform(self.loader(path)), target
|
|
|
|
|
|
|
|
|
|
|
|
def build_imagenet_pretrain(imagenet_folder, input_size):
|
|
|
|
trans_train = transforms.Compose([
|
|
|
|
transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation),
|
|
|
|
transforms.RandomHorizontalFlip(),
|
|
|
|
transforms.ToTensor(),
|
|
|
|
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
])
|
|
|
|
|
|
|
|
imagenet_folder = os.path.abspath(imagenet_folder)
|
|
|
|
for postfix in ('train', 'val'):
|
|
|
|
if imagenet_folder.endswith(postfix):
|
|
|
|
imagenet_folder = imagenet_folder[:-len(postfix)]
|
|
|
|
|
|
|
|
dataset_train = ImageNetDataset(imagenet_folder=imagenet_folder, transform=trans_train, train=True)
|
|
|
|
print_transform(trans_train, '[pre-train]')
|
|
|
|
return dataset_train
|
|
|
|
|
|
|
|
|
|
|
|
def print_transform(transform, s):
|
|
|
|
print(f'Transform {s} = ')
|
|
|
|
for t in transform.transforms:
|
|
|
|
print(t)
|
|
|
|
print('---------------------------\n')
|