You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

172 lines
5.8 KiB

# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from typing import Any, Callable, Optional, Tuple
import PIL.Image as PImage
import torch
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, create_transform
from timm.data.transforms_factory import transforms_imagenet_eval
from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS
from torchvision.transforms import transforms
import dist
try:
from torchvision.transforms import InterpolationMode
interpolation = InterpolationMode.BICUBIC
except:
import PIL
interpolation = PIL.Image.BICUBIC
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f: img: PImage.Image = PImage.open(f).convert('RGB')
return img
class ImageNetDataset(DatasetFolder):
def __init__(
self,
root: str,
train: bool,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
max_cls_id: int = 1000,
only=-1,
):
for postfix in (os.path.sep, 'train', 'val'):
if root.endswith(postfix):
root = root[:-len(postfix)]
root = os.path.join(root, 'train' if train else 'val')
super(ImageNetDataset, self).__init__(
root,
# loader=ImageLoader(train),
loader=pil_loader,
extensions=IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform, target_transform=target_transform, is_valid_file=is_valid_file
)
if only > 0:
g = torch.Generator()
g.manual_seed(0)
idx = torch.randperm(len(self.samples), generator=g).numpy().tolist()
ws = dist.get_world_size()
res = (max_cls_id * only) % ws
more = 0 if res == 0 else (ws - res)
max_total = max_cls_id * only + more
if (max_total // ws) % 2 == 1:
more += ws
max_total += ws
d = {c: [] for c in range(max_cls_id)}
max_len = {c: only for c in range(max_cls_id)}
for c in range(max_cls_id-more, max_cls_id):
max_len[c] += 1
total = 0
for i in idx:
path, target = self.samples[i]
if len(d[target]) < max_len[target]:
d[target].append((path, target))
total += 1
if total == max_total:
break
sp = []
[sp.extend(l) for l in d.values()]
print(f'[ds] more={more}, len(sp)={len(sp)}')
self.samples = tuple(sp)
self.targets = tuple([s[1] for s in self.samples])
else:
self.samples = tuple(filter(lambda item: item[-1] < max_cls_id, self.samples))
self.targets = tuple([s[1] for s in self.samples])
def __getitem__(self, index: int) -> Tuple[Any, Any]:
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def build_imagenet(mode, data_path, data_set, img_size, eval_crop_pct=None, rrc=0.3, aa='rand-m7-mstd0.5', re_prob=0.0, colorj=0.4):
mean, std = IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
norm = transforms.Normalize(mean=mean, std=std)
if img_size >= 384:
trans_val = transforms.Compose([
transforms.Resize((img_size, img_size), interpolation=interpolation),
transforms.ToTensor(),
norm,
])
else:
trans_val = transforms_imagenet_eval(
img_size=img_size, interpolation='bicubic', crop_pct=eval_crop_pct,
mean=mean, std=std
)
mode = mode.lower()
if mode == 'pt':
trans_train = transforms.Compose([
transforms.RandomResizedCrop(img_size, scale=(rrc, 1.0), interpolation=interpolation),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
norm,
])
elif mode == 'le':
trans_train = transforms.Compose([
transforms.RandomResizedCrop(img_size, interpolation=interpolation),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
norm,
])
else:
trans_train = create_transform(
is_training=True,
input_size=img_size,
auto_augment=aa,
interpolation='bicubic',
re_prob=re_prob,
re_mode='pixel',
re_count=1,
color_jitter=colorj,
mean=mean, std=std,
)
if data_path.endswith(os.path.sep):
data_path = data_path[:-len(os.path.sep)]
for postfix in ('train', 'val'):
if data_path.endswith(postfix):
data_path = data_path[:-len(postfix)]
if data_set == 'imn':
dataset_train = ImageNetDataset(root=data_path, transform=trans_train, train=True)
dataset_val = ImageNetDataset(root=data_path, transform=trans_val, train=False)
num_classes = 1000
else:
raise NotImplementedError
print_transform(trans_train, '[train]')
print_transform(trans_val, '[val]')
return dataset_train, dataset_val
def print_transform(transform, s):
print(f'Transform {s} = ')
for t in transform.transforms:
print(t)
print('---------------------------\n')