You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
321 lines
12 KiB
321 lines
12 KiB
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
import os |
|
import math |
|
import random |
|
import paddle |
|
import paddle.nn as nn |
|
from .base_model import BaseModel |
|
|
|
from .builder import MODELS |
|
from .criterions import build_criterion |
|
from .generators.builder import build_generator |
|
from .discriminators.builder import build_discriminator |
|
from ..solver import build_lr_scheduler, build_optimizer |
|
|
|
|
|
def r1_penalty(real_pred, real_img): |
|
""" |
|
R1 regularization for discriminator. The core idea is to |
|
penalize the gradient on real data alone: when the |
|
generator distribution produces the true data distribution |
|
and the discriminator is equal to 0 on the data manifold, the |
|
gradient penalty ensures that the discriminator cannot create |
|
a non-zero gradient orthogonal to the data manifold without |
|
suffering a loss in the GAN game. |
|
|
|
Ref: |
|
Eq. 9 in Which training methods for GANs do actually converge. |
|
""" |
|
|
|
grad_real = paddle.grad( |
|
outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] |
|
grad_penalty = (grad_real * grad_real).reshape([grad_real.shape[0], |
|
-1]).sum(1).mean() |
|
return grad_penalty |
|
|
|
|
|
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): |
|
noise = paddle.randn(fake_img.shape) / math.sqrt(fake_img.shape[2] * |
|
fake_img.shape[3]) |
|
grad = paddle.grad( |
|
outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0] |
|
path_lengths = paddle.sqrt((grad * grad).sum(2).mean(1)) |
|
|
|
path_mean = mean_path_length + decay * (path_lengths.mean() - |
|
mean_path_length) |
|
|
|
path_penalty = ( |
|
(path_lengths - path_mean) * (path_lengths - path_mean)).mean() |
|
|
|
return path_penalty, path_lengths.detach().mean(), path_mean.detach() |
|
|
|
|
|
@MODELS.register() |
|
class StyleGAN2Model(BaseModel): |
|
""" |
|
This class implements the StyleGANV2 model, for learning image-to-image translation without paired data. |
|
|
|
StyleGAN2 paper: https://arxiv.org/pdf/1912.04958.pdf |
|
""" |
|
|
|
def __init__(self, |
|
generator, |
|
discriminator=None, |
|
gan_criterion=None, |
|
num_style_feat=512, |
|
mixing_prob=0.9, |
|
r1_reg_weight=10., |
|
path_reg_weight=2., |
|
path_batch_shrink=2., |
|
params=None, |
|
max_eval_steps=50000): |
|
"""Initialize the CycleGAN class. |
|
|
|
Args: |
|
generator (dict): config of generator. |
|
discriminator (dict): config of discriminator. |
|
gan_criterion (dict): config of gan criterion. |
|
""" |
|
super(StyleGAN2Model, self).__init__(params) |
|
self.gen_iters = 4 if self.params is None else self.params.get( |
|
'gen_iters', 4) |
|
self.disc_iters = 16 if self.params is None else self.params.get( |
|
'disc_iters', 16) |
|
self.disc_start_iters = (0 if self.params is None else |
|
self.params.get('disc_start_iters', 0)) |
|
|
|
self.visual_iters = (500 if self.params is None else |
|
self.params.get('visual_iters', 500)) |
|
|
|
self.mixing_prob = mixing_prob |
|
self.num_style_feat = num_style_feat |
|
self.r1_reg_weight = r1_reg_weight |
|
|
|
self.path_reg_weight = path_reg_weight |
|
self.path_batch_shrink = path_batch_shrink |
|
self.mean_path_length = 0 |
|
|
|
self.nets['gen'] = build_generator(generator) |
|
self.max_eval_steps = max_eval_steps |
|
|
|
# define discriminators |
|
if discriminator: |
|
self.nets['disc'] = build_discriminator(discriminator) |
|
|
|
self.nets['gen_ema'] = build_generator(generator) |
|
self.model_ema(0) |
|
|
|
self.nets['gen'].train() |
|
self.nets['gen_ema'].eval() |
|
self.nets['disc'].train() |
|
self.current_iter = 1 |
|
|
|
# define loss functions |
|
if gan_criterion: |
|
self.gan_criterion = build_criterion(gan_criterion) |
|
|
|
def setup_lr_schedulers(self, cfg): |
|
self.lr_scheduler = dict() |
|
gen_cfg = cfg.copy() |
|
net_g_reg_ratio = self.gen_iters / (self.gen_iters + 1) |
|
gen_cfg['learning_rate'] = cfg['learning_rate'] * net_g_reg_ratio |
|
self.lr_scheduler['gen'] = build_lr_scheduler(gen_cfg) |
|
|
|
disc_cfg = cfg.copy() |
|
net_d_reg_ratio = self.disc_iters / (self.disc_iters + 1) |
|
disc_cfg['learning_rate'] = cfg['learning_rate'] * net_d_reg_ratio |
|
self.lr_scheduler['disc'] = build_lr_scheduler(disc_cfg) |
|
return self.lr_scheduler |
|
|
|
def setup_optimizers(self, lr, cfg): |
|
for opt_name, opt_cfg in cfg.items(): |
|
if opt_name == 'optimG': |
|
_lr = lr['gen'] |
|
elif opt_name == 'optimD': |
|
_lr = lr['disc'] |
|
else: |
|
raise ValueError("opt name must be in ['optimG', optimD]") |
|
|
|
cfg_ = opt_cfg.copy() |
|
net_names = cfg_.pop('net_names') |
|
parameters = [] |
|
for net_name in net_names: |
|
parameters += self.nets[net_name].parameters() |
|
self.optimizers[opt_name] = build_optimizer(cfg_, _lr, parameters) |
|
|
|
return self.optimizers |
|
|
|
def get_bare_model(self, net): |
|
"""Get bare model, especially under wrapping with DataParallel. |
|
""" |
|
if isinstance(net, (paddle.DataParallel)): |
|
net = net._layers |
|
return net |
|
|
|
def model_ema(self, decay=0.999): |
|
net_g = self.get_bare_model(self.nets['gen']) |
|
net_g_params = dict(net_g.named_parameters()) |
|
|
|
neg_g_ema = self.get_bare_model(self.nets['gen_ema']) |
|
net_g_ema_params = dict(neg_g_ema.named_parameters()) |
|
|
|
for k in net_g_ema_params.keys(): |
|
net_g_ema_params[k].set_value(net_g_ema_params[k] * (decay) + ( |
|
net_g_params[k] * (1 - decay))) |
|
|
|
def setup_input(self, input): |
|
"""Unpack input data from the dataloader and perform necessary pre-processing steps. |
|
|
|
Args: |
|
input (dict): include the data itself and its metadata information. |
|
|
|
""" |
|
self.real_img = paddle.to_tensor(input['A']) |
|
|
|
def forward(self): |
|
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" |
|
pass |
|
|
|
def make_noise(self, batch, num_noise): |
|
if num_noise == 1: |
|
noises = paddle.randn([batch, self.num_style_feat]) |
|
else: |
|
noises = [] |
|
for _ in range(num_noise): |
|
noises.append(paddle.randn([batch, self.num_style_feat])) |
|
return noises |
|
|
|
def mixing_noise(self, batch, prob): |
|
if random.random() < prob: |
|
return self.make_noise(batch, 2) |
|
else: |
|
return [self.make_noise(batch, 1)] |
|
|
|
def train_iter(self, optimizers=None): |
|
current_iter = self.current_iter |
|
self.set_requires_grad(self.nets['disc'], True) |
|
optimizers['optimD'].clear_grad() |
|
batch = self.real_img.shape[0] |
|
noise = self.mixing_noise(batch, self.mixing_prob) |
|
|
|
fake_img, _ = self.nets['gen'](noise) |
|
self.visual_items['real_img'] = self.real_img |
|
self.visual_items['fake_img'] = fake_img |
|
fake_pred = self.nets['disc'](fake_img.detach()) |
|
|
|
real_pred = self.nets['disc'](self.real_img) |
|
# wgan loss with softplus (logistic loss) for discriminator |
|
l_d_total = 0. |
|
l_d = self.gan_criterion( |
|
real_pred, True, is_disc=True) + self.gan_criterion( |
|
fake_pred, False, is_disc=True) |
|
self.losses['l_d'] = l_d |
|
# In wgan, real_score should be positive and fake_score should be |
|
# negative |
|
self.losses['real_score'] = real_pred.detach().mean() |
|
self.losses['fake_score'] = fake_pred.detach().mean() |
|
|
|
l_d_total += l_d |
|
|
|
if current_iter % self.disc_iters == 0: |
|
self.real_img.stop_gradient = False |
|
real_pred = self.nets['disc'](self.real_img) |
|
l_d_r1 = r1_penalty(real_pred, self.real_img) |
|
l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.disc_iters + 0 * |
|
real_pred[0]) |
|
|
|
self.losses['l_d_r1'] = l_d_r1.detach().mean() |
|
|
|
l_d_total += l_d_r1 |
|
l_d_total.backward() |
|
|
|
optimizers['optimD'].step() |
|
|
|
self.set_requires_grad(self.nets['disc'], False) |
|
optimizers['optimG'].clear_grad() |
|
|
|
noise = self.mixing_noise(batch, self.mixing_prob) |
|
fake_img, _ = self.nets['gen'](noise) |
|
fake_pred = self.nets['disc'](fake_img) |
|
|
|
# wgan loss with softplus (non-saturating loss) for generator |
|
l_g_total = 0. |
|
l_g = self.gan_criterion(fake_pred, True, is_disc=False) |
|
self.losses['l_g'] = l_g |
|
|
|
l_g_total += l_g |
|
if current_iter % self.gen_iters == 0: |
|
path_batch_size = max(1, int(batch // self.path_batch_shrink)) |
|
noise = self.mixing_noise(path_batch_size, self.mixing_prob) |
|
fake_img, latents = self.nets['gen'](noise, return_latents=True) |
|
l_g_path, path_lengths, self.mean_path_length = g_path_regularize( |
|
fake_img, latents, self.mean_path_length) |
|
|
|
l_g_path = (self.path_reg_weight * self.gen_iters * l_g_path + 0 * |
|
fake_img[0, 0, 0, 0]) |
|
|
|
l_g_total += l_g_path |
|
self.losses['l_g_path'] = l_g_path.detach().mean() |
|
self.losses['path_length'] = path_lengths |
|
l_g_total.backward() |
|
optimizers['optimG'].step() |
|
|
|
# EMA |
|
self.model_ema(decay=0.5**(32 / (10 * 1000))) |
|
|
|
if self.current_iter % self.visual_iters: |
|
sample_z = [self.make_noise(1, 1)] |
|
sample, _ = self.nets['gen_ema'](sample_z) |
|
self.visual_items['fake_img_ema'] = sample |
|
|
|
self.current_iter += 1 |
|
|
|
def test_iter(self, metrics=None): |
|
self.nets['gen_ema'].eval() |
|
batch = self.real_img.shape[0] |
|
noises = [paddle.randn([batch, self.num_style_feat])] |
|
fake_img, _ = self.nets['gen_ema'](noises) |
|
with paddle.no_grad(): |
|
if metrics is not None: |
|
for metric in metrics.values(): |
|
metric.update(fake_img, self.real_img) |
|
self.nets['gen_ema'].train() |
|
|
|
class InferGenerator(paddle.nn.Layer): |
|
def set_generator(self, generator): |
|
self.generator = generator |
|
|
|
def forward(self, style, truncation): |
|
truncation_latent = self.generator.get_mean_style() |
|
out = self.generator( |
|
styles=style, |
|
truncation=truncation, |
|
truncation_latent=truncation_latent) |
|
return out[0] |
|
|
|
def export_model(self, |
|
export_model=None, |
|
output_dir=None, |
|
inputs_size=[[1, 1, 512], [1, 1]]): |
|
infer_generator = self.InferGenerator() |
|
infer_generator.set_generator(self.nets['gen']) |
|
style = paddle.rand(shape=inputs_size[0], dtype='float32') |
|
truncation = paddle.rand(shape=inputs_size[1], dtype='float32') |
|
if output_dir is None: |
|
output_dir = 'inference_model' |
|
paddle.jit.save( |
|
infer_generator, |
|
os.path.join(output_dir, "stylegan2model_gen"), |
|
input_spec=[style, truncation])
|
|
|