You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
194 lines
5.8 KiB
194 lines
5.8 KiB
# code was heavily based on https://github.com/swz30/MPRNet |
|
# Users should be careful about adopting these functions in any commercial matters. |
|
# https://github.com/swz30/MPRNet/blob/main/LICENSE.md |
|
|
|
import os |
|
import random |
|
import numpy as np |
|
import cv2 |
|
import paddle |
|
from PIL import Image, ImageEnhance |
|
import numpy as np |
|
import random |
|
import numbers |
|
from paddle.io import Dataset |
|
from .builder import DATASETS |
|
from paddle.vision.transforms.functional import to_tensor, adjust_brightness, adjust_saturation, rotate, hflip, hflip, vflip, center_crop |
|
|
|
|
|
def is_image_file(filename): |
|
return any( |
|
filename.endswith(extension) |
|
for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) |
|
|
|
|
|
@DATASETS.register() |
|
class MPRTrain(Dataset): |
|
def __init__(self, rgb_dir, img_options=None): |
|
super(MPRTrain, self).__init__() |
|
|
|
inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input'))) |
|
tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target'))) |
|
|
|
self.inp_filenames = [ |
|
os.path.join(rgb_dir, 'input', x) for x in inp_files |
|
if is_image_file(x) |
|
] |
|
self.tar_filenames = [ |
|
os.path.join(rgb_dir, 'target', x) for x in tar_files |
|
if is_image_file(x) |
|
] |
|
|
|
self.img_options = img_options |
|
self.sizex = len(self.tar_filenames) # get the size of target |
|
|
|
self.ps = self.img_options['patch_size'] |
|
|
|
def __len__(self): |
|
return self.sizex |
|
|
|
def __getitem__(self, index): |
|
index_ = index % self.sizex |
|
ps = self.ps |
|
|
|
inp_path = self.inp_filenames[index_] |
|
tar_path = self.tar_filenames[index_] |
|
|
|
inp_img = Image.open(inp_path) |
|
tar_img = Image.open(tar_path) |
|
|
|
w, h = tar_img.size |
|
padw = ps - w if w < ps else 0 |
|
padh = ps - h if h < ps else 0 |
|
|
|
# Reflect Pad in case image is smaller than patch_size |
|
if padw != 0 or padh != 0: |
|
inp_img = np.pad(inp_img, (0, 0, padw, padh), |
|
padding_mode='reflect') |
|
tar_img = np.pad(tar_img, (0, 0, padw, padh), |
|
padding_mode='reflect') |
|
|
|
aug = random.randint(0, 2) |
|
if aug == 1: |
|
inp_img = adjust_brightness(inp_img, 1) |
|
tar_img = adjust_brightness(tar_img, 1) |
|
|
|
aug = random.randint(0, 2) |
|
if aug == 1: |
|
sat_factor = 1 + (0.2 - 0.4 * np.random.rand()) |
|
inp_img = adjust_saturation(inp_img, sat_factor) |
|
tar_img = adjust_saturation(tar_img, sat_factor) |
|
|
|
# Data Augmentations |
|
if aug == 1: |
|
inp_img = vflip(inp_img) |
|
tar_img = vflip(tar_img) |
|
elif aug == 2: |
|
inp_img = hflip(inp_img) |
|
tar_img = hflip(tar_img) |
|
elif aug == 3: |
|
inp_img = rotate(inp_img, 90) |
|
tar_img = rotate(tar_img, 90) |
|
elif aug == 4: |
|
inp_img = rotate(inp_img, 90 * 2) |
|
tar_img = rotate(tar_img, 90 * 2) |
|
elif aug == 5: |
|
inp_img = rotate(inp_img, 90 * 3) |
|
tar_img = rotate(tar_img, 90 * 3) |
|
elif aug == 6: |
|
inp_img = rotate(vflip(inp_img), 90) |
|
tar_img = rotate(vflip(tar_img), 90) |
|
elif aug == 7: |
|
inp_img = rotate(hflip(inp_img), 90) |
|
tar_img = rotate(hflip(tar_img), 90) |
|
|
|
inp_img = to_tensor(inp_img) |
|
tar_img = to_tensor(tar_img) |
|
|
|
hh, ww = tar_img.shape[1], tar_img.shape[2] |
|
|
|
rr = random.randint(0, hh - ps) |
|
cc = random.randint(0, ww - ps) |
|
aug = random.randint(0, 8) |
|
|
|
# Crop patch |
|
inp_img = inp_img[:, rr:rr + ps, cc:cc + ps] |
|
tar_img = tar_img[:, rr:rr + ps, cc:cc + ps] |
|
|
|
filename = os.path.splitext(os.path.split(tar_path)[-1])[0] |
|
|
|
return tar_img, inp_img, filename |
|
|
|
|
|
@DATASETS.register() |
|
class MPRVal(Dataset): |
|
def __init__(self, rgb_dir, img_options=None, rgb_dir2=None): |
|
super(MPRVal, self).__init__() |
|
|
|
inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input'))) |
|
tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target'))) |
|
|
|
self.inp_filenames = [ |
|
os.path.join(rgb_dir, 'input', x) for x in inp_files |
|
if is_image_file(x) |
|
] |
|
self.tar_filenames = [ |
|
os.path.join(rgb_dir, 'target', x) for x in tar_files |
|
if is_image_file(x) |
|
] |
|
|
|
self.img_options = img_options |
|
self.sizex = len(self.tar_filenames) # get the size of target |
|
|
|
self.ps = self.img_options['patch_size'] |
|
|
|
def __len__(self): |
|
return self.sizex |
|
|
|
def __getitem__(self, index): |
|
index_ = index % self.sizex |
|
ps = self.ps |
|
|
|
inp_path = self.inp_filenames[index_] |
|
tar_path = self.tar_filenames[index_] |
|
|
|
inp_img = Image.open(inp_path) |
|
tar_img = Image.open(tar_path) |
|
|
|
# Validate on center crop |
|
if self.ps is not None: |
|
inp_img = center_crop(inp_img, (ps, ps)) |
|
tar_img = center_crop(tar_img, (ps, ps)) |
|
|
|
inp_img = to_tensor(inp_img) |
|
tar_img = to_tensor(tar_img) |
|
|
|
filename = os.path.splitext(os.path.split(tar_path)[-1])[0] |
|
|
|
return tar_img, inp_img, filename |
|
|
|
|
|
@DATASETS.register() |
|
class MPRTest(Dataset): |
|
def __init__(self, inp_dir, img_options): |
|
super(MPRTest, self).__init__() |
|
|
|
inp_files = sorted(os.listdir(inp_dir)) |
|
self.inp_filenames = [ |
|
os.path.join(inp_dir, x) for x in inp_files if is_image_file(x) |
|
] |
|
|
|
self.inp_size = len(self.inp_filenames) |
|
self.img_options = img_options |
|
|
|
def __len__(self): |
|
return self.inp_size |
|
|
|
def __getitem__(self, index): |
|
|
|
path_inp = self.inp_filenames[index] |
|
filename = os.path.splitext(os.path.split(path_inp)[-1])[0] |
|
inp = Image.open(path_inp) |
|
|
|
inp = to_tensor(inp) |
|
return inp, filename
|
|
|