You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
159 lines
5.6 KiB
159 lines
5.6 KiB
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
import os |
|
import random |
|
from natsort import natsorted |
|
from glob import glob |
|
import numpy as np |
|
import cv2 |
|
from PIL import Image |
|
import paddle |
|
from .base_predictor import BasePredictor |
|
from ppgan.models.generators import MPRNet |
|
from ppgan.utils.download import get_path_from_url |
|
from ppgan.utils.visual import make_grid, tensor2img, save_image |
|
from ppgan.datasets.mpr_dataset import to_tensor |
|
from paddle.vision.transforms import Pad |
|
from tqdm import tqdm |
|
|
|
model_cfgs = { |
|
'Deblurring': { |
|
'model_urls': |
|
'https://paddlegan.bj.bcebos.com/models/MPR_Deblurring.pdparams', |
|
'n_feat': 96, |
|
'scale_unetfeats': 48, |
|
'scale_orsnetfeats': 32, |
|
}, |
|
'Denoising': { |
|
'model_urls': |
|
'https://paddlegan.bj.bcebos.com/models/MPR_Denoising.pdparams', |
|
'n_feat': 80, |
|
'scale_unetfeats': 48, |
|
'scale_orsnetfeats': 32, |
|
}, |
|
'Deraining': { |
|
'model_urls': |
|
'https://paddlegan.bj.bcebos.com/models/MPR_Deraining.pdparams', |
|
'n_feat': 40, |
|
'scale_unetfeats': 20, |
|
'scale_orsnetfeats': 16, |
|
} |
|
} |
|
|
|
|
|
class MPRPredictor(BasePredictor): |
|
def __init__(self, |
|
images_path=None, |
|
output_path='output_dir', |
|
weight_path=None, |
|
seed=None, |
|
task=None): |
|
self.output_path = output_path |
|
self.images_path = images_path |
|
self.task = task |
|
self.max_size = 640 |
|
self.img_multiple_of = 8 |
|
|
|
if weight_path is None: |
|
if task in model_cfgs.keys(): |
|
weight_path = get_path_from_url(model_cfgs[task]['model_urls']) |
|
checkpoint = paddle.load(weight_path) |
|
else: |
|
raise ValueError( |
|
'Predictor need a weight path or a pretrained model type') |
|
else: |
|
checkpoint = paddle.load(weight_path) |
|
|
|
self.generator = MPRNet( |
|
n_feat=model_cfgs[task]['n_feat'], |
|
scale_unetfeats=model_cfgs[task]['scale_unetfeats'], |
|
scale_orsnetfeats=model_cfgs[task]['scale_orsnetfeats']) |
|
self.generator.set_state_dict(checkpoint) |
|
self.generator.eval() |
|
|
|
if seed is not None: |
|
paddle.seed(seed) |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
|
|
def get_images(self, images_path): |
|
if os.path.isdir(images_path): |
|
return natsorted( |
|
glob(os.path.join(images_path, '*.jpg')) + glob( |
|
os.path.join(images_path, '*.JPG')) + glob( |
|
os.path.join(images_path, '*.png')) + glob( |
|
os.path.join(images_path, '*.PNG'))) |
|
else: |
|
return [images_path] |
|
|
|
def read_image(self, image_file): |
|
img = Image.open(image_file).convert('RGB') |
|
max_length = max(img.width, img.height) |
|
if max_length > self.max_size: |
|
ratio = max_length / self.max_size |
|
dw = int(img.width / ratio) |
|
dh = int(img.height / ratio) |
|
img = img.resize((dw, dh)) |
|
return img |
|
|
|
def run(self): |
|
os.makedirs(self.output_path, exist_ok=True) |
|
task_path = os.path.join(self.output_path, self.task) |
|
os.makedirs(task_path, exist_ok=True) |
|
image_files = self.get_images(self.images_path) |
|
for image_file in tqdm(image_files): |
|
img = self.read_image(image_file) |
|
image_name = os.path.basename(image_file) |
|
# img.save(os.path.join(task_path, image_name)) |
|
tmps = image_name.split('.') |
|
assert len( |
|
tmps) == 2, f'Invalid image name: {image_name}, too much "."' |
|
restoration_save_path = os.path.join( |
|
task_path, f'{tmps[0]}_restoration.{tmps[1]}') |
|
input_ = to_tensor(img) |
|
|
|
# Pad the input if not_multiple_of 8 |
|
h, w = input_.shape[1], input_.shape[2] |
|
|
|
H, W = ((h + self.img_multiple_of) // |
|
self.img_multiple_of) * self.img_multiple_of, ( |
|
(w + self.img_multiple_of) // |
|
self.img_multiple_of) * self.img_multiple_of |
|
padh = H - h if h % self.img_multiple_of != 0 else 0 |
|
padw = W - w if w % self.img_multiple_of != 0 else 0 |
|
input_ = paddle.to_tensor(input_) |
|
transform = Pad((0, 0, padw, padh), padding_mode='reflect') |
|
input_ = transform(input_) |
|
|
|
input_ = paddle.to_tensor(np.expand_dims(input_.numpy(), 0)) |
|
|
|
with paddle.no_grad(): |
|
restored = self.generator(input_) |
|
restored = restored[0] |
|
restored = paddle.clip(restored, 0, 1) |
|
|
|
# Unpad the output |
|
restored = restored[:, :, :h, :w] |
|
|
|
restored = restored.numpy() |
|
restored = restored.transpose(0, 2, 3, 1) |
|
restored = restored[0] |
|
restored = restored * 255 |
|
restored = restored.astype(np.uint8) |
|
|
|
cv2.imwrite(restoration_save_path, |
|
cv2.cvtColor(restored, cv2.COLOR_RGB2BGR)) |
|
|
|
print('Done, output path is:', task_path)
|
|
|