You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
300 lines
10 KiB
300 lines
10 KiB
# code was heavily based on https://github.com/mseitzer/pytorch-fid |
|
# This implementation is licensed under the Apache License 2.0. |
|
# Copyright (c) mseitzer |
|
|
|
import os |
|
import fnmatch |
|
import numpy as np |
|
import cv2 |
|
import paddle |
|
from PIL import Image |
|
from cv2 import imread |
|
from scipy import linalg |
|
from .inception import InceptionV3 |
|
from paddle.utils.download import get_weights_path_from_url |
|
from .builder import METRICS |
|
|
|
try: |
|
from tqdm import tqdm |
|
except: |
|
|
|
def tqdm(x): |
|
return x |
|
|
|
|
|
""" based on https://github.com/mit-han-lab/gan-compression/blob/master/metric/fid_score.py |
|
""" |
|
""" |
|
inceptionV3 pretrain model is convert from pytorch, pretrain_model url is https://paddle-gan-models.bj.bcebos.com/params_inceptionV3.tar.gz |
|
""" |
|
INCEPTIONV3_WEIGHT_URL = "https://paddlegan.bj.bcebos.com/InceptionV3.pdparams" |
|
|
|
|
|
@METRICS.register() |
|
class FID(paddle.metric.Metric): |
|
def __init__(self, |
|
batch_size=1, |
|
use_GPU=True, |
|
dims=2048, |
|
premodel_path=None, |
|
model=None): |
|
self.batch_size = batch_size |
|
self.use_GPU = use_GPU |
|
self.dims = dims |
|
self.premodel_path = premodel_path |
|
if model is None: |
|
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] |
|
model = InceptionV3([block_idx], normalize_input=False) |
|
if premodel_path is None: |
|
premodel_path = get_weights_path_from_url(INCEPTIONV3_WEIGHT_URL) |
|
self.model = model |
|
param_dict = paddle.load(premodel_path) |
|
self.model.load_dict(param_dict) |
|
self.model.eval() |
|
self.reset() |
|
|
|
def reset(self): |
|
self.preds = [] |
|
self.gts = [] |
|
self.results = [] |
|
|
|
def update(self, preds, gts): |
|
preds_inception, gts_inception = calculate_inception_val( |
|
preds, gts, self.batch_size, self.model, self.use_GPU, self.dims) |
|
self.preds.append(preds_inception) |
|
self.gts.append(gts_inception) |
|
|
|
def accumulate(self): |
|
self.preds = np.concatenate(self.preds, axis=0) |
|
self.gts = np.concatenate(self.gts, axis=0) |
|
value = calculate_fid_given_img(self.preds, self.gts) |
|
self.reset() |
|
return value |
|
|
|
def name(self): |
|
return 'FID' |
|
|
|
|
|
def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): |
|
m1 = np.atleast_1d(mu1) |
|
m2 = np.atleast_1d(mu2) |
|
|
|
sigma1 = np.atleast_2d(sigma1) |
|
sigma2 = np.atleast_2d(sigma2) |
|
|
|
assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths' |
|
assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions' |
|
|
|
diff = mu1 - mu2 |
|
|
|
t = sigma1.dot(sigma2) |
|
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) |
|
if not np.isfinite(covmean).all(): |
|
msg = ('fid calculation produces singular product; ' |
|
'adding %s to diagonal of cov estimates') % eps |
|
print(msg) |
|
offset = np.eye(sigma1.shape[0]) * eps |
|
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) |
|
|
|
# Numerical error might give slight imaginary component |
|
if np.iscomplexobj(covmean): |
|
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): |
|
m = np.max(np.abs(covmean.imag)) |
|
raise ValueError('Imaginary component {}'.format(m)) |
|
covmean = covmean.real |
|
tr_covmean = np.trace(covmean) |
|
|
|
return ( |
|
diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) |
|
|
|
|
|
def _get_activations_from_ims(img, model, batch_size, dims, use_gpu): |
|
n_batches = (len(img) + batch_size - 1) // batch_size |
|
n_used_img = len(img) |
|
|
|
pred_arr = np.empty((n_used_img, dims)) |
|
|
|
for i in range(n_batches): |
|
start = i * batch_size |
|
end = start + batch_size |
|
if end > len(img): |
|
end = len(img) |
|
images = img[start:end] |
|
if images.shape[1] != 3: |
|
images = images.transpose((0, 3, 1, 2)) |
|
|
|
images = paddle.to_tensor(images) |
|
pred = model(images)[0][0] |
|
pred_arr[start:end] = pred.reshape([end - start, -1]).cpu().numpy() |
|
return pred_arr |
|
|
|
|
|
def _compute_statistic_of_img(act): |
|
mu = np.mean(act, axis=0) |
|
sigma = np.cov(act, rowvar=False) |
|
return mu, sigma |
|
|
|
|
|
def calculate_inception_val(img_fake, |
|
img_real, |
|
batch_size, |
|
model, |
|
use_gpu=True, |
|
dims=2048): |
|
act_fake = _get_activations_from_ims(img_fake, model, batch_size, dims, |
|
use_gpu) |
|
act_real = _get_activations_from_ims(img_real, model, batch_size, dims, |
|
use_gpu) |
|
return act_fake, act_real |
|
|
|
|
|
def calculate_fid_given_img(act_fake, act_real): |
|
|
|
m1, s1 = _compute_statistic_of_img(act_fake) |
|
m2, s2 = _compute_statistic_of_img(act_real) |
|
fid_value = _calculate_frechet_distance(m1, s1, m2, s2) |
|
return fid_value |
|
|
|
|
|
def _get_activations(files, |
|
model, |
|
batch_size, |
|
dims, |
|
use_gpu, |
|
premodel_path, |
|
style=None): |
|
if len(files) % batch_size != 0: |
|
print(('Warning: number of images is not a multiple of the ' |
|
'batch size. Some samples are going to be ignored.')) |
|
if batch_size > len(files): |
|
print(('Warning: batch size is bigger than the datasets size. ' |
|
'Setting batch size to datasets size')) |
|
batch_size = len(files) |
|
|
|
n_batches = len(files) // batch_size |
|
n_used_imgs = n_batches * batch_size |
|
|
|
pred_arr = np.empty((n_used_imgs, dims)) |
|
for i in tqdm(range(n_batches)): |
|
start = i * batch_size |
|
end = start + batch_size |
|
|
|
# same as stargan-v2 official implementation: resize to 256 first, then resize to 299 |
|
if style == 'stargan': |
|
img_list = [] |
|
for f in files[start:end]: |
|
im = Image.open(str(f)).convert('RGB') |
|
if im.size[0] != 299: |
|
im = im.resize((256, 256), 2) |
|
im = im.resize((299, 299), 2) |
|
|
|
img_list.append(np.array(im).astype('float32')) |
|
|
|
images = np.array(img_list) |
|
else: |
|
images = np.array( |
|
[imread(str(f)).astype(np.float32) for f in files[start:end]]) |
|
|
|
if len(images.shape) != 4: |
|
images = imread(str(files[start])) |
|
images = cv2.cvtColor(images, cv2.COLOR_BGR2GRAY) |
|
images = np.array([images.astype(np.float32)]) |
|
|
|
images = images.transpose((0, 3, 1, 2)) |
|
images /= 255 |
|
|
|
# imagenet normalization |
|
if style == 'stargan': |
|
mean = np.array([0.485, 0.456, 0.406]).astype('float32') |
|
std = np.array([0.229, 0.224, 0.225]).astype('float32') |
|
images[:] = (images[:] - mean[:, None, None]) / std[:, None, None] |
|
|
|
if style == 'stargan': |
|
pred_arr[start:end] = inception_infer(images, premodel_path) |
|
else: |
|
with paddle.guard(): |
|
images = paddle.to_tensor(images) |
|
param_dict, _ = paddle.load(premodel_path) |
|
model.set_dict(param_dict) |
|
model.eval() |
|
|
|
pred = model(images)[0][0].numpy() |
|
|
|
pred_arr[start:end] = pred.reshape(end - start, -1) |
|
|
|
return pred_arr |
|
|
|
|
|
def inception_infer(x, model_path): |
|
exe = paddle.static.Executor() |
|
[inference_program, feed_target_names, |
|
fetch_targets] = paddle.static.load_inference_model(model_path, exe) |
|
results = exe.run(inference_program, |
|
feed={feed_target_names[0]: x}, |
|
fetch_list=fetch_targets) |
|
return results[0] |
|
|
|
|
|
def _calculate_activation_statistics(files, |
|
model, |
|
premodel_path, |
|
batch_size=50, |
|
dims=2048, |
|
use_gpu=False, |
|
style=None): |
|
act = _get_activations(files, model, batch_size, dims, use_gpu, |
|
premodel_path, style) |
|
mu = np.mean(act, axis=0) |
|
sigma = np.cov(act, rowvar=False) |
|
return mu, sigma |
|
|
|
|
|
def _compute_statistics_of_path(path, |
|
model, |
|
batch_size, |
|
dims, |
|
use_gpu, |
|
premodel_path, |
|
style=None): |
|
if path.endswith('.npz'): |
|
f = np.load(path) |
|
m, s = f['mu'][:], f['sigma'][:] |
|
f.close() |
|
else: |
|
files = [] |
|
for root, dirnames, filenames in os.walk(path): |
|
for filename in fnmatch.filter( |
|
filenames, '*.jpg') or fnmatch.filter(filenames, '*.png'): |
|
files.append(os.path.join(root, filename)) |
|
m, s = _calculate_activation_statistics( |
|
files, model, premodel_path, batch_size, dims, use_gpu, style) |
|
return m, s |
|
|
|
|
|
def calculate_fid_given_paths(paths, |
|
premodel_path, |
|
batch_size, |
|
use_gpu, |
|
dims, |
|
model=None, |
|
style=None): |
|
assert os.path.exists( |
|
premodel_path |
|
), 'pretrain_model path {} is not exists! Please download it first'.format( |
|
premodel_path) |
|
for p in paths: |
|
if not os.path.exists(p): |
|
raise RuntimeError('Invalid path: %s' % p) |
|
|
|
if model is None and style != 'stargan': |
|
with paddle.guard(): |
|
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] |
|
model = InceptionV3([block_idx], class_dim=1008) |
|
|
|
m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims, |
|
use_gpu, premodel_path, style) |
|
m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, dims, |
|
use_gpu, premodel_path, style) |
|
|
|
fid_value = _calculate_frechet_distance(m1, s1, m2, s2) |
|
return fid_value
|
|
|