|
|
|
@ -25,6 +25,7 @@ from paddle.static import InputSpec |
|
|
|
|
import paddlers |
|
|
|
|
import paddlers.models.ppgan as ppgan |
|
|
|
|
import paddlers.rs_models.res as cmres |
|
|
|
|
import paddlers.models.ppgan.metrics as metrics |
|
|
|
|
import paddlers.utils.logging as logging |
|
|
|
|
from paddlers.models import res_losses |
|
|
|
|
from paddlers.transforms import Resize, decode_image |
|
|
|
@ -32,12 +33,14 @@ from paddlers.transforms.functions import calc_hr_shape |
|
|
|
|
from paddlers.utils import get_single_card_bs |
|
|
|
|
from .base import BaseModel |
|
|
|
|
from .utils.res_adapters import GANAdapter, OptimizerAdapter |
|
|
|
|
from .utils.infer_nets import InferResNet |
|
|
|
|
|
|
|
|
|
__all__ = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseRestorer(BaseModel): |
|
|
|
|
MIN_MAX = (0., 255.) |
|
|
|
|
MIN_MAX = (0., 1.) |
|
|
|
|
TEST_OUT_KEY = None |
|
|
|
|
|
|
|
|
|
def __init__(self, model_name, losses=None, sr_factor=None, **params): |
|
|
|
|
self.init_params = locals() |
|
|
|
@ -63,9 +66,10 @@ class BaseRestorer(BaseModel): |
|
|
|
|
def _build_inference_net(self): |
|
|
|
|
# For GAN models, only the generator will be used for inference. |
|
|
|
|
if isinstance(self.net, GANAdapter): |
|
|
|
|
infer_net = self.net.generator |
|
|
|
|
infer_net = InferResNet( |
|
|
|
|
self.net.generator, out_key=self.TEST_OUT_KEY) |
|
|
|
|
else: |
|
|
|
|
infer_net = self.net |
|
|
|
|
infer_net = InferResNet(self.net, out_key=self.TEST_OUT_KEY) |
|
|
|
|
infer_net.eval() |
|
|
|
|
return infer_net |
|
|
|
|
|
|
|
|
@ -108,15 +112,18 @@ class BaseRestorer(BaseModel): |
|
|
|
|
outputs = OrderedDict() |
|
|
|
|
|
|
|
|
|
if mode == 'test': |
|
|
|
|
if isinstance(net, GANAdapter): |
|
|
|
|
net_out = net.generator(inputs[0]) |
|
|
|
|
else: |
|
|
|
|
net_out = net(inputs[0]) |
|
|
|
|
tar_shape = inputs[1] |
|
|
|
|
if self.status == 'Infer': |
|
|
|
|
net_out = net(inputs[0]) |
|
|
|
|
res_map_list = self._postprocess( |
|
|
|
|
net_out, tar_shape, transforms=inputs[2]) |
|
|
|
|
else: |
|
|
|
|
if isinstance(net, GANAdapter): |
|
|
|
|
net_out = net.generator(inputs[0]) |
|
|
|
|
else: |
|
|
|
|
net_out = net(inputs[0]) |
|
|
|
|
if self.TEST_OUT_KEY is not None: |
|
|
|
|
net_out = net_out[self.TEST_OUT_KEY] |
|
|
|
|
pred = self._postprocess( |
|
|
|
|
net_out, tar_shape, transforms=inputs[2]) |
|
|
|
|
res_map_list = [] |
|
|
|
@ -130,13 +137,15 @@ class BaseRestorer(BaseModel): |
|
|
|
|
net_out = net.generator(inputs[0]) |
|
|
|
|
else: |
|
|
|
|
net_out = net(inputs[0]) |
|
|
|
|
if self.TEST_OUT_KEY is not None: |
|
|
|
|
net_out = net_out[self.TEST_OUT_KEY] |
|
|
|
|
tar = inputs[1] |
|
|
|
|
tar_shape = [tar.shape[-2:]] |
|
|
|
|
pred = self._postprocess( |
|
|
|
|
net_out, tar_shape, transforms=inputs[2])[0] # NCHW |
|
|
|
|
pred = self._tensor_to_images(pred) |
|
|
|
|
outputs['pred'] = pred |
|
|
|
|
tar = self.tensor_to_images(tar) |
|
|
|
|
tar = self._tensor_to_images(tar) |
|
|
|
|
outputs['tar'] = tar |
|
|
|
|
|
|
|
|
|
if mode == 'train': |
|
|
|
@ -386,10 +395,11 @@ class BaseRestorer(BaseModel): |
|
|
|
|
self.eval_data_loader = self.build_data_loader( |
|
|
|
|
eval_dataset, batch_size=batch_size, mode='eval') |
|
|
|
|
# XXX: Hard-code crop_border and test_y_channel |
|
|
|
|
psnr = ppgan.metrics.PSNR(crop_border=4, test_y_channel=True) |
|
|
|
|
ssim = ppgan.metrics.SSIM(crop_border=4, test_y_channel=True) |
|
|
|
|
psnr = metrics.PSNR(crop_border=4, test_y_channel=True) |
|
|
|
|
ssim = metrics.SSIM(crop_border=4, test_y_channel=True) |
|
|
|
|
with paddle.no_grad(): |
|
|
|
|
for step, data in enumerate(self.eval_data_loader): |
|
|
|
|
data.append(eval_dataset.transforms.transforms) |
|
|
|
|
outputs = self.run(self.net, data, 'eval') |
|
|
|
|
psnr.update(outputs['pred'], outputs['tar']) |
|
|
|
|
ssim.update(outputs['pred'], outputs['tar']) |
|
|
|
@ -520,10 +530,9 @@ class BaseRestorer(BaseModel): |
|
|
|
|
def _postprocess(self, batch_pred, batch_tar_shape, transforms): |
|
|
|
|
batch_restore_list = BaseRestorer.get_transforms_shape_info( |
|
|
|
|
batch_tar_shape, transforms) |
|
|
|
|
if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer': |
|
|
|
|
if self.status == 'Infer': |
|
|
|
|
return self._infer_postprocess( |
|
|
|
|
batch_res_map=batch_pred[0], |
|
|
|
|
batch_restore_list=batch_restore_list) |
|
|
|
|
batch_res_map=batch_pred, batch_restore_list=batch_restore_list) |
|
|
|
|
results = [] |
|
|
|
|
if batch_pred.dtype == paddle.float32: |
|
|
|
|
mode = 'bilinear' |
|
|
|
@ -546,7 +555,7 @@ class BaseRestorer(BaseModel): |
|
|
|
|
|
|
|
|
|
def _infer_postprocess(self, batch_res_map, batch_restore_list): |
|
|
|
|
res_maps = [] |
|
|
|
|
for score_map, restore_list in zip(batch_res_map, batch_restore_list): |
|
|
|
|
for res_map, restore_list in zip(batch_res_map, batch_restore_list): |
|
|
|
|
if not isinstance(res_map, np.ndarray): |
|
|
|
|
res_map = paddle.unsqueeze(res_map, axis=0) |
|
|
|
|
for item in restore_list[::-1]: |
|
|
|
@ -557,15 +566,15 @@ class BaseRestorer(BaseModel): |
|
|
|
|
res_map, (w, h), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
else: |
|
|
|
|
res_map = F.interpolate( |
|
|
|
|
score_map, (h, w), |
|
|
|
|
res_map, (h, w), |
|
|
|
|
mode='bilinear', |
|
|
|
|
data_format='NHWC') |
|
|
|
|
elif item[0] == 'padding': |
|
|
|
|
x, y = item[2] |
|
|
|
|
if isinstance(res_map, np.ndarray): |
|
|
|
|
res_map = res_map[..., y:y + h, x:x + w] |
|
|
|
|
res_map = res_map[y:y + h, x:x + w] |
|
|
|
|
else: |
|
|
|
|
res_map = res_map[:, :, y:y + h, x:x + w] |
|
|
|
|
res_map = res_map[:, y:y + h, x:x + w, :] |
|
|
|
|
else: |
|
|
|
|
pass |
|
|
|
|
res_map = res_map.squeeze() |
|
|
|
@ -585,18 +594,25 @@ class BaseRestorer(BaseModel): |
|
|
|
|
def set_losses(self, losses): |
|
|
|
|
self.losses = losses |
|
|
|
|
|
|
|
|
|
def _tensor_to_images(self, tensor, squeeze=True, quantize=True): |
|
|
|
|
def _tensor_to_images(self, |
|
|
|
|
tensor, |
|
|
|
|
transpose=True, |
|
|
|
|
squeeze=True, |
|
|
|
|
quantize=True): |
|
|
|
|
if transpose: |
|
|
|
|
tensor = paddle.transpose(tensor, perm=[0, 2, 3, 1]) # NHWC |
|
|
|
|
if squeeze: |
|
|
|
|
tensor = tensor.squeeze() |
|
|
|
|
images = tensor.numpy().astype('float32') |
|
|
|
|
images = np.clip(images, self.MIN_MAX[0], self.MIN_MAX[1]) |
|
|
|
|
images = self._normalize(images, copy=True, quantize=quantize) |
|
|
|
|
images = self._normalize( |
|
|
|
|
images, copy=True, clip=True, quantize=quantize) |
|
|
|
|
return images |
|
|
|
|
|
|
|
|
|
def _normalize(self, im, copy=False, quantize=True): |
|
|
|
|
def _normalize(self, im, copy=False, clip=True, quantize=True): |
|
|
|
|
if copy: |
|
|
|
|
im = im.copy() |
|
|
|
|
if clip: |
|
|
|
|
im = np.clip(im, self.MIN_MAX[0], self.MIN_MAX[1]) |
|
|
|
|
im -= im.min() |
|
|
|
|
im /= im.max() + 1e-32 |
|
|
|
|
if quantize: |
|
|
|
@ -605,32 +621,9 @@ class BaseRestorer(BaseModel): |
|
|
|
|
return im |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RCAN(BaseRestorer): |
|
|
|
|
def __init__(self, |
|
|
|
|
losses=None, |
|
|
|
|
sr_factor=4, |
|
|
|
|
n_resgroups=10, |
|
|
|
|
n_resblocks=20, |
|
|
|
|
n_feats=64, |
|
|
|
|
n_colors=3, |
|
|
|
|
rgb_range=255, |
|
|
|
|
kernel_size=3, |
|
|
|
|
reduction=16, |
|
|
|
|
**params): |
|
|
|
|
params.update({ |
|
|
|
|
'n_resgroups': n_resgroups, |
|
|
|
|
'n_resblocks': n_resblocks, |
|
|
|
|
'n_feats': n_feats, |
|
|
|
|
'n_colors': n_colors, |
|
|
|
|
'rgb_range': rgb_range, |
|
|
|
|
'kernel_size': kernel_size, |
|
|
|
|
'reduction': reduction |
|
|
|
|
}) |
|
|
|
|
super(RCAN, self).__init__( |
|
|
|
|
model_name='RCAN', losses=losses, sr_factor=sr_factor, **params) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DRN(BaseRestorer): |
|
|
|
|
TEST_OUT_KEY = -1 |
|
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
|
losses=None, |
|
|
|
|
sr_factor=4, |
|
|
|
@ -638,8 +631,10 @@ class DRN(BaseRestorer): |
|
|
|
|
n_blocks=30, |
|
|
|
|
n_feats=16, |
|
|
|
|
n_colors=3, |
|
|
|
|
rgb_range=255, |
|
|
|
|
rgb_range=1.0, |
|
|
|
|
negval=0.2, |
|
|
|
|
lq_loss_weight=0.1, |
|
|
|
|
dual_loss_weight=0.1, |
|
|
|
|
**params): |
|
|
|
|
if sr_factor != max(scale): |
|
|
|
|
raise ValueError(f"`sr_factor` must be equal to `max(scale)`.") |
|
|
|
@ -651,12 +646,80 @@ class DRN(BaseRestorer): |
|
|
|
|
'rgb_range': rgb_range, |
|
|
|
|
'negval': negval |
|
|
|
|
}) |
|
|
|
|
self.lq_loss_weight = lq_loss_weight |
|
|
|
|
self.dual_loss_weight = dual_loss_weight |
|
|
|
|
super(DRN, self).__init__( |
|
|
|
|
model_name='DRN', losses=losses, sr_factor=sr_factor, **params) |
|
|
|
|
|
|
|
|
|
def build_net(self, **params): |
|
|
|
|
net = ppgan.models.generators.DRNGenerator(**params) |
|
|
|
|
return net |
|
|
|
|
from ppgan.modules.init import init_weights |
|
|
|
|
generators = [ppgan.models.generators.DRNGenerator(**params)] |
|
|
|
|
init_weights(generators[-1]) |
|
|
|
|
for scale in params['scale']: |
|
|
|
|
dual_model = ppgan.models.generators.drn.DownBlock( |
|
|
|
|
params['negval'], params['n_feats'], params['n_colors'], 2) |
|
|
|
|
generators.append(dual_model) |
|
|
|
|
init_weights(generators[-1]) |
|
|
|
|
return GANAdapter(generators, []) |
|
|
|
|
|
|
|
|
|
def default_optimizer(self, parameters, *args, **kwargs): |
|
|
|
|
optims_g = [ |
|
|
|
|
super(DRN, self).default_optimizer(params_g, *args, **kwargs) |
|
|
|
|
for params_g in parameters['params_g'] |
|
|
|
|
] |
|
|
|
|
return OptimizerAdapter(*optims_g) |
|
|
|
|
|
|
|
|
|
def run_gan(self, net, inputs, mode, gan_mode='forward_primary'): |
|
|
|
|
if mode != 'train': |
|
|
|
|
raise ValueError("`mode` is not 'train'.") |
|
|
|
|
outputs = OrderedDict() |
|
|
|
|
if gan_mode == 'forward_primary': |
|
|
|
|
sr = net.generator(inputs[0]) |
|
|
|
|
lr = [inputs[0]] |
|
|
|
|
lr.extend([ |
|
|
|
|
F.interpolate( |
|
|
|
|
inputs[0], scale_factor=s, mode='bicubic') |
|
|
|
|
for s in net.generator.scale[:-1] |
|
|
|
|
]) |
|
|
|
|
loss = self.losses(sr[-1], inputs[1]) |
|
|
|
|
for i in range(1, len(sr)): |
|
|
|
|
if self.lq_loss_weight > 0: |
|
|
|
|
loss += self.losses(sr[i - 1 - len(sr)], |
|
|
|
|
lr[i - len(sr)]) * self.lq_loss_weight |
|
|
|
|
outputs['loss_prim'] = loss |
|
|
|
|
outputs['sr'] = sr |
|
|
|
|
outputs['lr'] = lr |
|
|
|
|
elif gan_mode == 'forward_dual': |
|
|
|
|
sr, lr = inputs[0], inputs[1] |
|
|
|
|
sr2lr = [] |
|
|
|
|
n_scales = len(net.generator.scale) |
|
|
|
|
for i in range(n_scales): |
|
|
|
|
sr2lr_i = net.generators[1 + i](sr[i - n_scales]) |
|
|
|
|
sr2lr.append(sr2lr_i) |
|
|
|
|
loss = self.losses(sr2lr[0], lr[0]) |
|
|
|
|
for i in range(1, n_scales): |
|
|
|
|
if self.dual_loss_weight > 0.0: |
|
|
|
|
loss += self.losses(sr2lr[i], lr[i]) * self.dual_loss_weight |
|
|
|
|
outputs['loss_dual'] = loss |
|
|
|
|
else: |
|
|
|
|
raise ValueError("Invalid `gan_mode`!") |
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
def train_step(self, step, data, net): |
|
|
|
|
outputs = self.run_gan( |
|
|
|
|
net, data, mode='train', gan_mode='forward_primary') |
|
|
|
|
outputs.update( |
|
|
|
|
self.run_gan( |
|
|
|
|
net, (outputs['sr'], outputs['lr']), |
|
|
|
|
mode='train', |
|
|
|
|
gan_mode='forward_dual')) |
|
|
|
|
self.optimizer.clear_grad() |
|
|
|
|
(outputs['loss_prim'] + outputs['loss_dual']).backward() |
|
|
|
|
self.optimizer.step() |
|
|
|
|
return { |
|
|
|
|
'loss_prim': outputs['loss_prim'], |
|
|
|
|
'loss_dual': outputs['loss_dual'] |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LESRCNN(BaseRestorer): |
|
|
|
@ -680,8 +743,6 @@ class LESRCNN(BaseRestorer): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ESRGAN(BaseRestorer): |
|
|
|
|
MIN_MAX = (0., 1.) |
|
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
|
losses=None, |
|
|
|
|
sr_factor=4, |
|
|
|
@ -704,7 +765,9 @@ class ESRGAN(BaseRestorer): |
|
|
|
|
model_name='ESRGAN', losses=losses, sr_factor=sr_factor, **params) |
|
|
|
|
|
|
|
|
|
def build_net(self, **params): |
|
|
|
|
from ppgan.modules.init import init_weights |
|
|
|
|
generator = ppgan.models.generators.RRDBNet(**params) |
|
|
|
|
init_weights(generator) |
|
|
|
|
if self.use_gan: |
|
|
|
|
discriminator = ppgan.models.discriminators.VGGDiscriminator128( |
|
|
|
|
in_channels=params['out_nc'], num_feat=64) |
|
|
|
@ -716,10 +779,13 @@ class ESRGAN(BaseRestorer): |
|
|
|
|
|
|
|
|
|
def default_loss(self): |
|
|
|
|
if self.use_gan: |
|
|
|
|
self.losses = { |
|
|
|
|
return { |
|
|
|
|
'pixel': res_losses.L1Loss(loss_weight=0.01), |
|
|
|
|
'perceptual': |
|
|
|
|
res_losses.PerceptualLoss(layer_weights={'34': 1.0}), |
|
|
|
|
'perceptual': res_losses.PerceptualLoss( |
|
|
|
|
layer_weights={'34': 1.0}, |
|
|
|
|
perceptual_weight=1.0, |
|
|
|
|
style_weight=0.0, |
|
|
|
|
norm_img=False), |
|
|
|
|
'gan': res_losses.GANLoss( |
|
|
|
|
gan_mode='vanilla', loss_weight=0.005) |
|
|
|
|
} |
|
|
|
@ -734,7 +800,7 @@ class ESRGAN(BaseRestorer): |
|
|
|
|
parameters['params_d'][0], *args, **kwargs) |
|
|
|
|
return OptimizerAdapter(optim_g, optim_d) |
|
|
|
|
else: |
|
|
|
|
return super(ESRGAN, self).default_optimizer(params, *args, |
|
|
|
|
return super(ESRGAN, self).default_optimizer(parameters, *args, |
|
|
|
|
**kwargs) |
|
|
|
|
|
|
|
|
|
def run_gan(self, net, inputs, mode, gan_mode='forward_g'): |
|
|
|
@ -744,8 +810,8 @@ class ESRGAN(BaseRestorer): |
|
|
|
|
if gan_mode == 'forward_g': |
|
|
|
|
loss_g = 0 |
|
|
|
|
g_pred = net.generator(inputs[0]) |
|
|
|
|
loss_pix = self.losses['pixel'](g_pred, tar) |
|
|
|
|
loss_perc, loss_sty = self.losses['perceptual'](g_pred, tar) |
|
|
|
|
loss_pix = self.losses['pixel'](g_pred, inputs[1]) |
|
|
|
|
loss_perc, loss_sty = self.losses['perceptual'](g_pred, inputs[1]) |
|
|
|
|
loss_g += loss_pix |
|
|
|
|
if loss_perc is not None: |
|
|
|
|
loss_g += loss_perc |
|
|
|
@ -767,14 +833,14 @@ class ESRGAN(BaseRestorer): |
|
|
|
|
elif gan_mode == 'forward_d': |
|
|
|
|
self._set_requires_grad(net.discriminator, True) |
|
|
|
|
# Real |
|
|
|
|
fake_d_pred = net.discriminator(data[0]).detach() |
|
|
|
|
real_d_pred = net.discriminator(data[1]) |
|
|
|
|
fake_d_pred = net.discriminator(inputs[0]).detach() |
|
|
|
|
real_d_pred = net.discriminator(inputs[1]) |
|
|
|
|
loss_d_real = self.losses['gan']( |
|
|
|
|
real_d_pred - paddle.mean(fake_d_pred), True, |
|
|
|
|
is_disc=True) * 0.5 |
|
|
|
|
# Fake |
|
|
|
|
fake_d_pred = self.nets['discriminator'](self.output.detach()) |
|
|
|
|
loss_d_fake = self.gan_criterion( |
|
|
|
|
fake_d_pred = net.discriminator(inputs[0].detach()) |
|
|
|
|
loss_d_fake = self.losses['gan']( |
|
|
|
|
fake_d_pred - paddle.mean(real_d_pred.detach()), |
|
|
|
|
False, |
|
|
|
|
is_disc=True) * 0.5 |
|
|
|
@ -802,30 +868,43 @@ class ESRGAN(BaseRestorer): |
|
|
|
|
outputs['loss_d'].backward() |
|
|
|
|
optim_d.step() |
|
|
|
|
|
|
|
|
|
outputs['loss'] = outupts['loss_g_pps'] + outputs[ |
|
|
|
|
outputs['loss'] = outputs['loss_g_pps'] + outputs[ |
|
|
|
|
'loss_g_gan'] + outputs['loss_d'] |
|
|
|
|
|
|
|
|
|
if isinstance(optim_g._learning_rate, |
|
|
|
|
paddle.optimizer.lr.LRScheduler): |
|
|
|
|
# If ReduceOnPlateau is used as the scheduler, use the loss value as the metric. |
|
|
|
|
if isinstance(optim_g._learning_rate, |
|
|
|
|
paddle.optimizer.lr.ReduceOnPlateau): |
|
|
|
|
optim_g._learning_rate.step(loss.item()) |
|
|
|
|
else: |
|
|
|
|
optim_g._learning_rate.step() |
|
|
|
|
|
|
|
|
|
if isinstance(optim_d._learning_rate, |
|
|
|
|
paddle.optimizer.lr.LRScheduler): |
|
|
|
|
if isinstance(optim_d._learning_rate, |
|
|
|
|
paddle.optimizer.lr.ReduceOnPlateau): |
|
|
|
|
optim_d._learning_rate.step(loss.item()) |
|
|
|
|
else: |
|
|
|
|
optim_d._learning_rate.step() |
|
|
|
|
|
|
|
|
|
return outputs |
|
|
|
|
return { |
|
|
|
|
'loss': outputs['loss'], |
|
|
|
|
'loss_g_pps': outputs['loss_g_pps'], |
|
|
|
|
'loss_g_gan': outputs['loss_g_gan'], |
|
|
|
|
'loss_d': outputs['loss_d'] |
|
|
|
|
} |
|
|
|
|
else: |
|
|
|
|
super(ESRGAN, self).train_step(step, data, net) |
|
|
|
|
return super(ESRGAN, self).train_step(step, data, net) |
|
|
|
|
|
|
|
|
|
def _set_requires_grad(self, net, requires_grad): |
|
|
|
|
for p in net.parameters(): |
|
|
|
|
p.trainable = requires_grad |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RCAN(BaseRestorer): |
|
|
|
|
def __init__(self, |
|
|
|
|
losses=None, |
|
|
|
|
sr_factor=4, |
|
|
|
|
n_resgroups=10, |
|
|
|
|
n_resblocks=20, |
|
|
|
|
n_feats=64, |
|
|
|
|
n_colors=3, |
|
|
|
|
rgb_range=1.0, |
|
|
|
|
kernel_size=3, |
|
|
|
|
reduction=16, |
|
|
|
|
**params): |
|
|
|
|
params.update({ |
|
|
|
|
'n_resgroups': n_resgroups, |
|
|
|
|
'n_resblocks': n_resblocks, |
|
|
|
|
'n_feats': n_feats, |
|
|
|
|
'n_colors': n_colors, |
|
|
|
|
'rgb_range': rgb_range, |
|
|
|
|
'kernel_size': kernel_size, |
|
|
|
|
'reduction': reduction |
|
|
|
|
}) |
|
|
|
|
super(RCAN, self).__init__( |
|
|
|
|
model_name='RCAN', losses=losses, sr_factor=sr_factor, **params) |
|
|
|
|