From 7e74e85ab7fca097103ad954d69f4733083cac33 Mon Sep 17 00:00:00 2001 From: kongdebug <52785738+kongdebug@users.noreply.github.com> Date: Sun, 3 Apr 2022 15:24:44 +0800 Subject: [PATCH] Update rcan_model.py Alleviate gradient explosion, but convergence is still difficult --- paddlers/custom_models/gan/rcan_model.py | 25 ++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/paddlers/custom_models/gan/rcan_model.py b/paddlers/custom_models/gan/rcan_model.py index 5ad1e3e..0676e75 100644 --- a/paddlers/custom_models/gan/rcan_model.py +++ b/paddlers/custom_models/gan/rcan_model.py @@ -27,7 +27,6 @@ from ...models.ppgan.modules.init import reset_parameters class RCANModel(BaseModel): """Base SR model for single image super-resolution. """ - def __init__(self, generator, pixel_criterion=None, use_init_weight=False): """ Args: @@ -37,7 +36,8 @@ class RCANModel(BaseModel): super(RCANModel, self).__init__() self.nets['generator'] = build_generator(generator) - + self.error_last = 1e8 + self.batch = 0 if pixel_criterion: self.pixel_criterion = build_criterion(pixel_criterion) if use_init_weight: @@ -63,8 +63,21 @@ class RCANModel(BaseModel): loss_pixel = self.pixel_criterion(self.output, self.gt) self.losses['loss_pixel'] = loss_pixel - loss_pixel.backward() - optims['optim'].step() + skip_threshold = 1e6 + + if loss_pixel.item() < skip_threshold * self.error_last: + loss_pixel.backward() + optims['optim'].step() + else: + print('Skip this batch {}! (Loss: {})'.format( + self.batch + 1, loss_pixel.item() + )) + self.batch += 1 + + if self.batch % 1000 == 0: + self.error_last = loss_pixel.item()/1000 + print("update error_last:{}".format(self.error_last)) + def test_iter(self, metrics=None): self.nets['generator'].eval() @@ -86,8 +99,8 @@ class RCANModel(BaseModel): def init_sr_weight(net): def reset_func(m): - if hasattr(m, 'weight') and ( - not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))): + if hasattr(m, 'weight') and (not isinstance( + m, (nn.BatchNorm, nn.BatchNorm2D))): reset_parameters(m) net.apply(reset_func)