You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

238 lines
8.3 KiB

3 years ago
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import random
import numpy as np
import cv2
from paddle.io import Dataset
from .builder import DATASETS
logger = logging.getLogger(__name__)
@DATASETS.register()
class SRREDSMultipleGTDataset(Dataset):
"""REDS dataset for video super resolution for recurrent networks.
The dataset loads several LQ (Low-Quality) frames and GT (Ground-Truth)
frames. Then it applies specified transforms and finally returns a dict
containing paired data and other information.
Args:
lq_folder (str | :obj:`Path`): Path to a lq folder.
gt_folder (str | :obj:`Path`): Path to a gt folder.
num_input_frames (int): Number of input frames.
pipeline (list[dict | callable]): A sequence of data transformations.
scale (int): Upsampling scale ratio.
val_partition (str): Validation partition mode. Choices ['official' or
'REDS4']. Default: 'official'.
test_mode (bool): Store `True` when building test dataset.
Default: `False`.
"""
def __init__(self,
mode,
lq_folder,
gt_folder,
crop_size=256,
interval_list=[1],
random_reverse=False,
number_frames=15,
use_flip=False,
use_rot=False,
scale=4,
val_partition='REDS4',
batch_size=4,
num_clips=270):
super(SRREDSMultipleGTDataset, self).__init__()
self.mode = mode
self.fileroot = str(lq_folder)
self.gtroot = str(gt_folder)
self.crop_size = crop_size
self.interval_list = interval_list
self.random_reverse = random_reverse
self.number_frames = number_frames
self.use_flip = use_flip
self.use_rot = use_rot
self.scale = scale
self.val_partition = val_partition
self.batch_size = batch_size
self.num_clips = num_clips # training num of LQ and GT pairs
self.data_infos = self.load_annotations()
def __getitem__(self, idx):
"""Get item at each call.
Args:
idx (int): Index for getting each item.
"""
item = self.data_infos[idx]
idt = random.randint(0, 100 - self.number_frames)
item = item + '_' + f'{idt:03d}'
img_LQs, img_GTs = self.get_sample_data(
item, self.number_frames, self.interval_list, self.random_reverse,
self.gtroot, self.fileroot, self.crop_size, self.scale,
self.use_flip, self.use_rot, self.mode)
return {'lq': img_LQs, 'gt': img_GTs, 'lq_path': self.data_infos[idx]}
def load_annotations(self):
"""Load annoations for REDS dataset.
Returns:
dict: Returned dict for LQ and GT pairs.
"""
# generate keys
keys = [f'{i:03d}' for i in range(0, self.num_clips)]
if self.val_partition == 'REDS4':
val_partition = ['000', '011', '015', '020']
elif self.val_partition == 'official':
val_partition = [f'{i:03d}' for i in range(240, 270)]
else:
raise ValueError(f'Wrong validation partition {self.val_partition}.'
f'Supported ones are ["official", "REDS4"]')
if self.mode == 'train':
keys = [v for v in keys if v not in val_partition]
else:
keys = [v for v in keys if v in val_partition]
data_infos = []
for key in keys:
data_infos.append(key)
return data_infos
def get_sample_data(self,
item,
number_frames,
interval_list,
random_reverse,
gtroot,
fileroot,
crop_size,
scale,
use_flip,
use_rot,
mode='train'):
video_name = item.split('_')[0]
frame_name = item.split('_')[1]
frame_idxs = self.get_neighbor_frames(frame_name,
number_frames=number_frames,
interval_list=interval_list,
random_reverse=random_reverse)
frame_list = []
gt_list = []
for frame_idx in frame_idxs:
frame_idx_name = "%08d" % frame_idx
img = self.read_img(
os.path.join(fileroot, video_name, frame_idx_name + '.png'))
frame_list.append(img)
gt_img = self.read_img(
os.path.join(gtroot, video_name, frame_idx_name + '.png'))
gt_list.append(gt_img)
H, W, C = frame_list[0].shape
# add random crop
if (mode == 'train') or (mode == 'valid'):
LQ_size = crop_size // scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
frame_list = [
v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
for v in frame_list
]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
gt_list = [
v[rnd_h_HR:rnd_h_HR + crop_size,
rnd_w_HR:rnd_w_HR + crop_size, :] for v in gt_list
]
# add random flip and rotation
for v in gt_list:
frame_list.append(v)
if (mode == 'train') or (mode == 'valid'):
rlt = self.img_augment(frame_list, use_flip, use_rot)
else:
rlt = frame_list
frame_list = rlt[0:number_frames]
gt_list = rlt[number_frames:]
# stack LQ images to NHWC, N is the frame number
frame_list = [
v.transpose(2, 0, 1).astype('float32') for v in frame_list
]
gt_list = [v.transpose(2, 0, 1).astype('float32') for v in gt_list]
img_LQs = np.stack(frame_list, axis=0)
img_GTs = np.stack(gt_list, axis=0)
return img_LQs, img_GTs
def get_neighbor_frames(self, frame_name, number_frames, interval_list,
random_reverse):
frame_idx = int(frame_name)
interval = random.choice(interval_list)
neighbor_list = list(
range(frame_idx, frame_idx + number_frames, interval))
if random_reverse and random.random() < 0.5:
neighbor_list.reverse()
assert len(neighbor_list) == number_frames, \
"frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames)
return neighbor_list
def read_img(self, path, size=None):
"""read image by cv2
return: Numpy float32, HWC, BGR, [0,1]
"""
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
img = img.astype(np.float32) / 255.
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
def img_augment(self, img_list, hflip=True, rot=True):
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)
"""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(img) for img in img_list]
def __len__(self):
"""Length of the dataset.
Returns:
int: Length of the dataset.
"""
return len(self.data_infos)