You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
238 lines
8.3 KiB
238 lines
8.3 KiB
3 years ago
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import logging
|
||
|
import os
|
||
|
import random
|
||
|
import numpy as np
|
||
|
import cv2
|
||
|
from paddle.io import Dataset
|
||
|
|
||
|
from .builder import DATASETS
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
@DATASETS.register()
|
||
|
class SRREDSMultipleGTDataset(Dataset):
|
||
|
"""REDS dataset for video super resolution for recurrent networks.
|
||
|
|
||
|
The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth)
|
||
|
frames. Then it applies specified transforms and finally returns a dict
|
||
|
containing paired data and other information.
|
||
|
|
||
|
Args:
|
||
|
lq_folder (str | :obj:`Path`): Path to a lq folder.
|
||
|
gt_folder (str | :obj:`Path`): Path to a gt folder.
|
||
|
num_input_frames (int): Number of input frames.
|
||
|
pipeline (list[dict | callable]): A sequence of data transformations.
|
||
|
scale (int): Upsampling scale ratio.
|
||
|
val_partition (str): Validation partition mode. Choices ['official' or
|
||
|
'REDS4']. Default: 'official'.
|
||
|
test_mode (bool): Store `True` when building test dataset.
|
||
|
Default: `False`.
|
||
|
"""
|
||
|
def __init__(self,
|
||
|
mode,
|
||
|
lq_folder,
|
||
|
gt_folder,
|
||
|
crop_size=256,
|
||
|
interval_list=[1],
|
||
|
random_reverse=False,
|
||
|
number_frames=15,
|
||
|
use_flip=False,
|
||
|
use_rot=False,
|
||
|
scale=4,
|
||
|
val_partition='REDS4',
|
||
|
batch_size=4,
|
||
|
num_clips=270):
|
||
|
super(SRREDSMultipleGTDataset, self).__init__()
|
||
|
self.mode = mode
|
||
|
self.fileroot = str(lq_folder)
|
||
|
self.gtroot = str(gt_folder)
|
||
|
self.crop_size = crop_size
|
||
|
self.interval_list = interval_list
|
||
|
self.random_reverse = random_reverse
|
||
|
self.number_frames = number_frames
|
||
|
self.use_flip = use_flip
|
||
|
self.use_rot = use_rot
|
||
|
self.scale = scale
|
||
|
self.val_partition = val_partition
|
||
|
self.batch_size = batch_size
|
||
|
self.num_clips = num_clips # training num of LQ and GT pairs
|
||
|
self.data_infos = self.load_annotations()
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
"""Get item at each call.
|
||
|
|
||
|
Args:
|
||
|
idx (int): Index for getting each item.
|
||
|
"""
|
||
|
item = self.data_infos[idx]
|
||
|
idt = random.randint(0, 100 - self.number_frames)
|
||
|
item = item + '_' + f'{idt:03d}'
|
||
|
img_LQs, img_GTs = self.get_sample_data(
|
||
|
item, self.number_frames, self.interval_list, self.random_reverse,
|
||
|
self.gtroot, self.fileroot, self.crop_size, self.scale,
|
||
|
self.use_flip, self.use_rot, self.mode)
|
||
|
return {'lq': img_LQs, 'gt': img_GTs, 'lq_path': self.data_infos[idx]}
|
||
|
|
||
|
def load_annotations(self):
|
||
|
"""Load annoations for REDS dataset.
|
||
|
|
||
|
Returns:
|
||
|
dict: Returned dict for LQ and GT pairs.
|
||
|
"""
|
||
|
# generate keys
|
||
|
keys = [f'{i:03d}' for i in range(0, self.num_clips)]
|
||
|
|
||
|
if self.val_partition == 'REDS4':
|
||
|
val_partition = ['000', '011', '015', '020']
|
||
|
elif self.val_partition == 'official':
|
||
|
val_partition = [f'{i:03d}' for i in range(240, 270)]
|
||
|
else:
|
||
|
raise ValueError(f'Wrong validation partition {self.val_partition}.'
|
||
|
f'Supported ones are ["official", "REDS4"]')
|
||
|
|
||
|
if self.mode == 'train':
|
||
|
keys = [v for v in keys if v not in val_partition]
|
||
|
else:
|
||
|
keys = [v for v in keys if v in val_partition]
|
||
|
|
||
|
data_infos = []
|
||
|
for key in keys:
|
||
|
data_infos.append(key)
|
||
|
|
||
|
return data_infos
|
||
|
|
||
|
def get_sample_data(self,
|
||
|
item,
|
||
|
number_frames,
|
||
|
interval_list,
|
||
|
random_reverse,
|
||
|
gtroot,
|
||
|
fileroot,
|
||
|
crop_size,
|
||
|
scale,
|
||
|
use_flip,
|
||
|
use_rot,
|
||
|
mode='train'):
|
||
|
video_name = item.split('_')[0]
|
||
|
frame_name = item.split('_')[1]
|
||
|
frame_idxs = self.get_neighbor_frames(frame_name,
|
||
|
number_frames=number_frames,
|
||
|
interval_list=interval_list,
|
||
|
random_reverse=random_reverse)
|
||
|
|
||
|
frame_list = []
|
||
|
gt_list = []
|
||
|
for frame_idx in frame_idxs:
|
||
|
frame_idx_name = "%08d" % frame_idx
|
||
|
img = self.read_img(
|
||
|
os.path.join(fileroot, video_name, frame_idx_name + '.png'))
|
||
|
frame_list.append(img)
|
||
|
gt_img = self.read_img(
|
||
|
os.path.join(gtroot, video_name, frame_idx_name + '.png'))
|
||
|
gt_list.append(gt_img)
|
||
|
H, W, C = frame_list[0].shape
|
||
|
# add random crop
|
||
|
if (mode == 'train') or (mode == 'valid'):
|
||
|
LQ_size = crop_size // scale
|
||
|
rnd_h = random.randint(0, max(0, H - LQ_size))
|
||
|
rnd_w = random.randint(0, max(0, W - LQ_size))
|
||
|
frame_list = [
|
||
|
v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
|
||
|
for v in frame_list
|
||
|
]
|
||
|
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
|
||
|
gt_list = [
|
||
|
v[rnd_h_HR:rnd_h_HR + crop_size,
|
||
|
rnd_w_HR:rnd_w_HR + crop_size, :] for v in gt_list
|
||
|
]
|
||
|
|
||
|
# add random flip and rotation
|
||
|
for v in gt_list:
|
||
|
frame_list.append(v)
|
||
|
if (mode == 'train') or (mode == 'valid'):
|
||
|
rlt = self.img_augment(frame_list, use_flip, use_rot)
|
||
|
else:
|
||
|
rlt = frame_list
|
||
|
frame_list = rlt[0:number_frames]
|
||
|
gt_list = rlt[number_frames:]
|
||
|
|
||
|
# stack LQ images to NHWC, N is the frame number
|
||
|
frame_list = [
|
||
|
v.transpose(2, 0, 1).astype('float32') for v in frame_list
|
||
|
]
|
||
|
gt_list = [v.transpose(2, 0, 1).astype('float32') for v in gt_list]
|
||
|
|
||
|
img_LQs = np.stack(frame_list, axis=0)
|
||
|
img_GTs = np.stack(gt_list, axis=0)
|
||
|
|
||
|
return img_LQs, img_GTs
|
||
|
|
||
|
def get_neighbor_frames(self, frame_name, number_frames, interval_list,
|
||
|
random_reverse):
|
||
|
frame_idx = int(frame_name)
|
||
|
interval = random.choice(interval_list)
|
||
|
neighbor_list = list(
|
||
|
range(frame_idx, frame_idx + number_frames, interval))
|
||
|
if random_reverse and random.random() < 0.5:
|
||
|
neighbor_list.reverse()
|
||
|
|
||
|
assert len(neighbor_list) == number_frames, \
|
||
|
"frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames)
|
||
|
|
||
|
return neighbor_list
|
||
|
|
||
|
def read_img(self, path, size=None):
|
||
|
"""read image by cv2
|
||
|
|
||
|
return: Numpy float32, HWC, BGR, [0,1]
|
||
|
"""
|
||
|
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||
|
img = img.astype(np.float32) / 255.
|
||
|
if img.ndim == 2:
|
||
|
img = np.expand_dims(img, axis=2)
|
||
|
# some images have 4 channels
|
||
|
if img.shape[2] > 3:
|
||
|
img = img[:, :, :3]
|
||
|
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||
|
|
||
|
def img_augment(self, img_list, hflip=True, rot=True):
|
||
|
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)
|
||
|
"""
|
||
|
hflip = hflip and random.random() < 0.5
|
||
|
vflip = rot and random.random() < 0.5
|
||
|
rot90 = rot and random.random() < 0.5
|
||
|
|
||
|
def _augment(img):
|
||
|
if hflip:
|
||
|
img = img[:, ::-1, :]
|
||
|
if vflip:
|
||
|
img = img[::-1, :, :]
|
||
|
if rot90:
|
||
|
img = img.transpose(1, 0, 2)
|
||
|
return img
|
||
|
|
||
|
return [_augment(img) for img in img_list]
|
||
|
|
||
|
def __len__(self):
|
||
|
"""Length of the dataset.
|
||
|
|
||
|
Returns:
|
||
|
int: Length of the dataset.
|
||
|
"""
|
||
|
return len(self.data_infos)
|