You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
244 lines
9.2 KiB
244 lines
9.2 KiB
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
import paddle |
|
from .base_model import BaseModel |
|
|
|
from .builder import MODELS |
|
from .generators.builder import build_generator |
|
from .discriminators.builder import build_discriminator |
|
from .criterions import build_criterion |
|
|
|
from ..modules.init import init_weights |
|
from ..utils.image_pool import ImagePool |
|
|
|
|
|
@MODELS.register() |
|
class CycleGANModel(BaseModel): |
|
""" |
|
This class implements the CycleGAN model, for learning image-to-image translation without paired data. |
|
|
|
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf |
|
""" |
|
def __init__(self, |
|
generator, |
|
discriminator=None, |
|
cycle_criterion=None, |
|
idt_criterion=None, |
|
gan_criterion=None, |
|
pool_size=50, |
|
direction='a2b', |
|
lambda_a=10., |
|
lambda_b=10.): |
|
"""Initialize the CycleGAN class. |
|
|
|
Args: |
|
generator (dict): config of generator. |
|
discriminator (dict): config of discriminator. |
|
cycle_criterion (dict): config of cycle criterion. |
|
""" |
|
super(CycleGANModel, self).__init__() |
|
|
|
self.direction = direction |
|
|
|
self.lambda_a = lambda_a |
|
self.lambda_b = lambda_b |
|
# define generators |
|
# The naming is different from those used in the paper. |
|
# Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) |
|
self.nets['netG_A'] = build_generator(generator) |
|
self.nets['netG_B'] = build_generator(generator) |
|
init_weights(self.nets['netG_A']) |
|
init_weights(self.nets['netG_B']) |
|
|
|
# define discriminators |
|
if discriminator: |
|
self.nets['netD_A'] = build_discriminator(discriminator) |
|
self.nets['netD_B'] = build_discriminator(discriminator) |
|
init_weights(self.nets['netD_A']) |
|
init_weights(self.nets['netD_B']) |
|
|
|
# create image buffer to store previously generated images |
|
self.fake_A_pool = ImagePool(pool_size) |
|
# create image buffer to store previously generated images |
|
self.fake_B_pool = ImagePool(pool_size) |
|
|
|
# define loss functions |
|
if gan_criterion: |
|
self.gan_criterion = build_criterion(gan_criterion) |
|
|
|
if cycle_criterion: |
|
self.cycle_criterion = build_criterion(cycle_criterion) |
|
|
|
if idt_criterion: |
|
self.idt_criterion = build_criterion(idt_criterion) |
|
|
|
def setup_input(self, input): |
|
"""Unpack input data from the dataloader and perform necessary pre-processing steps. |
|
|
|
Args: |
|
input (dict): include the data itself and its metadata information. |
|
|
|
The option 'direction' can be used to swap domain A and domain B. |
|
""" |
|
|
|
AtoB = self.direction == 'a2b' |
|
|
|
if AtoB: |
|
if 'A' in input: |
|
self.real_A = paddle.to_tensor(input['A']) |
|
if 'B' in input: |
|
self.real_B = paddle.to_tensor(input['B']) |
|
else: |
|
if 'B' in input: |
|
self.real_A = paddle.to_tensor(input['B']) |
|
if 'A' in input: |
|
self.real_B = paddle.to_tensor(input['A']) |
|
|
|
if 'A_paths' in input: |
|
self.image_paths = input['A_paths'] |
|
elif 'B_paths' in input: |
|
self.image_paths = input['B_paths'] |
|
|
|
def forward(self): |
|
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" |
|
if hasattr(self, 'real_A'): |
|
self.fake_B = self.nets['netG_A'](self.real_A) # G_A(A) |
|
self.rec_A = self.nets['netG_B'](self.fake_B) # G_B(G_A(A)) |
|
|
|
# visual |
|
self.visual_items['real_A'] = self.real_A |
|
self.visual_items['fake_B'] = self.fake_B |
|
self.visual_items['rec_A'] = self.rec_A |
|
|
|
if hasattr(self, 'real_B'): |
|
self.fake_A = self.nets['netG_B'](self.real_B) # G_B(B) |
|
self.rec_B = self.nets['netG_A'](self.fake_A) # G_A(G_B(B)) |
|
|
|
# visual |
|
self.visual_items['real_B'] = self.real_B |
|
self.visual_items['fake_A'] = self.fake_A |
|
self.visual_items['rec_B'] = self.rec_B |
|
|
|
def backward_D_basic(self, netD, real, fake): |
|
"""Calculate GAN loss for the discriminator |
|
|
|
Args: |
|
netD (Layer): the discriminator D |
|
real (paddle.Tensor): real images |
|
fake (paddle.Tensor): images generated by a generator |
|
|
|
Return: |
|
the discriminator loss. |
|
|
|
We also call loss_D.backward() to calculate the gradients. |
|
""" |
|
# Real |
|
pred_real = netD(real) |
|
loss_D_real = self.gan_criterion(pred_real, True) |
|
# Fake |
|
pred_fake = netD(fake.detach()) |
|
loss_D_fake = self.gan_criterion(pred_fake, False) |
|
# Combined loss and calculate gradients |
|
loss_D = (loss_D_real + loss_D_fake) * 0.5 |
|
|
|
loss_D.backward() |
|
return loss_D |
|
|
|
def backward_D_A(self): |
|
"""Calculate GAN loss for discriminator D_A""" |
|
fake_B = self.fake_B_pool.query(self.fake_B) |
|
self.loss_D_A = self.backward_D_basic(self.nets['netD_A'], self.real_B, |
|
fake_B) |
|
self.losses['D_A_loss'] = self.loss_D_A |
|
|
|
def backward_D_B(self): |
|
"""Calculate GAN loss for discriminator D_B""" |
|
fake_A = self.fake_A_pool.query(self.fake_A) |
|
self.loss_D_B = self.backward_D_basic(self.nets['netD_B'], self.real_A, |
|
fake_A) |
|
self.losses['D_B_loss'] = self.loss_D_B |
|
|
|
def backward_G(self): |
|
"""Calculate the loss for generators G_A and G_B""" |
|
# Identity loss |
|
if self.idt_criterion: |
|
# G_A should be identity if real_B is fed: ||G_A(B) - B|| |
|
self.idt_A = self.nets['netG_A'](self.real_B) |
|
|
|
self.loss_idt_A = self.idt_criterion(self.idt_A, |
|
self.real_B) * self.lambda_b |
|
# G_B should be identity if real_A is fed: ||G_B(A) - A|| |
|
self.idt_B = self.nets['netG_B'](self.real_A) |
|
|
|
# visual |
|
self.visual_items['idt_A'] = self.idt_A |
|
self.visual_items['idt_B'] = self.idt_B |
|
|
|
self.loss_idt_B = self.idt_criterion(self.idt_B, |
|
self.real_A) * self.lambda_a |
|
else: |
|
self.loss_idt_A = 0 |
|
self.loss_idt_B = 0 |
|
|
|
# GAN loss D_A(G_A(A)) |
|
self.loss_G_A = self.gan_criterion(self.nets['netD_A'](self.fake_B), |
|
True) |
|
# GAN loss D_B(G_B(B)) |
|
self.loss_G_B = self.gan_criterion(self.nets['netD_B'](self.fake_A), |
|
True) |
|
# Forward cycle loss || G_B(G_A(A)) - A|| |
|
self.loss_cycle_A = self.cycle_criterion(self.rec_A, |
|
self.real_A) * self.lambda_a |
|
# Backward cycle loss || G_A(G_B(B)) - B|| |
|
self.loss_cycle_B = self.cycle_criterion(self.rec_B, |
|
self.real_B) * self.lambda_b |
|
|
|
self.losses['G_idt_A_loss'] = self.loss_idt_A |
|
self.losses['G_idt_B_loss'] = self.loss_idt_B |
|
self.losses['G_A_adv_loss'] = self.loss_G_A |
|
self.losses['G_B_adv_loss'] = self.loss_G_B |
|
self.losses['G_A_cycle_loss'] = self.loss_cycle_A |
|
self.losses['G_B_cycle_loss'] = self.loss_cycle_B |
|
# combined loss and calculate gradients |
|
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B |
|
|
|
self.loss_G.backward() |
|
|
|
def train_iter(self, optimizers=None): |
|
"""Calculate losses, gradients, and update network weights; called in every training iteration""" |
|
# forward |
|
# compute fake images and reconstruction images. |
|
self.forward() |
|
# G_A and G_B |
|
# Ds require no gradients when optimizing Gs |
|
self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']], |
|
False) |
|
# set G_A and G_B's gradients to zero |
|
optimizers['optimG'].clear_grad() |
|
# calculate gradients for G_A and G_B |
|
self.backward_G() |
|
# update G_A and G_B's weights |
|
self.optimizers['optimG'].step() |
|
# D_A and D_B |
|
self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']], True) |
|
|
|
# set D_A and D_B's gradients to zero |
|
optimizers['optimD'].clear_grad() |
|
# calculate gradients for D_A |
|
self.backward_D_A() |
|
# calculate graidents for D_B |
|
self.backward_D_B() |
|
# update D_A and D_B's weights |
|
optimizers['optimD'].step()
|
|
|