|
|
|
@ -27,7 +27,6 @@ from ...models.ppgan.modules.init import reset_parameters |
|
|
|
|
class RCANModel(BaseModel): |
|
|
|
|
"""Base SR model for single image super-resolution. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, generator, pixel_criterion=None, use_init_weight=False): |
|
|
|
|
""" |
|
|
|
|
Args: |
|
|
|
@ -37,7 +36,8 @@ class RCANModel(BaseModel): |
|
|
|
|
super(RCANModel, self).__init__() |
|
|
|
|
|
|
|
|
|
self.nets['generator'] = build_generator(generator) |
|
|
|
|
|
|
|
|
|
self.error_last = 1e8 |
|
|
|
|
self.batch = 0 |
|
|
|
|
if pixel_criterion: |
|
|
|
|
self.pixel_criterion = build_criterion(pixel_criterion) |
|
|
|
|
if use_init_weight: |
|
|
|
@ -63,8 +63,21 @@ class RCANModel(BaseModel): |
|
|
|
|
loss_pixel = self.pixel_criterion(self.output, self.gt) |
|
|
|
|
self.losses['loss_pixel'] = loss_pixel |
|
|
|
|
|
|
|
|
|
loss_pixel.backward() |
|
|
|
|
optims['optim'].step() |
|
|
|
|
skip_threshold = 1e6 |
|
|
|
|
|
|
|
|
|
if loss_pixel.item() < skip_threshold * self.error_last: |
|
|
|
|
loss_pixel.backward() |
|
|
|
|
optims['optim'].step() |
|
|
|
|
else: |
|
|
|
|
print('Skip this batch {}! (Loss: {})'.format( |
|
|
|
|
self.batch + 1, loss_pixel.item() |
|
|
|
|
)) |
|
|
|
|
self.batch += 1 |
|
|
|
|
|
|
|
|
|
if self.batch % 1000 == 0: |
|
|
|
|
self.error_last = loss_pixel.item()/1000 |
|
|
|
|
print("update error_last:{}".format(self.error_last)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_iter(self, metrics=None): |
|
|
|
|
self.nets['generator'].eval() |
|
|
|
@ -86,8 +99,8 @@ class RCANModel(BaseModel): |
|
|
|
|
|
|
|
|
|
def init_sr_weight(net): |
|
|
|
|
def reset_func(m): |
|
|
|
|
if hasattr(m, 'weight') and ( |
|
|
|
|
not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))): |
|
|
|
|
if hasattr(m, 'weight') and (not isinstance( |
|
|
|
|
m, (nn.BatchNorm, nn.BatchNorm2D))): |
|
|
|
|
reset_parameters(m) |
|
|
|
|
|
|
|
|
|
net.apply(reset_func) |
|
|
|
|