You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
312 lines
12 KiB
312 lines
12 KiB
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
import logging |
|
import os |
|
import random |
|
import numpy as np |
|
import scipy.io as scio |
|
import cv2 |
|
import paddle |
|
from paddle.io import Dataset, DataLoader |
|
from .builder import DATASETS |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@DATASETS.register() |
|
class REDSDataset(Dataset): |
|
""" |
|
REDS dataset for EDVR model |
|
""" |
|
|
|
def __init__(self, |
|
mode, |
|
lq_folder, |
|
gt_folder, |
|
img_format="png", |
|
crop_size=256, |
|
interval_list=[1], |
|
random_reverse=False, |
|
number_frames=5, |
|
batch_size=32, |
|
use_flip=False, |
|
use_rot=False, |
|
buf_size=1024, |
|
scale=4, |
|
fix_random_seed=False): |
|
super(REDSDataset, self).__init__() |
|
self.format = img_format |
|
self.mode = mode |
|
self.crop_size = crop_size |
|
self.interval_list = interval_list |
|
self.random_reverse = random_reverse |
|
self.number_frames = number_frames |
|
self.batch_size = batch_size |
|
self.fileroot = lq_folder |
|
self.use_flip = use_flip |
|
self.use_rot = use_rot |
|
self.buf_size = buf_size |
|
self.fix_random_seed = fix_random_seed |
|
|
|
if self.mode != 'infer': |
|
self.gtroot = gt_folder |
|
self.scale = scale |
|
self.LR_input = (self.scale > 1) |
|
if self.fix_random_seed: |
|
random.seed(10) |
|
np.random.seed(10) |
|
self.num_reader_threads = 1 |
|
|
|
self._init_() |
|
|
|
def _init_(self): |
|
logger.info('initialize reader ... ') |
|
print("initialize reader") |
|
self.filelist = [] |
|
for video_name in os.listdir(self.fileroot): |
|
if (self.mode == 'train') and ( |
|
video_name in ['000', '011', '015', '020' |
|
]): #These four videos are used as val |
|
continue |
|
for frame_name in os.listdir( |
|
os.path.join(self.fileroot, video_name)): |
|
frame_idx = frame_name.split('.')[0] |
|
video_frame_idx = video_name + '_' + str(frame_idx) |
|
# for each item in self.filelist is like '010_00000015', '260_00000090' |
|
self.filelist.append(video_frame_idx) |
|
if self.mode == 'test': |
|
self.filelist.sort() |
|
print(len(self.filelist)) |
|
|
|
def __getitem__(self, index): |
|
"""Get training sample |
|
|
|
return: lq:[5,3,W,H], |
|
gt:[3,W,H], |
|
lq_path:str |
|
""" |
|
item = self.filelist[index] |
|
img_LQs, img_GT = self.get_sample_data( |
|
item, self.number_frames, self.interval_list, self.random_reverse, |
|
self.gtroot, self.fileroot, self.LR_input, self.crop_size, |
|
self.scale, self.use_flip, self.use_rot, self.mode) |
|
return {'lq': img_LQs, 'gt': img_GT, 'lq_path': self.filelist[index]} |
|
|
|
def get_sample_data(self, |
|
item, |
|
number_frames, |
|
interval_list, |
|
random_reverse, |
|
gtroot, |
|
fileroot, |
|
LR_input, |
|
crop_size, |
|
scale, |
|
use_flip, |
|
use_rot, |
|
mode='train'): |
|
video_name = item.split('_')[0] |
|
frame_name = item.split('_')[1] |
|
if (mode == 'train') or (mode == 'valid'): |
|
ngb_frames, name_b = self.get_neighbor_frames(frame_name, \ |
|
number_frames=number_frames, \ |
|
interval_list=interval_list, \ |
|
random_reverse=random_reverse) |
|
elif mode == 'test': |
|
ngb_frames, name_b = self.get_test_neighbor_frames( |
|
int(frame_name), number_frames) |
|
else: |
|
raise NotImplementedError('mode {} not implemented'.format(mode)) |
|
frame_name = name_b |
|
img_GT = self.read_img( |
|
os.path.join(gtroot, video_name, frame_name + '.png')) |
|
frame_list = [] |
|
for ngb_frm in ngb_frames: |
|
ngb_name = "%08d" % ngb_frm |
|
img = self.read_img( |
|
os.path.join(fileroot, video_name, ngb_name + '.png')) |
|
frame_list.append(img) |
|
H, W, C = frame_list[0].shape |
|
# add random crop |
|
if (mode == 'train') or (mode == 'valid'): |
|
if LR_input: |
|
LQ_size = crop_size // scale |
|
rnd_h = random.randint(0, max(0, H - LQ_size)) |
|
rnd_w = random.randint(0, max(0, W - LQ_size)) |
|
frame_list = [ |
|
v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] |
|
for v in frame_list |
|
] |
|
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) |
|
img_GT = img_GT[rnd_h_HR:rnd_h_HR + crop_size, rnd_w_HR:rnd_w_HR |
|
+ crop_size, :] |
|
else: |
|
rnd_h = random.randint(0, max(0, H - crop_size)) |
|
rnd_w = random.randint(0, max(0, W - crop_size)) |
|
frame_list = [ |
|
v[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :] |
|
for v in frame_list |
|
] |
|
img_GT = img_GT[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + |
|
crop_size, :] |
|
|
|
# add random flip and rotation |
|
frame_list.append(img_GT) |
|
if (mode == 'train') or (mode == 'valid'): |
|
rlt = self.img_augment(frame_list, use_flip, use_rot) |
|
else: |
|
rlt = frame_list |
|
frame_list = rlt[0:-1] |
|
img_GT = rlt[-1] |
|
|
|
# stack LQ images to NHWC, N is the frame number |
|
img_LQs = np.stack(frame_list, axis=0) |
|
# BGR to RGB, HWC to CHW, numpy to tensor |
|
img_GT = img_GT[:, :, [2, 1, 0]] |
|
img_LQs = img_LQs[:, :, :, [2, 1, 0]] |
|
img_GT = np.transpose(img_GT, (2, 0, 1)).astype('float32') |
|
img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32') |
|
|
|
return img_LQs, img_GT |
|
|
|
def get_neighbor_frames(self, |
|
frame_name, |
|
number_frames, |
|
interval_list, |
|
random_reverse, |
|
max_frame=99, |
|
bordermode=False): |
|
center_frame_idx = int(frame_name) |
|
half_N_frames = number_frames // 2 |
|
interval = random.choice(interval_list) |
|
if bordermode: |
|
direction = 1 |
|
if random_reverse and random.random() < 0.5: |
|
direction = random.choice([0, 1]) |
|
if center_frame_idx + interval * (number_frames - 1) > max_frame: |
|
direction = 0 |
|
elif center_frame_idx - interval * (number_frames - 1) < 0: |
|
direction = 1 |
|
if direction == 1: |
|
neighbor_list = list( |
|
range(center_frame_idx, center_frame_idx + interval * |
|
number_frames, interval)) |
|
else: |
|
neighbor_list = list( |
|
range(center_frame_idx, center_frame_idx - interval * |
|
number_frames, -interval)) |
|
name_b = '{:08d}'.format(neighbor_list[0]) |
|
else: |
|
# ensure not exceeding the borders |
|
while (center_frame_idx + half_N_frames * interval > max_frame) or ( |
|
center_frame_idx - half_N_frames * interval < 0): |
|
center_frame_idx = random.randint(0, max_frame) |
|
neighbor_list = list( |
|
range(center_frame_idx - half_N_frames * interval, |
|
center_frame_idx + half_N_frames * interval + 1, |
|
interval)) |
|
if random_reverse and random.random() < 0.5: |
|
neighbor_list.reverse() |
|
name_b = '{:08d}'.format(neighbor_list[half_N_frames]) |
|
assert len(neighbor_list) == number_frames, \ |
|
"frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames) |
|
|
|
return neighbor_list, name_b |
|
|
|
def read_img(self, path, size=None): |
|
"""read image by cv2 |
|
|
|
return: Numpy float32, HWC, BGR, [0,1] |
|
""" |
|
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) |
|
img = img.astype(np.float32) / 255. |
|
if img.ndim == 2: |
|
img = np.expand_dims(img, axis=2) |
|
# some images have 4 channels |
|
if img.shape[2] > 3: |
|
img = img[:, :, :3] |
|
return img |
|
|
|
def img_augment(self, img_list, hflip=True, rot=True): |
|
"""horizontal flip OR rotate (0, 90, 180, 270 degrees) |
|
""" |
|
hflip = hflip and random.random() < 0.5 |
|
vflip = rot and random.random() < 0.5 |
|
rot90 = rot and random.random() < 0.5 |
|
|
|
def _augment(img): |
|
if hflip: |
|
img = img[:, ::-1, :] |
|
if vflip: |
|
img = img[::-1, :, :] |
|
if rot90: |
|
img = img.transpose(1, 0, 2) |
|
return img |
|
|
|
return [_augment(img) for img in img_list] |
|
|
|
def get_test_neighbor_frames(self, crt_i, N, max_n=100, padding='new_info'): |
|
"""Generate an index list for reading N frames from a sequence of images |
|
Args: |
|
crt_i (int): current center index |
|
max_n (int): max number of the sequence of images (calculated from 1) |
|
N (int): reading N frames |
|
padding (str): padding mode, one of replicate | reflection | new_info | circle |
|
Example: crt_i = 0, N = 5 |
|
replicate: [0, 0, 0, 1, 2] |
|
reflection: [2, 1, 0, 1, 2] |
|
new_info: [4, 3, 0, 1, 2] |
|
circle: [3, 4, 0, 1, 2] |
|
|
|
Returns: |
|
return_l (list [int]): a list of indexes |
|
""" |
|
max_n = max_n - 1 |
|
n_pad = N // 2 |
|
return_l = [] |
|
|
|
for i in range(crt_i - n_pad, crt_i + n_pad + 1): |
|
if i < 0: |
|
if padding == 'replicate': |
|
add_idx = 0 |
|
elif padding == 'reflection': |
|
add_idx = -i |
|
elif padding == 'new_info': |
|
add_idx = (crt_i + n_pad) + (-i) |
|
elif padding == 'circle': |
|
add_idx = N + i |
|
else: |
|
raise ValueError('Wrong padding mode') |
|
elif i > max_n: |
|
if padding == 'replicate': |
|
add_idx = max_n |
|
elif padding == 'reflection': |
|
add_idx = max_n * 2 - i |
|
elif padding == 'new_info': |
|
add_idx = (crt_i - n_pad) - (i - max_n) |
|
elif padding == 'circle': |
|
add_idx = i - N |
|
else: |
|
raise ValueError('Wrong padding mode') |
|
else: |
|
add_idx = i |
|
return_l.append(add_idx) |
|
name_b = '{:08d}'.format(crt_i) |
|
return return_l, name_b |
|
|
|
def __len__(self): |
|
"""Return the total number of images in the dataset. |
|
""" |
|
return len(self.filelist)
|
|
|