|
|
@ -43,7 +43,12 @@ class BaseRestorer(BaseModel): |
|
|
|
MIN_MAX = (0., 1.) |
|
|
|
MIN_MAX = (0., 1.) |
|
|
|
TEST_OUT_KEY = None |
|
|
|
TEST_OUT_KEY = None |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, model_name, losses=None, sr_factor=None, **params): |
|
|
|
def __init__(self, |
|
|
|
|
|
|
|
model_name, |
|
|
|
|
|
|
|
losses=None, |
|
|
|
|
|
|
|
sr_factor=None, |
|
|
|
|
|
|
|
min_max=None, |
|
|
|
|
|
|
|
**params): |
|
|
|
self.init_params = locals() |
|
|
|
self.init_params = locals() |
|
|
|
if 'with_net' in self.init_params: |
|
|
|
if 'with_net' in self.init_params: |
|
|
|
del self.init_params['with_net'] |
|
|
|
del self.init_params['with_net'] |
|
|
@ -55,6 +60,8 @@ class BaseRestorer(BaseModel): |
|
|
|
params.pop('with_net', None) |
|
|
|
params.pop('with_net', None) |
|
|
|
self.net = self.build_net(**params) |
|
|
|
self.net = self.build_net(**params) |
|
|
|
self.find_unused_parameters = True |
|
|
|
self.find_unused_parameters = True |
|
|
|
|
|
|
|
if min_max is None: |
|
|
|
|
|
|
|
self.min_max = self.MIN_MAX |
|
|
|
|
|
|
|
|
|
|
|
def build_net(self, **params): |
|
|
|
def build_net(self, **params): |
|
|
|
# Currently, only use models from cmres. |
|
|
|
# Currently, only use models from cmres. |
|
|
@ -283,11 +290,13 @@ class BaseRestorer(BaseModel): |
|
|
|
exit=True) |
|
|
|
exit=True) |
|
|
|
pretrained_dir = osp.join(save_dir, 'pretrain') |
|
|
|
pretrained_dir = osp.join(save_dir, 'pretrain') |
|
|
|
is_backbone_weights = pretrain_weights == 'IMAGENET' |
|
|
|
is_backbone_weights = pretrain_weights == 'IMAGENET' |
|
|
|
|
|
|
|
# XXX: Currently, do not load optimizer state dict. |
|
|
|
self.initialize_net( |
|
|
|
self.initialize_net( |
|
|
|
pretrain_weights=pretrain_weights, |
|
|
|
pretrain_weights=pretrain_weights, |
|
|
|
save_dir=pretrained_dir, |
|
|
|
save_dir=pretrained_dir, |
|
|
|
resume_checkpoint=resume_checkpoint, |
|
|
|
resume_checkpoint=resume_checkpoint, |
|
|
|
is_backbone_weights=is_backbone_weights) |
|
|
|
is_backbone_weights=is_backbone_weights, |
|
|
|
|
|
|
|
load_optim_state=False) |
|
|
|
|
|
|
|
|
|
|
|
self.train_loop( |
|
|
|
self.train_loop( |
|
|
|
num_epochs=num_epochs, |
|
|
|
num_epochs=num_epochs, |
|
|
@ -434,6 +443,7 @@ class BaseRestorer(BaseModel): |
|
|
|
|
|
|
|
|
|
|
|
return eval_metrics |
|
|
|
return eval_metrics |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad() |
|
|
|
def predict(self, img_file, transforms=None): |
|
|
|
def predict(self, img_file, transforms=None): |
|
|
|
""" |
|
|
|
""" |
|
|
|
Do inference. |
|
|
|
Do inference. |
|
|
@ -653,9 +663,9 @@ class BaseRestorer(BaseModel): |
|
|
|
if copy: |
|
|
|
if copy: |
|
|
|
im = im.copy() |
|
|
|
im = im.copy() |
|
|
|
if clip: |
|
|
|
if clip: |
|
|
|
im = np.clip(im, self.MIN_MAX[0], self.MIN_MAX[1]) |
|
|
|
im = np.clip(im, self.min_max[0], self.min_max[1]) |
|
|
|
im -= im.min() |
|
|
|
im -= self.min_max[0] |
|
|
|
im /= im.max() + 1e-32 |
|
|
|
im /= self.min_max[1] - self.min_max[0] |
|
|
|
if quantize: |
|
|
|
if quantize: |
|
|
|
im *= 255 |
|
|
|
im *= 255 |
|
|
|
im = im.astype('uint8') |
|
|
|
im = im.astype('uint8') |
|
|
@ -668,6 +678,7 @@ class DRN(BaseRestorer): |
|
|
|
def __init__(self, |
|
|
|
def __init__(self, |
|
|
|
losses=None, |
|
|
|
losses=None, |
|
|
|
sr_factor=4, |
|
|
|
sr_factor=4, |
|
|
|
|
|
|
|
min_max=None, |
|
|
|
scales=(2, 4), |
|
|
|
scales=(2, 4), |
|
|
|
n_blocks=30, |
|
|
|
n_blocks=30, |
|
|
|
n_feats=16, |
|
|
|
n_feats=16, |
|
|
@ -691,7 +702,11 @@ class DRN(BaseRestorer): |
|
|
|
self.dual_loss_weight = dual_loss_weight |
|
|
|
self.dual_loss_weight = dual_loss_weight |
|
|
|
self.scales = scales |
|
|
|
self.scales = scales |
|
|
|
super(DRN, self).__init__( |
|
|
|
super(DRN, self).__init__( |
|
|
|
model_name='DRN', losses=losses, sr_factor=sr_factor, **params) |
|
|
|
model_name='DRN', |
|
|
|
|
|
|
|
losses=losses, |
|
|
|
|
|
|
|
sr_factor=sr_factor, |
|
|
|
|
|
|
|
min_max=min_max, |
|
|
|
|
|
|
|
**params) |
|
|
|
|
|
|
|
|
|
|
|
def build_net(self, **params): |
|
|
|
def build_net(self, **params): |
|
|
|
from ppgan.modules.init import init_weights |
|
|
|
from ppgan.modules.init import init_weights |
|
|
@ -769,6 +784,7 @@ class LESRCNN(BaseRestorer): |
|
|
|
def __init__(self, |
|
|
|
def __init__(self, |
|
|
|
losses=None, |
|
|
|
losses=None, |
|
|
|
sr_factor=4, |
|
|
|
sr_factor=4, |
|
|
|
|
|
|
|
min_max=None, |
|
|
|
multi_scale=False, |
|
|
|
multi_scale=False, |
|
|
|
group=1, |
|
|
|
group=1, |
|
|
|
**params): |
|
|
|
**params): |
|
|
@ -778,7 +794,11 @@ class LESRCNN(BaseRestorer): |
|
|
|
'group': group |
|
|
|
'group': group |
|
|
|
}) |
|
|
|
}) |
|
|
|
super(LESRCNN, self).__init__( |
|
|
|
super(LESRCNN, self).__init__( |
|
|
|
model_name='LESRCNN', losses=losses, sr_factor=sr_factor, **params) |
|
|
|
model_name='LESRCNN', |
|
|
|
|
|
|
|
losses=losses, |
|
|
|
|
|
|
|
sr_factor=sr_factor, |
|
|
|
|
|
|
|
min_max=min_max, |
|
|
|
|
|
|
|
**params) |
|
|
|
|
|
|
|
|
|
|
|
def build_net(self, **params): |
|
|
|
def build_net(self, **params): |
|
|
|
net = ppgan.models.generators.LESRCNNGenerator(**params) |
|
|
|
net = ppgan.models.generators.LESRCNNGenerator(**params) |
|
|
@ -789,6 +809,7 @@ class ESRGAN(BaseRestorer): |
|
|
|
def __init__(self, |
|
|
|
def __init__(self, |
|
|
|
losses=None, |
|
|
|
losses=None, |
|
|
|
sr_factor=4, |
|
|
|
sr_factor=4, |
|
|
|
|
|
|
|
min_max=None, |
|
|
|
use_gan=True, |
|
|
|
use_gan=True, |
|
|
|
in_channels=3, |
|
|
|
in_channels=3, |
|
|
|
out_channels=3, |
|
|
|
out_channels=3, |
|
|
@ -805,7 +826,11 @@ class ESRGAN(BaseRestorer): |
|
|
|
}) |
|
|
|
}) |
|
|
|
self.use_gan = use_gan |
|
|
|
self.use_gan = use_gan |
|
|
|
super(ESRGAN, self).__init__( |
|
|
|
super(ESRGAN, self).__init__( |
|
|
|
model_name='ESRGAN', losses=losses, sr_factor=sr_factor, **params) |
|
|
|
model_name='ESRGAN', |
|
|
|
|
|
|
|
losses=losses, |
|
|
|
|
|
|
|
sr_factor=sr_factor, |
|
|
|
|
|
|
|
min_max=min_max, |
|
|
|
|
|
|
|
**params) |
|
|
|
|
|
|
|
|
|
|
|
def build_net(self, **params): |
|
|
|
def build_net(self, **params): |
|
|
|
from ppgan.modules.init import init_weights |
|
|
|
from ppgan.modules.init import init_weights |
|
|
@ -932,6 +957,7 @@ class RCAN(BaseRestorer): |
|
|
|
def __init__(self, |
|
|
|
def __init__(self, |
|
|
|
losses=None, |
|
|
|
losses=None, |
|
|
|
sr_factor=4, |
|
|
|
sr_factor=4, |
|
|
|
|
|
|
|
min_max=None, |
|
|
|
n_resgroups=10, |
|
|
|
n_resgroups=10, |
|
|
|
n_resblocks=20, |
|
|
|
n_resblocks=20, |
|
|
|
n_feats=64, |
|
|
|
n_feats=64, |
|
|
@ -950,4 +976,8 @@ class RCAN(BaseRestorer): |
|
|
|
'reduction': reduction |
|
|
|
'reduction': reduction |
|
|
|
}) |
|
|
|
}) |
|
|
|
super(RCAN, self).__init__( |
|
|
|
super(RCAN, self).__init__( |
|
|
|
model_name='RCAN', losses=losses, sr_factor=sr_factor, **params) |
|
|
|
model_name='RCAN', |
|
|
|
|
|
|
|
losses=losses, |
|
|
|
|
|
|
|
sr_factor=sr_factor, |
|
|
|
|
|
|
|
min_max=min_max, |
|
|
|
|
|
|
|
**params) |
|
|
|