You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
75 lines
2.7 KiB
75 lines
2.7 KiB
# Copyright (c) ByteDance, Inc. and its affiliates. |
|
# All rights reserved. |
|
# |
|
# This source code is licensed under the license found in the |
|
# LICENSE file in the root directory of this source tree. |
|
|
|
import os |
|
from typing import Any, Callable, Optional, Tuple |
|
|
|
import PIL.Image as PImage |
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, create_transform |
|
from timm.data.transforms_factory import transforms_imagenet_eval |
|
from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS |
|
from torchvision.transforms import transforms |
|
try: |
|
from torchvision.transforms import InterpolationMode |
|
interpolation = InterpolationMode.BICUBIC |
|
except: |
|
import PIL |
|
interpolation = PIL.Image.BICUBIC |
|
|
|
|
|
def pil_loader(path): |
|
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) |
|
with open(path, 'rb') as f: img: PImage.Image = PImage.open(f).convert('RGB') |
|
return img |
|
|
|
|
|
class ImageNetDataset(DatasetFolder): |
|
def __init__( |
|
self, |
|
imagenet_folder: str, |
|
train: bool, |
|
transform: Callable, |
|
is_valid_file: Optional[Callable[[str], bool]] = None, |
|
): |
|
imagenet_folder = os.path.join(imagenet_folder, 'train' if train else 'val') |
|
super(ImageNetDataset, self).__init__( |
|
imagenet_folder, |
|
loader=pil_loader, |
|
extensions=IMG_EXTENSIONS if is_valid_file is None else None, |
|
transform=transform, target_transform=None, is_valid_file=is_valid_file |
|
) |
|
|
|
self.samples = tuple(self.samples) |
|
self.targets = tuple([s[1] for s in self.samples]) |
|
|
|
def __getitem__(self, index: int) -> Tuple[Any, int]: |
|
path, target = self.samples[index] |
|
return self.transform(self.loader(path)), target |
|
|
|
|
|
def build_imagenet_pretrain(imagenet_folder, input_size): |
|
trans_train = transforms.Compose([ |
|
transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), |
|
]) |
|
|
|
imagenet_folder = os.path.abspath(imagenet_folder) |
|
for postfix in ('train', 'val'): |
|
if imagenet_folder.endswith(postfix): |
|
imagenet_folder = imagenet_folder[:-len(postfix)] |
|
|
|
dataset_train = ImageNetDataset(imagenet_folder=imagenet_folder, transform=trans_train, train=True) |
|
print_transform(trans_train, '[pre-train]') |
|
return dataset_train |
|
|
|
|
|
def print_transform(transform, s): |
|
print(f'Transform {s} = ') |
|
for t in transform.transforms: |
|
print(t) |
|
print('---------------------------\n')
|
|
|