You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
185 lines
5.8 KiB
185 lines
5.8 KiB
# code was heavily based on https://github.com/clovaai/stargan-v2 |
|
# Users should be careful about adopting these functions in any commercial matters. |
|
# https://github.com/clovaai/stargan-v2#license |
|
|
|
import paddle |
|
from .base_dataset import BaseDataset |
|
from .builder import DATASETS |
|
import os |
|
from itertools import chain |
|
from pathlib import Path |
|
import traceback |
|
import random |
|
import numpy as np |
|
from PIL import Image |
|
|
|
from paddle.io import Dataset, WeightedRandomSampler |
|
|
|
|
|
def listdir(dname): |
|
fnames = list( |
|
chain(*[ |
|
list(Path(dname).rglob('*.' + ext)) |
|
for ext in ['png', 'jpg', 'jpeg', 'JPG'] |
|
])) |
|
return fnames |
|
|
|
|
|
def _make_balanced_sampler(labels): |
|
class_counts = np.bincount(labels) |
|
class_weights = 1. / class_counts |
|
weights = class_weights[labels] |
|
return WeightedRandomSampler(weights, len(weights)) |
|
|
|
|
|
class ImageFolder(Dataset): |
|
def __init__(self, root, use_sampler=False): |
|
self.samples, self.targets = self._make_dataset(root) |
|
self.use_sampler = use_sampler |
|
if self.use_sampler: |
|
self.sampler = _make_balanced_sampler(self.targets) |
|
self.iter_sampler = iter(self.sampler) |
|
|
|
def _make_dataset(self, root): |
|
domains = os.listdir(root) |
|
fnames, labels = [], [] |
|
for idx, domain in enumerate(sorted(domains)): |
|
class_dir = os.path.join(root, domain) |
|
cls_fnames = listdir(class_dir) |
|
fnames += cls_fnames |
|
labels += [idx] * len(cls_fnames) |
|
return fnames, labels |
|
|
|
def __getitem__(self, i): |
|
if self.use_sampler: |
|
try: |
|
index = next(self.iter_sampler) |
|
except StopIteration: |
|
self.iter_sampler = iter(self.sampler) |
|
index = next(self.iter_sampler) |
|
else: |
|
index = i |
|
fname = self.samples[index] |
|
label = self.targets[index] |
|
return fname, label |
|
|
|
def __len__(self): |
|
return len(self.targets) |
|
|
|
|
|
class ReferenceDataset(Dataset): |
|
def __init__(self, root, use_sampler=None): |
|
self.samples, self.targets = self._make_dataset(root) |
|
self.use_sampler = use_sampler |
|
if self.use_sampler: |
|
self.sampler = _make_balanced_sampler(self.targets) |
|
self.iter_sampler = iter(self.sampler) |
|
|
|
def _make_dataset(self, root): |
|
domains = os.listdir(root) |
|
fnames, fnames2, labels = [], [], [] |
|
for idx, domain in enumerate(sorted(domains)): |
|
class_dir = os.path.join(root, domain) |
|
cls_fnames = listdir(class_dir) |
|
fnames += cls_fnames |
|
fnames2 += random.sample(cls_fnames, len(cls_fnames)) |
|
labels += [idx] * len(cls_fnames) |
|
return list(zip(fnames, fnames2)), labels |
|
|
|
def __getitem__(self, i): |
|
if self.use_sampler: |
|
try: |
|
index = next(self.iter_sampler) |
|
except StopIteration: |
|
self.iter_sampler = iter(self.sampler) |
|
index = next(self.iter_sampler) |
|
else: |
|
index = i |
|
fname, fname2 = self.samples[index] |
|
label = self.targets[index] |
|
return fname, fname2, label |
|
|
|
def __len__(self): |
|
return len(self.targets) |
|
|
|
|
|
@DATASETS.register() |
|
class StarGANv2Dataset(BaseDataset): |
|
""" |
|
""" |
|
|
|
def __init__(self, dataroot, is_train, preprocess, test_count=0): |
|
"""Initialize single dataset class. |
|
|
|
Args: |
|
dataroot (str): Directory of dataset. |
|
preprocess (list[dict]): A sequence of data preprocess config. |
|
""" |
|
super(StarGANv2Dataset, self).__init__(preprocess) |
|
|
|
self.dataroot = dataroot |
|
self.is_train = is_train |
|
if self.is_train: |
|
self.src_loader = ImageFolder(self.dataroot, use_sampler=True) |
|
self.ref_loader = ReferenceDataset(self.dataroot, use_sampler=True) |
|
self.counts = len(self.src_loader) |
|
else: |
|
files = os.listdir(self.dataroot) |
|
if 'src' in files and 'ref' in files: |
|
self.src_loader = ImageFolder( |
|
os.path.join(self.dataroot, 'src')) |
|
self.ref_loader = ImageFolder( |
|
os.path.join(self.dataroot, 'ref')) |
|
else: |
|
self.src_loader = ImageFolder(self.dataroot) |
|
self.ref_loader = ImageFolder(self.dataroot) |
|
self.counts = min(test_count, len(self.src_loader)) |
|
self.counts = min(self.counts, len(self.ref_loader)) |
|
|
|
def _fetch_inputs(self): |
|
try: |
|
x, y = next(self.iter_src) |
|
except (AttributeError, StopIteration): |
|
self.iter_src = iter(self.src_loader) |
|
x, y = next(self.iter_src) |
|
return x, y |
|
|
|
def _fetch_refs(self): |
|
try: |
|
x, x2, y = next(self.iter_ref) |
|
except (AttributeError, StopIteration): |
|
self.iter_ref = iter(self.ref_loader) |
|
x, x2, y = next(self.iter_ref) |
|
return x, x2, y |
|
|
|
def __getitem__(self, idx): |
|
if self.is_train: |
|
x, y = self._fetch_inputs() |
|
x_ref, x_ref2, y_ref = self._fetch_refs() |
|
datas = { |
|
'src_path': x, |
|
'src_cls': y, |
|
'ref_path': x_ref, |
|
'ref2_path': x_ref2, |
|
'ref_cls': y_ref, |
|
} |
|
else: |
|
x, y = self.src_loader[idx] |
|
x_ref, y_ref = self.ref_loader[idx] |
|
datas = { |
|
'src_path': x, |
|
'src_cls': y, |
|
'ref_path': x_ref, |
|
'ref_cls': y_ref, |
|
} |
|
|
|
if hasattr(self, 'preprocess') and self.preprocess: |
|
datas = self.preprocess(datas) |
|
|
|
return datas |
|
|
|
def __len__(self): |
|
return self.counts |
|
|
|
def prepare_data_infos(self, dataroot): |
|
pass
|
|
|