Update rcan_model.py

Alleviate gradient explosion, but convergence is still difficult
own
kongdebug 3 years ago committed by GitHub
parent 393e570a68
commit 7e74e85ab7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 25
      paddlers/custom_models/gan/rcan_model.py

@ -27,7 +27,6 @@ from ...models.ppgan.modules.init import reset_parameters
class RCANModel(BaseModel):
"""Base SR model for single image super-resolution.
"""
def __init__(self, generator, pixel_criterion=None, use_init_weight=False):
"""
Args:
@ -37,7 +36,8 @@ class RCANModel(BaseModel):
super(RCANModel, self).__init__()
self.nets['generator'] = build_generator(generator)
self.error_last = 1e8
self.batch = 0
if pixel_criterion:
self.pixel_criterion = build_criterion(pixel_criterion)
if use_init_weight:
@ -63,8 +63,21 @@ class RCANModel(BaseModel):
loss_pixel = self.pixel_criterion(self.output, self.gt)
self.losses['loss_pixel'] = loss_pixel
loss_pixel.backward()
optims['optim'].step()
skip_threshold = 1e6
if loss_pixel.item() < skip_threshold * self.error_last:
loss_pixel.backward()
optims['optim'].step()
else:
print('Skip this batch {}! (Loss: {})'.format(
self.batch + 1, loss_pixel.item()
))
self.batch += 1
if self.batch % 1000 == 0:
self.error_last = loss_pixel.item()/1000
print("update error_last:{}".format(self.error_last))
def test_iter(self, metrics=None):
self.nets['generator'].eval()
@ -86,8 +99,8 @@ class RCANModel(BaseModel):
def init_sr_weight(net):
def reset_func(m):
if hasattr(m, 'weight') and (
not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))):
if hasattr(m, 'weight') and (not isinstance(
m, (nn.BatchNorm, nn.BatchNorm2D))):
reset_parameters(m)
net.apply(reset_func)

Loading…
Cancel
Save