You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
87 lines
3.1 KiB
87 lines
3.1 KiB
# Copyright (c) ByteDance, Inc. and its affiliates. |
|
# All rights reserved. |
|
# |
|
# This source code is licensed under the license found in the |
|
# LICENSE file in the root directory of this source tree. |
|
|
|
import os |
|
from typing import Any, Callable, Optional, Tuple |
|
|
|
import PIL.Image as PImage |
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS |
|
from torchvision.transforms import transforms |
|
from torch.utils.data import Dataset |
|
|
|
try: |
|
from torchvision.transforms import InterpolationMode |
|
interpolation = InterpolationMode.BICUBIC |
|
except: |
|
import PIL |
|
interpolation = PIL.Image.BICUBIC |
|
|
|
|
|
def pil_loader(path): |
|
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) |
|
with open(path, 'rb') as f: img: PImage.Image = PImage.open(f).convert('RGB') |
|
return img |
|
|
|
|
|
class ImageNetDataset(DatasetFolder): |
|
def __init__( |
|
self, |
|
imagenet_folder: str, |
|
train: bool, |
|
transform: Callable, |
|
is_valid_file: Optional[Callable[[str], bool]] = None, |
|
): |
|
imagenet_folder = os.path.join(imagenet_folder, 'train' if train else 'val') |
|
super(ImageNetDataset, self).__init__( |
|
imagenet_folder, |
|
loader=pil_loader, |
|
extensions=IMG_EXTENSIONS if is_valid_file is None else None, |
|
transform=transform, |
|
target_transform=None, is_valid_file=is_valid_file |
|
) |
|
|
|
self.samples = tuple(img for (img, label) in self.samples) |
|
self.targets = None # this is self-supervised learning so we don't need labels |
|
|
|
def __getitem__(self, index: int) -> Any: |
|
img_file_path = self.samples[index] |
|
return self.transform(self.loader(img_file_path)) |
|
|
|
|
|
def build_dataset_to_pretrain(dataset_path, input_size) -> Dataset: |
|
""" |
|
You may need to modify this function to return your own dataset. |
|
Define a new class, a subclass of `Dataset`, to replace our ImageNetDataset. |
|
Use dataset_path to build your image file path list. |
|
Use input_size to create the transformation function for your images, can refer to the `trans_train` blow. |
|
|
|
:param dataset_path: the folder of dataset |
|
:param input_size: the input size (image resolution) |
|
:return: the dataset used for pretraining |
|
""" |
|
trans_train = transforms.Compose([ |
|
transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), |
|
]) |
|
|
|
dataset_path = os.path.abspath(dataset_path) |
|
for postfix in ('train', 'val'): |
|
if dataset_path.endswith(postfix): |
|
dataset_path = dataset_path[:-len(postfix)] |
|
|
|
dataset_train = ImageNetDataset(imagenet_folder=dataset_path, transform=trans_train, train=True) |
|
print_transform(trans_train, '[pre-train]') |
|
return dataset_train |
|
|
|
|
|
def print_transform(transform, s): |
|
print(f'Transform {s} = ') |
|
for t in transform.transforms: |
|
print(t) |
|
print('---------------------------\n')
|
|
|