You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

352 lines
11 KiB

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import numpy as np
import paddle
from .builder import METRICS
@METRICS.register()
class PSNR(paddle.metric.Metric):
def __init__(self, crop_border, input_order='HWC', test_y_channel=False):
self.crop_border = crop_border
self.input_order = input_order
self.test_y_channel = test_y_channel
self.reset()
def reset(self):
self.results = []
def update(self, preds, gts, is_seq=False):
if not isinstance(preds, (list, tuple)):
preds = [preds]
if not isinstance(gts, (list, tuple)):
gts = [gts]
if is_seq:
single_seq = []
for pred, gt in zip(preds, gts):
value = calculate_psnr(pred, gt, self.crop_border, self.input_order,
self.test_y_channel)
if is_seq:
single_seq.append(value)
else:
self.results.append(value)
if is_seq:
self.results.append(np.mean(single_seq))
def accumulate(self):
if paddle.distributed.get_world_size() > 1:
results = paddle.to_tensor(self.results)
results_list = []
paddle.distributed.all_gather(results_list, results)
self.results = paddle.concat(results_list).numpy()
if len(self.results) <= 0:
return 0.
return np.mean(self.results)
def name(self):
return 'PSNR'
@METRICS.register()
class SSIM(PSNR):
def update(self, preds, gts, is_seq=False):
if not isinstance(preds, (list, tuple)):
preds = [preds]
if not isinstance(gts, (list, tuple)):
gts = [gts]
if is_seq:
single_seq = []
for pred, gt in zip(preds, gts):
value = calculate_ssim(pred, gt, self.crop_border, self.input_order,
self.test_y_channel)
if is_seq:
single_seq.append(value)
else:
self.results.append(value)
if is_seq:
self.results.append(np.mean(single_seq))
def name(self):
return 'SSIM'
def calculate_psnr(img1,
img2,
crop_border,
input_order='HWC',
test_y_channel=False):
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
Args:
img1 (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These
pixels are not involved in the PSNR calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: psnr result.
"""
assert img1.shape == img2.shape, (
f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are '
'"HWC" and "CHW"')
img1 = img1.copy().astype('float32')
img2 = img2.copy().astype('float32')
img1 = reorder_image(img1, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)
if crop_border != 0:
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
if test_y_channel:
img1 = to_y_channel(img1)
img2 = to_y_channel(img2)
mse = np.mean((img1 - img2)**2)
if mse == 0:
return float('inf')
return 20. * np.log10(255. / np.sqrt(mse))
def _ssim(img1, img2):
"""Calculate SSIM (structural similarity) for one channel images.
It is called by func:`calculate_ssim`.
Args:
img1 (ndarray): Images with range [0, 255] with order 'HWC'.
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
Returns:
float: ssim result.
"""
C1 = (0.01 * 255)**2
C2 = (0.03 * 255)**2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) *
(2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
def calculate_ssim(img1,
img2,
crop_border,
input_order='HWC',
test_y_channel=False):
"""Calculate SSIM (structural similarity).
Ref:
Image quality assessment: From error visibility to structural similarity
The results are the same as that of the official released MATLAB code in
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
For three-channel images, SSIM is calculated for each channel and then
averaged.
Args:
img1 (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These
pixels are not involved in the SSIM calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: ssim result.
"""
assert img1.shape == img2.shape, (
f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are '
'"HWC" and "CHW"')
img1 = img1.copy().astype('float32')
img2 = img2.copy().astype('float32')
img1 = reorder_image(img1, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)
if crop_border != 0:
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
if test_y_channel:
img1 = to_y_channel(img1)
img2 = to_y_channel(img2)
ssims = []
for i in range(img1.shape[2]):
ssims.append(_ssim(img1[..., i], img2[..., i]))
return np.array(ssims).mean()
def reorder_image(img, input_order='HWC'):
"""Reorder images to 'HWC' order.
If the input_order is (h, w), return (h, w, 1);
If the input_order is (c, h, w), return (h, w, c);
If the input_order is (h, w, c), return as it is.
Args:
img (ndarray): Input image.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
If the input image shape is (h, w), input_order will not have
effects. Default: 'HWC'.
Returns:
ndarray: reordered image.
"""
if input_order not in ['HWC', 'CHW']:
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are '
"'HWC' and 'CHW'")
if len(img.shape) == 2:
img = img[..., None]
return img
if input_order == 'CHW':
img = img.transpose(1, 2, 0)
return img
def bgr2ycbcr(img, y_only=False):
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
if y_only:
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
else:
out_img = np.matmul(
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) + [16, 128, 128]
return out_img
def rgb2ycbcr(img, y_only=False):
"""Convert a RGB image to YCbCr image.
The RGB version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
if img_type != np.uint8:
img *= 255.
if y_only:
out_img = np.dot(img, [65.481, 128.553, 24.966]) / 255. + 16.0
else:
out_img = np.matmul(
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) + [16, 128, 128]
if img_type != np.uint8:
out_img /= 255.
else:
out_img = out_img.round()
return out_img
def to_y_channel(img):
"""Change to Y channel of YCbCr.
Args:
img (ndarray): Images with range [0, 255].
Returns:
(ndarray): Images with range [0, 255] (float type) without round.
"""
img = img.astype(np.float32) / 255.
if img.ndim == 3 and img.shape[2] == 3:
img = rgb2ycbcr(img, y_only=True)
img = img[..., None]
return img * 255.