You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
264 lines
10 KiB
264 lines
10 KiB
# code was heavily based on https://github.com/AliaksandrSiarohin/first-order-model |
|
# Users should be careful about adopting these functions in any commercial matters. |
|
# https://github.com/AliaksandrSiarohin/first-order-model/blob/master/LICENSE.md |
|
|
|
import logging |
|
from multiprocessing import Pool |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import tqdm |
|
from imageio import imread, mimread, imwrite |
|
import cv2 |
|
from paddle.io import Dataset |
|
from .builder import DATASETS |
|
from .preprocess.builder import build_transforms |
|
import glob, os |
|
|
|
POOL_SIZE = 64 # If POOL_SIZE>0 use multiprocessing to extract frames from gif file |
|
|
|
|
|
@DATASETS.register() |
|
class FirstOrderDataset(Dataset): |
|
def __init__(self, **cfg): |
|
"""Initialize FirstOrder dataset class. |
|
|
|
Args: |
|
dataroot (str): Directory of dataset. |
|
phase (str): train or test |
|
num_repeats (int): Number for datasets to repeat |
|
time_flip (bool): whether to exchange the driving image and source image randomly |
|
batch_size (int): dataset batch size |
|
id_sampling (bool): whether to sample person's id |
|
frame_shape (list): image shape |
|
create_frames_folder (bool): if the format of your input datasets is '.mp4', \ |
|
you can choose whether to save it with images |
|
num_workers (int): dataset |
|
""" |
|
super(FirstOrderDataset, self).__init__() |
|
self.cfg = cfg |
|
self.frameDataset = FramesDataset(self.cfg) |
|
|
|
# create frames folder before 'DatasetRepeater' |
|
if self.cfg['create_frames_folder']: |
|
file_idx_set = [ |
|
idx for idx, path in enumerate(self.frameDataset.videos) |
|
if not self.frameDataset.root_dir.joinpath(path).is_dir() |
|
] |
|
file_idx_set = list(file_idx_set) |
|
if len(file_idx_set) != 0: |
|
if POOL_SIZE == 0: |
|
for idx in tqdm.tqdm( |
|
file_idx_set, desc='Extracting frames'): |
|
_ = self.frameDataset[idx] |
|
else: |
|
# multiprocessing |
|
bar = tqdm.tqdm( |
|
total=len(file_idx_set), desc='Extracting frames') |
|
with Pool(POOL_SIZE) as pl: |
|
_p = 0 |
|
while _p <= len(file_idx_set) - 1: |
|
_ = pl.map(self.frameDataset.__getitem__, |
|
file_idx_set[_p:_p + POOL_SIZE * 2]) |
|
_p = _p + POOL_SIZE * 2 |
|
bar.update(POOL_SIZE * 2) |
|
bar.close() |
|
|
|
# rewrite video path |
|
self.frameDataset.videos = [ |
|
i.with_suffix('') for i in self.frameDataset.videos |
|
] |
|
|
|
if self.cfg['phase'] == 'train': |
|
self.outDataset = DatasetRepeater(self.frameDataset, |
|
self.cfg['num_repeats']) |
|
else: |
|
self.outDataset = self.frameDataset |
|
|
|
def __len__(self): |
|
return len(self.outDataset) |
|
|
|
def __getitem__(self, idx): |
|
return self.outDataset[idx] |
|
|
|
|
|
def read_video(name: Path, frame_shape=tuple([256, 256, 3]), saveto='folder'): |
|
""" |
|
Read video which can be: |
|
- an image of concatenated frames |
|
- '.mp4' and'.gif' |
|
- folder with videos |
|
""" |
|
if name.is_dir(): |
|
frames = sorted( |
|
name.iterdir(), key=lambda x: int(x.with_suffix('').name)) |
|
video_array = np.array( |
|
[imread(path) for path in frames], dtype='float32') |
|
return video_array |
|
elif name.suffix.lower() in ['.gif', '.mp4', '.mov']: |
|
try: |
|
video = mimread(name, memtest=False) |
|
except Exception as err: |
|
logging.error('DataLoading File:%s Msg:%s' % (str(name), str(err))) |
|
return None |
|
|
|
# convert to 3-channel image |
|
if video[0].shape[-1] == 4: |
|
video = [i[..., :3] for i in video] |
|
elif video[0].shape[-1] == 1: |
|
video = [np.tile(i, (1, 1, 3)) for i in video] |
|
elif len(video[0].shape) == 2: |
|
video = [np.tile(i[..., np.newaxis], (1, 1, 3)) for i in video] |
|
video_array = np.asarray(video) |
|
video_array_reshape = [] |
|
for idx, img in enumerate(video_array): |
|
img = cv2.resize(img, (frame_shape[0], frame_shape[1])) |
|
video_array_reshape.append(img.astype(np.uint8)) |
|
video_array_reshape = np.asarray(video_array_reshape) |
|
|
|
if saveto == 'folder': |
|
sub_dir = name.with_suffix('') |
|
try: |
|
sub_dir.mkdir() |
|
except FileExistsError: |
|
pass |
|
for idx, img in enumerate(video_array_reshape): |
|
cv2.imwrite( |
|
str(sub_dir.joinpath('%i.png' % idx)), img[:, :, [2, 1, 0]]) |
|
name.unlink() |
|
return video_array_reshape |
|
else: |
|
raise Exception("Unknown dataset file extensions %s" % name) |
|
|
|
|
|
class FramesDataset(Dataset): |
|
""" |
|
Dataset of videos, each video can be represented as: |
|
- an image of concatenated frames |
|
- '.mp4' or '.gif' |
|
- folder with all frames |
|
FramesDataset[i]: obtain sample from i-th video in self.videos |
|
""" |
|
|
|
def __init__(self, cfg): |
|
self.root_dir = Path(cfg['dataroot']) |
|
self.videos = None |
|
self.frame_shape = tuple(cfg['frame_shape']) |
|
self.id_sampling = cfg['id_sampling'] |
|
self.time_flip = cfg['time_flip'] |
|
self.is_train = True if cfg['phase'] == 'train' else False |
|
self.pairs_list = cfg.setdefault('pairs_list', None) |
|
self.create_frames_folder = cfg['create_frames_folder'] |
|
self.transform = None |
|
random_seed = 0 |
|
assert self.root_dir.joinpath('train').exists() |
|
assert self.root_dir.joinpath('test').exists() |
|
logging.info("Use predefined train-test split.") |
|
if self.id_sampling: |
|
train_videos = { |
|
video.name.split('#')[0] |
|
for video in self.root_dir.joinpath('train').iterdir() |
|
} |
|
train_videos = list(train_videos) |
|
else: |
|
train_videos = list(self.root_dir.joinpath('train').iterdir()) |
|
test_videos = list(self.root_dir.joinpath('test').iterdir()) |
|
self.root_dir = self.root_dir.joinpath('train' |
|
if self.is_train else 'test') |
|
|
|
if self.is_train: |
|
self.videos = train_videos |
|
self.transform = build_transforms(cfg['transforms']) |
|
else: |
|
self.videos = test_videos |
|
self.transform = None |
|
|
|
def __len__(self): |
|
return len(self.videos) |
|
|
|
def __getitem__(self, idx): |
|
if self.is_train and self.id_sampling: |
|
name = self.videos[idx] |
|
path = Path( |
|
np.random.choice( |
|
glob.glob(os.path.join(self.root_dir, name + '*.mp4')))) |
|
else: |
|
path = self.videos[idx] |
|
video_name = path.name |
|
if self.is_train and path.is_dir(): |
|
frames = sorted( |
|
path.iterdir(), key=lambda x: int(x.with_suffix('').name)) |
|
num_frames = len(frames) |
|
frame_idx = np.sort( |
|
np.random.choice( |
|
num_frames, replace=True, size=2)) |
|
video_array = [imread(str(frames[idx])) for idx in frame_idx] |
|
else: |
|
if self.create_frames_folder: |
|
video_array = read_video( |
|
path, frame_shape=self.frame_shape, saveto='folder') |
|
self.videos[idx] = path.with_suffix( |
|
'') # rename /xx/xx/xx.gif -> /xx/xx/xx |
|
else: |
|
video_array = read_video( |
|
path, frame_shape=self.frame_shape, saveto=None) |
|
num_frames = len(video_array) |
|
frame_idx = np.sort( |
|
np.random.choice( |
|
num_frames, replace=True, |
|
size=2)) if self.is_train else range(num_frames) |
|
video_array = [video_array[i] for i in frame_idx] |
|
# convert to 3-channel image |
|
if video_array[0].shape[-1] == 4: |
|
video_array = [i[..., :3] for i in video_array] |
|
elif video_array[0].shape[-1] == 1: |
|
video_array = [np.tile(i, (1, 1, 3)) for i in video_array] |
|
elif len(video_array[0].shape) == 2: |
|
video_array = [ |
|
np.tile(i[..., np.newaxis], (1, 1, 3)) for i in video_array |
|
] |
|
out = {} |
|
if self.is_train: |
|
if self.transform is not None: #modify |
|
t = self.transform(tuple(video_array)) |
|
out['driving'] = t[0].transpose(2, 0, |
|
1).astype(np.float32) / 255.0 |
|
out['source'] = t[1].transpose(2, 0, |
|
1).astype(np.float32) / 255.0 |
|
else: |
|
source = np.array( |
|
video_array[0], |
|
dtype='float32') / 255.0 # shape is [H, W, C] |
|
driving = np.array( |
|
video_array[1], |
|
dtype='float32') / 255.0 # shape is [H, W, C] |
|
out['driving'] = driving.transpose(2, 0, 1) |
|
out['source'] = source.transpose(2, 0, 1) |
|
if self.time_flip and np.random.rand() < 0.5: #modify |
|
buf = out['driving'] |
|
out['driving'] = out['source'] |
|
out['source'] = buf |
|
else: |
|
video = np.stack(video_array, axis=0).astype(np.float32) / 255.0 |
|
out['video'] = video.transpose(3, 0, 1, 2) |
|
out['name'] = video_name |
|
return out |
|
|
|
def get_sample(self, idx): |
|
return self.__getitem__(idx) |
|
|
|
|
|
class DatasetRepeater(Dataset): |
|
""" |
|
Pass several times over the same dataset for better i/o performance |
|
""" |
|
|
|
def __init__(self, dataset, num_repeats=100): |
|
self.dataset = dataset |
|
self.num_repeats = num_repeats |
|
|
|
def __len__(self): |
|
return self.num_repeats * self.dataset.__len__() |
|
|
|
def __getitem__(self, idx): |
|
return self.dataset[idx % self.dataset.__len__()]
|
|
|