You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
298 lines
9.6 KiB
298 lines
9.6 KiB
#Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. |
|
# |
|
#Licensed under the Apache License, Version 2.0 (the "License"); |
|
#you may not use this file except in compliance with the License. |
|
#You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
#Unless required by applicable law or agreed to in writing, software |
|
#distributed under the License is distributed on an "AS IS" BASIS, |
|
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
#See the License for the specific language governing permissions and |
|
#limitations under the License. |
|
|
|
from __future__ import absolute_import |
|
|
|
from collections import namedtuple |
|
import numpy as np |
|
|
|
import paddle |
|
import paddle.nn as nn |
|
from paddle.utils.download import get_weights_path_from_url |
|
from ..modules import init |
|
from ..models.criterions.perceptual_loss import PerceptualVGG |
|
from .builder import METRICS |
|
|
|
lpips = True |
|
|
|
VGG16_TORCHVISION_URL = 'https://paddlegan.bj.bcebos.com/models/vgg16_official.pdparams' |
|
LINS_01_VGG_URL = 'https://paddlegan.bj.bcebos.com/models/lins_0.1_vgg.pdparams' |
|
|
|
|
|
@METRICS.register() |
|
class LPIPSMetric(paddle.metric.Metric): |
|
"""Calculate LPIPS (Learned Perceptual Image Patch Similarity). |
|
|
|
Ref: https://arxiv.org/abs/1801.03924 |
|
|
|
Args: |
|
net (str): Type of backbone net. Default: 'vgg'. |
|
version (str): Version of lpips method. Defalut: '0.1'. |
|
mean (list): Sequence of means for each channel of input image. Default: None. |
|
std (list): Sequence of standard deviations for each channel of input image. Default: None. |
|
|
|
Returns: |
|
float: lpips result. |
|
""" |
|
|
|
def __init__(self, net='vgg', version='0.1', mean=None, std=None): |
|
self.net = net |
|
self.version = version |
|
|
|
self.loss_fn = LPIPS(net=net, version=version) |
|
|
|
if mean is None: |
|
self.mean = [0.5, 0.5, 0.5] |
|
else: |
|
self.mean = mean |
|
|
|
if std is None: |
|
self.std = [0.5, 0.5, 0.5] |
|
else: |
|
self.std = std |
|
|
|
self.reset() |
|
|
|
def reset(self): |
|
self.results = [] |
|
|
|
def update(self, preds, gts): |
|
if not isinstance(preds, (list, tuple)): |
|
preds = [preds] |
|
|
|
if not isinstance(gts, (list, tuple)): |
|
gts = [gts] |
|
|
|
for pred, gt in zip(preds, gts): |
|
pred, gt = pred.astype(np.float32) / 255., gt.astype( |
|
np.float32) / 255. |
|
pred = paddle.vision.transforms.normalize( |
|
pred.transpose([2, 0, 1]), self.mean, self.std) |
|
gt = paddle.vision.transforms.normalize( |
|
gt.transpose([2, 0, 1]), self.mean, self.std) |
|
|
|
with paddle.no_grad(): |
|
value = self.loss_fn( |
|
paddle.to_tensor(pred).unsqueeze(0), |
|
paddle.to_tensor(gt).unsqueeze(0)) |
|
|
|
self.results.append(value.item()) |
|
|
|
def accumulate(self): |
|
if paddle.distributed.get_world_size() > 1: |
|
results = paddle.to_tensor(self.results) |
|
results_list = [] |
|
paddle.distributed.all_gather(results_list, results) |
|
self.results = paddle.concat(results_list).numpy() |
|
|
|
if len(self.results) <= 0: |
|
return 0. |
|
return np.mean(self.results) |
|
|
|
def name(self): |
|
return 'LPIPS' |
|
|
|
|
|
def spatial_average(in_tens, keepdim=True): |
|
return in_tens.mean([2, 3], keepdim=keepdim) |
|
|
|
|
|
# assumes scale factor is same for H and W |
|
def upsample(in_tens, out_HW=(64, 64)): |
|
in_H, in_W = in_tens.shape[2], in_tens.shape[3] |
|
scale_factor_H, scale_factor_W = 1. * out_HW[0] / in_H, 1. * out_HW[ |
|
1] / in_W |
|
|
|
return nn.Upsample( |
|
scale_factor=(scale_factor_H, scale_factor_W), |
|
mode='bilinear', |
|
align_corners=False)(in_tens) |
|
|
|
|
|
def normalize_tensor(in_feat, eps=1e-10): |
|
norm_factor = paddle.sqrt(paddle.sum(in_feat**2, 1, keepdim=True)) |
|
return in_feat / (norm_factor + eps) |
|
|
|
|
|
# Learned perceptual metric |
|
class LPIPS(nn.Layer): |
|
def __init__(self, |
|
pretrained=True, |
|
net='vgg', |
|
version='0.1', |
|
lpips=True, |
|
spatial=False, |
|
pnet_rand=False, |
|
pnet_tune=False, |
|
use_dropout=True, |
|
model_path=None, |
|
eval_mode=True, |
|
verbose=True): |
|
# lpips - [True] means with linear calibration on top of base network |
|
# pretrained - [True] means load linear weights |
|
|
|
super(LPIPS, self).__init__() |
|
if (verbose): |
|
print( |
|
'Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]' |
|
% ('LPIPS' if lpips else 'baseline', net, version, 'on' |
|
if spatial else 'off')) |
|
|
|
self.pnet_type = net |
|
self.pnet_tune = pnet_tune |
|
self.pnet_rand = pnet_rand |
|
self.spatial = spatial |
|
|
|
# false means baseline of just averaging all layers |
|
self.lpips = lpips |
|
|
|
self.version = version |
|
self.scaling_layer = ScalingLayer() |
|
|
|
if (self.pnet_type in ['vgg', 'vgg16']): |
|
net_type = vgg16 |
|
self.chns = [64, 128, 256, 512, 512] |
|
elif (self.pnet_type == 'alex'): |
|
raise TypeError('alex not support now!') |
|
|
|
elif (self.pnet_type == 'squeeze'): |
|
raise TypeError('squeeze not support now!') |
|
|
|
self.L = len(self.chns) |
|
|
|
self.net = net_type(pretrained=True, requires_grad=False) |
|
|
|
if (lpips): |
|
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) |
|
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) |
|
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) |
|
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) |
|
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) |
|
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] |
|
|
|
# TODO: add alex and squeezenet |
|
# 7 layers for squeezenet |
|
self.lins = nn.LayerList(self.lins) |
|
if (pretrained): |
|
if (model_path is None): |
|
model_path = get_weights_path_from_url(LINS_01_VGG_URL) |
|
|
|
if (verbose): |
|
print('Loading model from: %s' % model_path) |
|
|
|
self.lins.set_state_dict(paddle.load(model_path)) |
|
|
|
if (eval_mode): |
|
self.eval() |
|
|
|
def forward(self, in0, in1, retPerLayer=False, normalize=False): |
|
# turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] |
|
if normalize: |
|
in0 = 2 * in0 - 1 |
|
in1 = 2 * in1 - 1 |
|
|
|
# v0.0 - original release had a bug, where input was not scaled |
|
in0_input, in1_input = ( |
|
self.scaling_layer(in0), |
|
self.scaling_layer(in1)) if self.version == '0.1' else (in0, in1) |
|
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) |
|
feats0, feats1, diffs = {}, {}, {} |
|
|
|
for kk in range(self.L): |
|
feats0[kk], feats1[kk] = normalize_tensor(outs0[ |
|
kk]), normalize_tensor(outs1[kk]) |
|
diffs[kk] = (feats0[kk] - feats1[kk])**2 |
|
|
|
if (self.lpips): |
|
if (self.spatial): |
|
res = [ |
|
upsample( |
|
self.lins[kk].model(diffs[kk]), out_HW=in0.shape[2:]) |
|
for kk in range(self.L) |
|
] |
|
else: |
|
res = [ |
|
spatial_average( |
|
self.lins[kk].model(diffs[kk]), keepdim=True) |
|
for kk in range(self.L) |
|
] |
|
else: |
|
if (self.spatial): |
|
res = [ |
|
upsample( |
|
diffs[kk].sum(dim=1, keepdim=True), |
|
out_HW=in0.shape[2:]) for kk in range(self.L) |
|
] |
|
else: |
|
res = [ |
|
spatial_average( |
|
diffs[kk].sum(dim=1, keepdim=True), keepdim=True) |
|
for kk in range(self.L) |
|
] |
|
|
|
val = res[0] |
|
for l in range(1, self.L): |
|
val += res[l] |
|
|
|
if (retPerLayer): |
|
return (val, res) |
|
else: |
|
return val |
|
|
|
|
|
class ScalingLayer(nn.Layer): |
|
def __init__(self): |
|
super(ScalingLayer, self).__init__() |
|
self.register_buffer( |
|
'shift', |
|
paddle.to_tensor([-.030, -.088, -.188]).reshape([1, 3, 1, 1])) |
|
self.register_buffer( |
|
'scale', paddle.to_tensor([.458, .448, .450]).reshape([1, 3, 1, 1])) |
|
|
|
def forward(self, inp): |
|
return (inp - self.shift) / self.scale |
|
|
|
|
|
class NetLinLayer(nn.Layer): |
|
''' A single linear layer which does a 1x1 conv ''' |
|
|
|
def __init__(self, chn_in, chn_out=1, use_dropout=False): |
|
super(NetLinLayer, self).__init__() |
|
|
|
layers = [nn.Dropout(), ] if (use_dropout) else [] |
|
layers += [ |
|
nn.Conv2D( |
|
chn_in, chn_out, 1, stride=1, padding=0, bias_attr=False), |
|
] |
|
self.model = nn.Sequential(*layers) |
|
|
|
|
|
class vgg16(nn.Layer): |
|
def __init__(self, requires_grad=False, pretrained=True): |
|
super(vgg16, self).__init__() |
|
self.vgg16 = PerceptualVGG(['3', '8', '15', '22', '29'], 'vgg16', False, |
|
VGG16_TORCHVISION_URL) |
|
|
|
if not requires_grad: |
|
for param in self.parameters(): |
|
param.trainable = False |
|
|
|
def forward(self, x): |
|
out = self.vgg16(x) |
|
vgg_outputs = namedtuple( |
|
"VggOutputs", |
|
['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) |
|
out = vgg_outputs(out['3'], out['8'], out['15'], out['22'], out['29']) |
|
|
|
return out
|
|
|