You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

76 lines
2.7 KiB

2 years ago
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from typing import Any, Callable, Optional, Tuple
import PIL.Image as PImage
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, create_transform
from timm.data.transforms_factory import transforms_imagenet_eval
from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS
from torchvision.transforms import transforms
try:
from torchvision.transforms import InterpolationMode
interpolation = InterpolationMode.BICUBIC
except:
import PIL
interpolation = PIL.Image.BICUBIC
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f: img: PImage.Image = PImage.open(f).convert('RGB')
return img
class ImageNetDataset(DatasetFolder):
def __init__(
self,
imagenet_folder: str,
2 years ago
train: bool,
transform: Callable,
2 years ago
is_valid_file: Optional[Callable[[str], bool]] = None,
):
imagenet_folder = os.path.join(imagenet_folder, 'train' if train else 'val')
2 years ago
super(ImageNetDataset, self).__init__(
imagenet_folder,
2 years ago
loader=pil_loader,
extensions=IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform, target_transform=None, is_valid_file=is_valid_file
2 years ago
)
self.samples = tuple(self.samples)
self.targets = tuple([s[1] for s in self.samples])
2 years ago
def __getitem__(self, index: int) -> Tuple[Any, int]:
2 years ago
path, target = self.samples[index]
return self.transform(self.loader(path)), target
2 years ago
def build_imagenet_pretrain(imagenet_folder, input_size):
trans_train = transforms.Compose([
transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
])
2 years ago
imagenet_folder = os.path.abspath(imagenet_folder)
2 years ago
for postfix in ('train', 'val'):
if imagenet_folder.endswith(postfix):
imagenet_folder = imagenet_folder[:-len(postfix)]
2 years ago
dataset_train = ImageNetDataset(imagenet_folder=imagenet_folder, transform=trans_train, train=True)
print_transform(trans_train, '[pre-train]')
return dataset_train
2 years ago
def print_transform(transform, s):
print(f'Transform {s} = ')
for t in transform.transforms:
print(t)
print('---------------------------\n')