You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

194 lines
5.8 KiB

# code was heavily based on https://github.com/swz30/MPRNet
# Users should be careful about adopting these functions in any commercial matters.
# https://github.com/swz30/MPRNet/blob/main/LICENSE.md
import os
import random
import numpy as np
import cv2
import paddle
from PIL import Image, ImageEnhance
import numpy as np
import random
import numbers
from paddle.io import Dataset
from .builder import DATASETS
from paddle.vision.transforms.functional import to_tensor, adjust_brightness, adjust_saturation, rotate, hflip, hflip, vflip, center_crop
def is_image_file(filename):
return any(
filename.endswith(extension)
for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
@DATASETS.register()
class MPRTrain(Dataset):
def __init__(self, rgb_dir, img_options=None):
super(MPRTrain, self).__init__()
inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
self.inp_filenames = [
os.path.join(rgb_dir, 'input', x) for x in inp_files
if is_image_file(x)
]
self.tar_filenames = [
os.path.join(rgb_dir, 'target', x) for x in tar_files
if is_image_file(x)
]
self.img_options = img_options
self.sizex = len(self.tar_filenames) # get the size of target
self.ps = self.img_options['patch_size']
def __len__(self):
return self.sizex
def __getitem__(self, index):
index_ = index % self.sizex
ps = self.ps
inp_path = self.inp_filenames[index_]
tar_path = self.tar_filenames[index_]
inp_img = Image.open(inp_path)
tar_img = Image.open(tar_path)
w, h = tar_img.size
padw = ps - w if w < ps else 0
padh = ps - h if h < ps else 0
# Reflect Pad in case image is smaller than patch_size
if padw != 0 or padh != 0:
inp_img = np.pad(inp_img, (0, 0, padw, padh),
padding_mode='reflect')
tar_img = np.pad(tar_img, (0, 0, padw, padh),
padding_mode='reflect')
aug = random.randint(0, 2)
if aug == 1:
inp_img = adjust_brightness(inp_img, 1)
tar_img = adjust_brightness(tar_img, 1)
aug = random.randint(0, 2)
if aug == 1:
sat_factor = 1 + (0.2 - 0.4 * np.random.rand())
inp_img = adjust_saturation(inp_img, sat_factor)
tar_img = adjust_saturation(tar_img, sat_factor)
# Data Augmentations
if aug == 1:
inp_img = vflip(inp_img)
tar_img = vflip(tar_img)
elif aug == 2:
inp_img = hflip(inp_img)
tar_img = hflip(tar_img)
elif aug == 3:
inp_img = rotate(inp_img, 90)
tar_img = rotate(tar_img, 90)
elif aug == 4:
inp_img = rotate(inp_img, 90 * 2)
tar_img = rotate(tar_img, 90 * 2)
elif aug == 5:
inp_img = rotate(inp_img, 90 * 3)
tar_img = rotate(tar_img, 90 * 3)
elif aug == 6:
inp_img = rotate(vflip(inp_img), 90)
tar_img = rotate(vflip(tar_img), 90)
elif aug == 7:
inp_img = rotate(hflip(inp_img), 90)
tar_img = rotate(hflip(tar_img), 90)
inp_img = to_tensor(inp_img)
tar_img = to_tensor(tar_img)
hh, ww = tar_img.shape[1], tar_img.shape[2]
rr = random.randint(0, hh - ps)
cc = random.randint(0, ww - ps)
aug = random.randint(0, 8)
# Crop patch
inp_img = inp_img[:, rr:rr + ps, cc:cc + ps]
tar_img = tar_img[:, rr:rr + ps, cc:cc + ps]
filename = os.path.splitext(os.path.split(tar_path)[-1])[0]
return tar_img, inp_img, filename
@DATASETS.register()
class MPRVal(Dataset):
def __init__(self, rgb_dir, img_options=None, rgb_dir2=None):
super(MPRVal, self).__init__()
inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
self.inp_filenames = [
os.path.join(rgb_dir, 'input', x) for x in inp_files
if is_image_file(x)
]
self.tar_filenames = [
os.path.join(rgb_dir, 'target', x) for x in tar_files
if is_image_file(x)
]
self.img_options = img_options
self.sizex = len(self.tar_filenames) # get the size of target
self.ps = self.img_options['patch_size']
def __len__(self):
return self.sizex
def __getitem__(self, index):
index_ = index % self.sizex
ps = self.ps
inp_path = self.inp_filenames[index_]
tar_path = self.tar_filenames[index_]
inp_img = Image.open(inp_path)
tar_img = Image.open(tar_path)
# Validate on center crop
if self.ps is not None:
inp_img = center_crop(inp_img, (ps, ps))
tar_img = center_crop(tar_img, (ps, ps))
inp_img = to_tensor(inp_img)
tar_img = to_tensor(tar_img)
filename = os.path.splitext(os.path.split(tar_path)[-1])[0]
return tar_img, inp_img, filename
@DATASETS.register()
class MPRTest(Dataset):
def __init__(self, inp_dir, img_options):
super(MPRTest, self).__init__()
inp_files = sorted(os.listdir(inp_dir))
self.inp_filenames = [
os.path.join(inp_dir, x) for x in inp_files if is_image_file(x)
]
self.inp_size = len(self.inp_filenames)
self.img_options = img_options
def __len__(self):
return self.inp_size
def __getitem__(self, index):
path_inp = self.inp_filenames[index]
filename = os.path.splitext(os.path.split(path_inp)[-1])[0]
inp = Image.open(path_inp)
inp = to_tensor(inp)
return inp, filename