own
Bobholamovic 3 years ago
parent 5dcc6cd078
commit 752d8e41cf
  1. 21
      paddlers/deploy/predictor.py
  2. 0
      paddlers/models/ppdet/metrics/json_results.py
  3. 0
      paddlers/models/ppdet/modeling/architectures/centernet.py
  4. 0
      paddlers/models/ppdet/modeling/architectures/fairmot.py
  5. 0
      paddlers/models/ppdet/modeling/backbones/darknet.py
  6. 0
      paddlers/models/ppdet/modeling/backbones/dla.py
  7. 0
      paddlers/models/ppdet/modeling/backbones/resnet.py
  8. 0
      paddlers/models/ppdet/modeling/backbones/vgg.py
  9. 0
      paddlers/models/ppdet/modeling/heads/centernet_head.py
  10. 0
      paddlers/models/ppdet/modeling/losses/fairmot_loss.py
  11. 0
      paddlers/models/ppdet/modeling/necks/centernet_fpn.py
  12. 0
      paddlers/models/ppdet/modeling/reid/fairmot_embedding_head.py
  13. 0
      paddlers/models/ppseg/models/losses/focal_loss.py
  14. 0
      paddlers/models/ppseg/models/losses/kl_loss.py
  15. 27
      paddlers/rs_models/res/generators/param_init.py
  16. 24
      paddlers/rs_models/res/generators/rcan.py
  17. 23
      paddlers/tasks/base.py
  18. 8
      paddlers/tasks/change_detector.py
  19. 241
      paddlers/tasks/restorer.py
  20. 12
      paddlers/tasks/segmenter.py
  21. 39
      paddlers/tasks/utils/infer_nets.py
  22. 6
      paddlers/tasks/utils/res_adapters.py
  23. 5
      paddlers/transforms/operators.py
  24. 2
      paddlers/utils/__init__.py
  25. 33
      paddlers/utils/utils.py
  26. 1
      tutorials/train/README.md
  27. 15
      tutorials/train/image_restoration/drn.py
  28. 15
      tutorials/train/image_restoration/esrgan.py
  29. 15
      tutorials/train/image_restoration/lesrcnn.py

@ -163,17 +163,27 @@ class Predictor(object):
'image2': preprocessed_samples[1],
'ori_shape': preprocessed_samples[2]
}
elif self._model.model_type == 'restorer':
preprocessed_samples = {
'image': preprocessed_samples[0],
'tar_shape': preprocessed_samples[1]
}
else:
logging.error(
"Invalid model type {}".format(self._model.model_type),
exit=True)
return preprocessed_samples
def postprocess(self, net_outputs, topk=1, ori_shape=None, transforms=None):
def postprocess(self,
net_outputs,
topk=1,
ori_shape=None,
tar_shape=None,
transforms=None):
if self._model.model_type == 'classifier':
true_topk = min(self._model.num_classes, topk)
if self._model._postprocess is None:
self._model.build_postprocess_from_labels(topk)
self._model.build_postprocess_from_labels(true_topk)
# XXX: Convert ndarray to tensor as self._model._postprocess requires
assert len(net_outputs) == 1
net_outputs = paddle.to_tensor(net_outputs[0])
@ -201,6 +211,12 @@ class Predictor(object):
for k, v in zip(['bbox', 'bbox_num', 'mask'], net_outputs)
}
preds = self._model._postprocess(net_outputs)
elif self._model.model_type == 'restorer':
res_maps = self._model._postprocess(
net_outputs[0],
batch_tar_shape=tar_shape,
transforms=transforms.transforms)
preds = [{'res_map': res_map} for res_map in res_maps]
else:
logging.error(
"Invalid model type {}.".format(self._model.model_type),
@ -244,6 +260,7 @@ class Predictor(object):
net_outputs,
topk,
ori_shape=preprocessed_input.get('ori_shape', None),
tar_shape=preprocessed_input.get('tar_shape', None),
transforms=transforms)
self.timer.postprocess_time_s.end(iter_num=len(images))

@ -0,0 +1,27 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from paddlers.models.ppgan.modules.init import reset_parameters
def init_sr_weight(net):
def reset_func(m):
if hasattr(m, 'weight') and (
not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D))):
reset_parameters(m)
net.apply(reset_func)

@ -1,9 +1,26 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Based on https://github.com/kongdebug/RCAN-Paddle
import math
import paddle
import paddle.nn as nn
from .param_init import init_sr_weight
def default_conv(in_channels, out_channels, kernel_size, bias=True):
weight_attr = paddle.ParamAttr(
@ -61,8 +78,10 @@ class RCAB(nn.Layer):
bias=True,
bn=False,
act=nn.ReLU(),
res_scale=1):
res_scale=1,
use_init_weight=False):
super(RCAB, self).__init__()
modules_body = []
for i in range(2):
modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
@ -72,6 +91,9 @@ class RCAB(nn.Layer):
self.body = nn.Sequential(*modules_body)
self.res_scale = res_scale
if use_init_weight:
init_sr_weight(self)
def forward(self, x):
res = self.body(x)
res += x

@ -30,10 +30,10 @@ from paddleslim import L1NormFilterPruner, FPGMFilterPruner
import paddlers
import paddlers.utils.logging as logging
from paddlers.utils import (seconds_to_hms, get_single_card_bs, dict2str,
get_pretrain_weights, load_pretrain_weights,
load_checkpoint, SmoothedValue, TrainingStats,
_get_shared_memory_size_in_M, EarlyStop)
from paddlers.utils import (
seconds_to_hms, get_single_card_bs, dict2str, get_pretrain_weights,
load_pretrain_weights, load_checkpoint, SmoothedValue, TrainingStats,
_get_shared_memory_size_in_M, EarlyStop, to_data_parallel, scheduler_step)
from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
@ -320,10 +320,10 @@ class BaseModel(metaclass=ModelMeta):
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
):
paddle.distributed.init_parallel_env()
ddp_net = paddle.DataParallel(
ddp_net = to_data_parallel(
self.net, find_unused_parameters=find_unused_parameters)
else:
ddp_net = paddle.DataParallel(
ddp_net = to_data_parallel(
self.net, find_unused_parameters=find_unused_parameters)
if use_vdl:
@ -368,6 +368,8 @@ class BaseModel(metaclass=ModelMeta):
else:
outputs = self.train_step(step, data, self.net)
scheduler_step(self.optimizer)
train_avg_metrics.update(outputs)
lr = self.optimizer.get_lr()
outputs['lr'] = lr
@ -662,15 +664,6 @@ class BaseModel(metaclass=ModelMeta):
self.optimizer.step()
self.optimizer.clear_grad()
if isinstance(self.optimizer._learning_rate,
paddle.optimizer.lr.LRScheduler):
# If ReduceOnPlateau is used as the scheduler, use the loss value as the metric.
if isinstance(self.optimizer._learning_rate,
paddle.optimizer.lr.ReduceOnPlateau):
self.optimizer._learning_rate.step(loss.item())
else:
self.optimizer._learning_rate.step()
return outputs
def _check_transforms(self, transforms, mode):

@ -796,11 +796,11 @@ class BaseChangeDetector(BaseModel):
elif item[0] == 'padding':
x, y = item[2]
if isinstance(label_map, np.ndarray):
label_map = label_map[..., y:y + h, x:x + w]
score_map = score_map[..., y:y + h, x:x + w]
label_map = label_map[y:y + h, x:x + w]
score_map = score_map[y:y + h, x:x + w]
else:
label_map = label_map[:, :, y:y + h, x:x + w]
score_map = score_map[:, :, y:y + h, x:x + w]
label_map = label_map[:, y:y + h, x:x + w, :]
score_map = score_map[:, y:y + h, x:x + w, :]
else:
pass
label_map = label_map.squeeze()

@ -25,6 +25,7 @@ from paddle.static import InputSpec
import paddlers
import paddlers.models.ppgan as ppgan
import paddlers.rs_models.res as cmres
import paddlers.models.ppgan.metrics as metrics
import paddlers.utils.logging as logging
from paddlers.models import res_losses
from paddlers.transforms import Resize, decode_image
@ -32,12 +33,14 @@ from paddlers.transforms.functions import calc_hr_shape
from paddlers.utils import get_single_card_bs
from .base import BaseModel
from .utils.res_adapters import GANAdapter, OptimizerAdapter
from .utils.infer_nets import InferResNet
__all__ = []
class BaseRestorer(BaseModel):
MIN_MAX = (0., 255.)
MIN_MAX = (0., 1.)
TEST_OUT_KEY = None
def __init__(self, model_name, losses=None, sr_factor=None, **params):
self.init_params = locals()
@ -63,9 +66,10 @@ class BaseRestorer(BaseModel):
def _build_inference_net(self):
# For GAN models, only the generator will be used for inference.
if isinstance(self.net, GANAdapter):
infer_net = self.net.generator
infer_net = InferResNet(
self.net.generator, out_key=self.TEST_OUT_KEY)
else:
infer_net = self.net
infer_net = InferResNet(self.net, out_key=self.TEST_OUT_KEY)
infer_net.eval()
return infer_net
@ -108,15 +112,18 @@ class BaseRestorer(BaseModel):
outputs = OrderedDict()
if mode == 'test':
if isinstance(net, GANAdapter):
net_out = net.generator(inputs[0])
else:
net_out = net(inputs[0])
tar_shape = inputs[1]
if self.status == 'Infer':
net_out = net(inputs[0])
res_map_list = self._postprocess(
net_out, tar_shape, transforms=inputs[2])
else:
if isinstance(net, GANAdapter):
net_out = net.generator(inputs[0])
else:
net_out = net(inputs[0])
if self.TEST_OUT_KEY is not None:
net_out = net_out[self.TEST_OUT_KEY]
pred = self._postprocess(
net_out, tar_shape, transforms=inputs[2])
res_map_list = []
@ -130,13 +137,15 @@ class BaseRestorer(BaseModel):
net_out = net.generator(inputs[0])
else:
net_out = net(inputs[0])
if self.TEST_OUT_KEY is not None:
net_out = net_out[self.TEST_OUT_KEY]
tar = inputs[1]
tar_shape = [tar.shape[-2:]]
pred = self._postprocess(
net_out, tar_shape, transforms=inputs[2])[0] # NCHW
pred = self._tensor_to_images(pred)
outputs['pred'] = pred
tar = self.tensor_to_images(tar)
tar = self._tensor_to_images(tar)
outputs['tar'] = tar
if mode == 'train':
@ -386,10 +395,11 @@ class BaseRestorer(BaseModel):
self.eval_data_loader = self.build_data_loader(
eval_dataset, batch_size=batch_size, mode='eval')
# XXX: Hard-code crop_border and test_y_channel
psnr = ppgan.metrics.PSNR(crop_border=4, test_y_channel=True)
ssim = ppgan.metrics.SSIM(crop_border=4, test_y_channel=True)
psnr = metrics.PSNR(crop_border=4, test_y_channel=True)
ssim = metrics.SSIM(crop_border=4, test_y_channel=True)
with paddle.no_grad():
for step, data in enumerate(self.eval_data_loader):
data.append(eval_dataset.transforms.transforms)
outputs = self.run(self.net, data, 'eval')
psnr.update(outputs['pred'], outputs['tar'])
ssim.update(outputs['pred'], outputs['tar'])
@ -520,10 +530,9 @@ class BaseRestorer(BaseModel):
def _postprocess(self, batch_pred, batch_tar_shape, transforms):
batch_restore_list = BaseRestorer.get_transforms_shape_info(
batch_tar_shape, transforms)
if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
if self.status == 'Infer':
return self._infer_postprocess(
batch_res_map=batch_pred[0],
batch_restore_list=batch_restore_list)
batch_res_map=batch_pred, batch_restore_list=batch_restore_list)
results = []
if batch_pred.dtype == paddle.float32:
mode = 'bilinear'
@ -546,7 +555,7 @@ class BaseRestorer(BaseModel):
def _infer_postprocess(self, batch_res_map, batch_restore_list):
res_maps = []
for score_map, restore_list in zip(batch_res_map, batch_restore_list):
for res_map, restore_list in zip(batch_res_map, batch_restore_list):
if not isinstance(res_map, np.ndarray):
res_map = paddle.unsqueeze(res_map, axis=0)
for item in restore_list[::-1]:
@ -557,15 +566,15 @@ class BaseRestorer(BaseModel):
res_map, (w, h), interpolation=cv2.INTER_LINEAR)
else:
res_map = F.interpolate(
score_map, (h, w),
res_map, (h, w),
mode='bilinear',
data_format='NHWC')
elif item[0] == 'padding':
x, y = item[2]
if isinstance(res_map, np.ndarray):
res_map = res_map[..., y:y + h, x:x + w]
res_map = res_map[y:y + h, x:x + w]
else:
res_map = res_map[:, :, y:y + h, x:x + w]
res_map = res_map[:, y:y + h, x:x + w, :]
else:
pass
res_map = res_map.squeeze()
@ -585,18 +594,25 @@ class BaseRestorer(BaseModel):
def set_losses(self, losses):
self.losses = losses
def _tensor_to_images(self, tensor, squeeze=True, quantize=True):
def _tensor_to_images(self,
tensor,
transpose=True,
squeeze=True,
quantize=True):
if transpose:
tensor = paddle.transpose(tensor, perm=[0, 2, 3, 1]) # NHWC
if squeeze:
tensor = tensor.squeeze()
images = tensor.numpy().astype('float32')
images = np.clip(images, self.MIN_MAX[0], self.MIN_MAX[1])
images = self._normalize(images, copy=True, quantize=quantize)
images = self._normalize(
images, copy=True, clip=True, quantize=quantize)
return images
def _normalize(self, im, copy=False, quantize=True):
def _normalize(self, im, copy=False, clip=True, quantize=True):
if copy:
im = im.copy()
if clip:
im = np.clip(im, self.MIN_MAX[0], self.MIN_MAX[1])
im -= im.min()
im /= im.max() + 1e-32
if quantize:
@ -605,32 +621,9 @@ class BaseRestorer(BaseModel):
return im
class RCAN(BaseRestorer):
def __init__(self,
losses=None,
sr_factor=4,
n_resgroups=10,
n_resblocks=20,
n_feats=64,
n_colors=3,
rgb_range=255,
kernel_size=3,
reduction=16,
**params):
params.update({
'n_resgroups': n_resgroups,
'n_resblocks': n_resblocks,
'n_feats': n_feats,
'n_colors': n_colors,
'rgb_range': rgb_range,
'kernel_size': kernel_size,
'reduction': reduction
})
super(RCAN, self).__init__(
model_name='RCAN', losses=losses, sr_factor=sr_factor, **params)
class DRN(BaseRestorer):
TEST_OUT_KEY = -1
def __init__(self,
losses=None,
sr_factor=4,
@ -638,8 +631,10 @@ class DRN(BaseRestorer):
n_blocks=30,
n_feats=16,
n_colors=3,
rgb_range=255,
rgb_range=1.0,
negval=0.2,
lq_loss_weight=0.1,
dual_loss_weight=0.1,
**params):
if sr_factor != max(scale):
raise ValueError(f"`sr_factor` must be equal to `max(scale)`.")
@ -651,12 +646,80 @@ class DRN(BaseRestorer):
'rgb_range': rgb_range,
'negval': negval
})
self.lq_loss_weight = lq_loss_weight
self.dual_loss_weight = dual_loss_weight
super(DRN, self).__init__(
model_name='DRN', losses=losses, sr_factor=sr_factor, **params)
def build_net(self, **params):
net = ppgan.models.generators.DRNGenerator(**params)
return net
from ppgan.modules.init import init_weights
generators = [ppgan.models.generators.DRNGenerator(**params)]
init_weights(generators[-1])
for scale in params['scale']:
dual_model = ppgan.models.generators.drn.DownBlock(
params['negval'], params['n_feats'], params['n_colors'], 2)
generators.append(dual_model)
init_weights(generators[-1])
return GANAdapter(generators, [])
def default_optimizer(self, parameters, *args, **kwargs):
optims_g = [
super(DRN, self).default_optimizer(params_g, *args, **kwargs)
for params_g in parameters['params_g']
]
return OptimizerAdapter(*optims_g)
def run_gan(self, net, inputs, mode, gan_mode='forward_primary'):
if mode != 'train':
raise ValueError("`mode` is not 'train'.")
outputs = OrderedDict()
if gan_mode == 'forward_primary':
sr = net.generator(inputs[0])
lr = [inputs[0]]
lr.extend([
F.interpolate(
inputs[0], scale_factor=s, mode='bicubic')
for s in net.generator.scale[:-1]
])
loss = self.losses(sr[-1], inputs[1])
for i in range(1, len(sr)):
if self.lq_loss_weight > 0:
loss += self.losses(sr[i - 1 - len(sr)],
lr[i - len(sr)]) * self.lq_loss_weight
outputs['loss_prim'] = loss
outputs['sr'] = sr
outputs['lr'] = lr
elif gan_mode == 'forward_dual':
sr, lr = inputs[0], inputs[1]
sr2lr = []
n_scales = len(net.generator.scale)
for i in range(n_scales):
sr2lr_i = net.generators[1 + i](sr[i - n_scales])
sr2lr.append(sr2lr_i)
loss = self.losses(sr2lr[0], lr[0])
for i in range(1, n_scales):
if self.dual_loss_weight > 0.0:
loss += self.losses(sr2lr[i], lr[i]) * self.dual_loss_weight
outputs['loss_dual'] = loss
else:
raise ValueError("Invalid `gan_mode`!")
return outputs
def train_step(self, step, data, net):
outputs = self.run_gan(
net, data, mode='train', gan_mode='forward_primary')
outputs.update(
self.run_gan(
net, (outputs['sr'], outputs['lr']),
mode='train',
gan_mode='forward_dual'))
self.optimizer.clear_grad()
(outputs['loss_prim'] + outputs['loss_dual']).backward()
self.optimizer.step()
return {
'loss_prim': outputs['loss_prim'],
'loss_dual': outputs['loss_dual']
}
class LESRCNN(BaseRestorer):
@ -680,8 +743,6 @@ class LESRCNN(BaseRestorer):
class ESRGAN(BaseRestorer):
MIN_MAX = (0., 1.)
def __init__(self,
losses=None,
sr_factor=4,
@ -704,7 +765,9 @@ class ESRGAN(BaseRestorer):
model_name='ESRGAN', losses=losses, sr_factor=sr_factor, **params)
def build_net(self, **params):
from ppgan.modules.init import init_weights
generator = ppgan.models.generators.RRDBNet(**params)
init_weights(generator)
if self.use_gan:
discriminator = ppgan.models.discriminators.VGGDiscriminator128(
in_channels=params['out_nc'], num_feat=64)
@ -716,10 +779,13 @@ class ESRGAN(BaseRestorer):
def default_loss(self):
if self.use_gan:
self.losses = {
return {
'pixel': res_losses.L1Loss(loss_weight=0.01),
'perceptual':
res_losses.PerceptualLoss(layer_weights={'34': 1.0}),
'perceptual': res_losses.PerceptualLoss(
layer_weights={'34': 1.0},
perceptual_weight=1.0,
style_weight=0.0,
norm_img=False),
'gan': res_losses.GANLoss(
gan_mode='vanilla', loss_weight=0.005)
}
@ -734,7 +800,7 @@ class ESRGAN(BaseRestorer):
parameters['params_d'][0], *args, **kwargs)
return OptimizerAdapter(optim_g, optim_d)
else:
return super(ESRGAN, self).default_optimizer(params, *args,
return super(ESRGAN, self).default_optimizer(parameters, *args,
**kwargs)
def run_gan(self, net, inputs, mode, gan_mode='forward_g'):
@ -744,8 +810,8 @@ class ESRGAN(BaseRestorer):
if gan_mode == 'forward_g':
loss_g = 0
g_pred = net.generator(inputs[0])
loss_pix = self.losses['pixel'](g_pred, tar)
loss_perc, loss_sty = self.losses['perceptual'](g_pred, tar)
loss_pix = self.losses['pixel'](g_pred, inputs[1])
loss_perc, loss_sty = self.losses['perceptual'](g_pred, inputs[1])
loss_g += loss_pix
if loss_perc is not None:
loss_g += loss_perc
@ -767,14 +833,14 @@ class ESRGAN(BaseRestorer):
elif gan_mode == 'forward_d':
self._set_requires_grad(net.discriminator, True)
# Real
fake_d_pred = net.discriminator(data[0]).detach()
real_d_pred = net.discriminator(data[1])
fake_d_pred = net.discriminator(inputs[0]).detach()
real_d_pred = net.discriminator(inputs[1])
loss_d_real = self.losses['gan'](
real_d_pred - paddle.mean(fake_d_pred), True,
is_disc=True) * 0.5
# Fake
fake_d_pred = self.nets['discriminator'](self.output.detach())
loss_d_fake = self.gan_criterion(
fake_d_pred = net.discriminator(inputs[0].detach())
loss_d_fake = self.losses['gan'](
fake_d_pred - paddle.mean(real_d_pred.detach()),
False,
is_disc=True) * 0.5
@ -802,30 +868,43 @@ class ESRGAN(BaseRestorer):
outputs['loss_d'].backward()
optim_d.step()
outputs['loss'] = outupts['loss_g_pps'] + outputs[
outputs['loss'] = outputs['loss_g_pps'] + outputs[
'loss_g_gan'] + outputs['loss_d']
if isinstance(optim_g._learning_rate,
paddle.optimizer.lr.LRScheduler):
# If ReduceOnPlateau is used as the scheduler, use the loss value as the metric.
if isinstance(optim_g._learning_rate,
paddle.optimizer.lr.ReduceOnPlateau):
optim_g._learning_rate.step(loss.item())
else:
optim_g._learning_rate.step()
if isinstance(optim_d._learning_rate,
paddle.optimizer.lr.LRScheduler):
if isinstance(optim_d._learning_rate,
paddle.optimizer.lr.ReduceOnPlateau):
optim_d._learning_rate.step(loss.item())
else:
optim_d._learning_rate.step()
return outputs
return {
'loss': outputs['loss'],
'loss_g_pps': outputs['loss_g_pps'],
'loss_g_gan': outputs['loss_g_gan'],
'loss_d': outputs['loss_d']
}
else:
super(ESRGAN, self).train_step(step, data, net)
return super(ESRGAN, self).train_step(step, data, net)
def _set_requires_grad(self, net, requires_grad):
for p in net.parameters():
p.trainable = requires_grad
class RCAN(BaseRestorer):
def __init__(self,
losses=None,
sr_factor=4,
n_resgroups=10,
n_resblocks=20,
n_feats=64,
n_colors=3,
rgb_range=1.0,
kernel_size=3,
reduction=16,
**params):
params.update({
'n_resgroups': n_resgroups,
'n_resblocks': n_resblocks,
'n_feats': n_feats,
'n_colors': n_colors,
'rgb_range': rgb_range,
'kernel_size': kernel_size,
'reduction': reduction
})
super(RCAN, self).__init__(
model_name='RCAN', losses=losses, sr_factor=sr_factor, **params)

@ -33,7 +33,7 @@ from paddlers.utils import get_single_card_bs, DisablePrint
from paddlers.utils.checkpoint import seg_pretrain_weights_dict
from .base import BaseModel
from .utils import seg_metrics as metrics
from .utils.infer_nets import InferNet
from .utils.infer_nets import InferSegNet
__all__ = ["UNet", "DeepLabV3P", "FastSCNN", "HRNet", "BiSeNetV2", "FarSeg"]
@ -71,7 +71,7 @@ class BaseSegmenter(BaseModel):
return net
def _build_inference_net(self):
infer_net = InferNet(self.net, self.model_type)
infer_net = InferSegNet(self.net)
infer_net.eval()
return infer_net
@ -755,11 +755,11 @@ class BaseSegmenter(BaseModel):
elif item[0] == 'padding':
x, y = item[2]
if isinstance(label_map, np.ndarray):
label_map = label_map[..., y:y + h, x:x + w]
score_map = score_map[..., y:y + h, x:x + w]
label_map = label_map[y:y + h, x:x + w]
score_map = score_map[y:y + h, x:x + w]
else:
label_map = label_map[:, :, y:y + h, x:x + w]
score_map = score_map[:, :, y:y + h, x:x + w]
label_map = label_map[:, y:y + h, x:x + w, :]
score_map = score_map[:, y:y + h, x:x + w, :]
else:
pass
label_map = label_map.squeeze()

@ -15,30 +15,36 @@
import paddle
class PostProcessor(paddle.nn.Layer):
def __init__(self, model_type):
super(PostProcessor, self).__init__()
self.model_type = model_type
class SegPostProcessor(paddle.nn.Layer):
def forward(self, net_outputs):
# label_map [NHW], score_map [NHWC]
logit = net_outputs[0]
outputs = paddle.argmax(logit, axis=1, keepdim=False, dtype='int32'), \
paddle.transpose(paddle.nn.functional.softmax(logit, axis=1), perm=[0, 2, 3, 1])
return outputs
class ResPostProcessor(paddle.nn.Layer):
def __init__(self, out_key=None):
super(ResPostProcessor, self).__init__()
self.out_key = out_key
def forward(self, net_outputs):
if self.out_key is not None:
net_outputs = net_outputs[self.out_key]
outputs = paddle.transpose(net_outputs, perm=[0, 2, 3, 1])
return outputs
class InferNet(paddle.nn.Layer):
def __init__(self, net, model_type):
super(InferNet, self).__init__()
class InferSegNet(paddle.nn.Layer):
def __init__(self, net):
super(InferSegNet, self).__init__()
self.net = net
self.postprocessor = PostProcessor(model_type)
self.postprocessor = SegPostProcessor()
def forward(self, x):
net_outputs = self.net(x)
outputs = self.postprocessor(net_outputs)
return outputs
@ -46,10 +52,21 @@ class InferCDNet(paddle.nn.Layer):
def __init__(self, net):
super(InferCDNet, self).__init__()
self.net = net
self.postprocessor = PostProcessor('change_detector')
self.postprocessor = SegPostProcessor()
def forward(self, x1, x2):
net_outputs = self.net(x1, x2)
outputs = self.postprocessor(net_outputs)
return outputs
class InferResNet(paddle.nn.Layer):
def __init__(self, net, out_key=None):
super(InferResNet, self).__init__()
self.net = net
self.postprocessor = ResPostProcessor(out_key=out_key)
def forward(self, x):
net_outputs = self.net(x)
outputs = self.postprocessor(net_outputs)
return outputs

@ -122,7 +122,11 @@ class OptimizerAdapter(Adapter):
__ducktype__ = paddle.optimizer.Optimizer
__ava__ = ('state_dict', 'set_state_dict', 'clear_grad', 'step', 'get_lr')
# Special dispatching rule
def set_state_dict(self, state_dicts):
# Special dispatching rule
for optim, state_dict in zip(self, state_dicts):
optim.set_state_dict(state_dict)
def get_lr(self):
# Return the lr of the first optimizer
return self[0].get_lr()

@ -1207,7 +1207,7 @@ class RandomCrop(Transform):
if 'target' in sample:
if 'sr_factor' in sample:
sample['target'] = self.apply_im(
sample['image'],
sample['target'],
calc_hr_shape(crop_box, sample['sr_factor']))
else:
sample['target'] = self.apply_im(sample['image'], crop_box)
@ -1993,8 +1993,9 @@ class ArrangeDetector(Arrange):
class ArrangeRestorer(Arrange):
def apply(self, sample):
image = permute(sample['image'], False)
if 'target' in sample:
target = permute(sample['target'], False)
image = permute(sample['image'], False)
if self.mode == 'train':
return image, target
if self.mode == 'eval':

@ -16,7 +16,7 @@ from . import logging
from . import utils
from .utils import (seconds_to_hms, get_encoding, get_single_card_bs, dict2str,
EarlyStop, norm_path, is_pic, MyEncoder, DisablePrint,
Timer)
Timer, to_data_parallel, scheduler_step)
from .checkpoint import get_pretrain_weights, load_pretrain_weights, load_checkpoint
from .env import get_environ_info, get_num_workers, init_parallel_env
from .download import download_and_decompress, decompress

@ -20,11 +20,12 @@ import math
import imghdr
import chardet
import json
import platform
import numpy as np
import paddle
from . import logging
import platform
import paddlers
@ -237,3 +238,33 @@ class Timer(Times):
self.postprocess_time_s.reset()
self.img_num = 0
self.repeats = 0
def to_data_parallel(layers, *args, **kwargs):
from paddlers.tasks.utils.res_adapters import GANAdapter
if isinstance(layers, GANAdapter):
# Inplace modification for efficiency
layers.generators = [
paddle.DataParallel(g, *args, **kwargs) for g in layers.generators
]
layers.discriminators = [
paddle.DataParallel(d, *args, **kwargs)
for d in layers.discriminators
]
else:
layers = paddle.DataParallel(layers, *args, **kwargs)
return layers
def scheduler_step(optimizer):
from paddlers.tasks.utils.res_adapters import OptimizerAdapter
if not isinstance(optimizer, OptimizerAdapter):
optimizer = [optimizer]
for optim in optimizer:
if isinstance(optim._learning_rate, paddle.optimizer.lr.LRScheduler):
# If ReduceOnPlateau is used as the scheduler, use the loss value as the metric.
if isinstance(optim._learning_rate,
paddle.optimizer.lr.ReduceOnPlateau):
optim._learning_rate.step(loss.item())
else:
optim._learning_rate.step()

@ -20,7 +20,6 @@
|image_restoration/drn.py | 图像复原 | DRN |
|image_restoration/esrgan.py | 图像复原 | ESRGAN |
|image_restoration/lesrcnn.py | 图像复原 | LESRCNN |
|image_restoration/rcan.py | 图像复原 | RCAN |
|object_detection/faster_rcnn.py | 目标检测 | Faster R-CNN |
|object_detection/ppyolo.py | 目标检测 | PP-YOLO |
|object_detection/ppyolotiny.py | 目标检测 | PP-YOLO Tiny |

@ -25,8 +25,8 @@ pdrs.utils.download_and_decompress(
train_transforms = T.Compose([
# 读取影像
T.DecodeImg(),
# 将输入影像缩放到256x256大小
T.Resize(target_size=256),
# 从输入影像中裁剪96x96大小的影像块
T.RandomCrop(crop_size=96),
# 以50%的概率实施随机水平翻转
T.RandomHorizontalFlip(prob=0.5),
# 以50%的概率实施随机垂直翻转
@ -39,6 +39,7 @@ train_transforms = T.Compose([
eval_transforms = T.Compose([
T.DecodeImg(),
# 将输入影像缩放到256x256大小
T.Resize(target_size=256),
# 验证阶段与训练阶段的数据归一化方式必须相同
T.Normalize(
@ -52,14 +53,16 @@ train_dataset = pdrs.datasets.ResDataset(
file_list=TRAIN_FILE_LIST_PATH,
transforms=train_transforms,
num_workers=0,
shuffle=True)
shuffle=True,
sr_factor=4)
eval_dataset = pdrs.datasets.ResDataset(
data_dir=DATA_DIR,
file_list=EVAL_FILE_LIST_PATH,
transforms=eval_transforms,
num_workers=0,
shuffle=False)
shuffle=False,
sr_factor=4)
# 使用默认参数构建DRN模型
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
@ -74,10 +77,10 @@ model.train(
eval_dataset=eval_dataset,
save_interval_epochs=1,
# 每多少次迭代记录一次日志
log_interval_steps=50,
log_interval_steps=5,
save_dir=EXP_DIR,
# 初始学习率大小
learning_rate=0.01,
learning_rate=0.001,
# 是否使用early stopping策略,当精度不再改善时提前终止训练
early_stop=False,
# 是否启用VisualDL日志功能

@ -25,8 +25,8 @@ pdrs.utils.download_and_decompress(
train_transforms = T.Compose([
# 读取影像
T.DecodeImg(),
# 将输入影像缩放到256x256大小
T.Resize(target_size=256),
# 从输入影像中裁剪32x32大小的影像块
T.RandomCrop(crop_size=32),
# 以50%的概率实施随机水平翻转
T.RandomHorizontalFlip(prob=0.5),
# 以50%的概率实施随机垂直翻转
@ -39,6 +39,7 @@ train_transforms = T.Compose([
eval_transforms = T.Compose([
T.DecodeImg(),
# 将输入影像缩放到256x256大小
T.Resize(target_size=256),
# 验证阶段与训练阶段的数据归一化方式必须相同
T.Normalize(
@ -52,14 +53,16 @@ train_dataset = pdrs.datasets.ResDataset(
file_list=TRAIN_FILE_LIST_PATH,
transforms=train_transforms,
num_workers=0,
shuffle=True)
shuffle=True,
sr_factor=4)
eval_dataset = pdrs.datasets.ResDataset(
data_dir=DATA_DIR,
file_list=EVAL_FILE_LIST_PATH,
transforms=eval_transforms,
num_workers=0,
shuffle=False)
shuffle=False,
sr_factor=4)
# 使用默认参数构建ESRGAN模型
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
@ -74,10 +77,10 @@ model.train(
eval_dataset=eval_dataset,
save_interval_epochs=1,
# 每多少次迭代记录一次日志
log_interval_steps=50,
log_interval_steps=5,
save_dir=EXP_DIR,
# 初始学习率大小
learning_rate=0.01,
learning_rate=0.001,
# 是否使用early stopping策略,当精度不再改善时提前终止训练
early_stop=False,
# 是否启用VisualDL日志功能

@ -25,8 +25,8 @@ pdrs.utils.download_and_decompress(
train_transforms = T.Compose([
# 读取影像
T.DecodeImg(),
# 将输入影像缩放到256x256大小
T.Resize(target_size=256),
# 从输入影像中裁剪32x32大小的影像块
T.RandomCrop(crop_size=32),
# 以50%的概率实施随机水平翻转
T.RandomHorizontalFlip(prob=0.5),
# 以50%的概率实施随机垂直翻转
@ -39,6 +39,7 @@ train_transforms = T.Compose([
eval_transforms = T.Compose([
T.DecodeImg(),
# 将输入影像缩放到256x256大小
T.Resize(target_size=256),
# 验证阶段与训练阶段的数据归一化方式必须相同
T.Normalize(
@ -52,14 +53,16 @@ train_dataset = pdrs.datasets.ResDataset(
file_list=TRAIN_FILE_LIST_PATH,
transforms=train_transforms,
num_workers=0,
shuffle=True)
shuffle=True,
sr_factor=4)
eval_dataset = pdrs.datasets.ResDataset(
data_dir=DATA_DIR,
file_list=EVAL_FILE_LIST_PATH,
transforms=eval_transforms,
num_workers=0,
shuffle=False)
shuffle=False,
sr_factor=4)
# 使用默认参数构建LESRCNN模型
# 目前已支持的模型请参考:https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/intro/model_zoo.md
@ -74,10 +77,10 @@ model.train(
eval_dataset=eval_dataset,
save_interval_epochs=1,
# 每多少次迭代记录一次日志
log_interval_steps=50,
log_interval_steps=5,
save_dir=EXP_DIR,
# 初始学习率大小
learning_rate=0.01,
learning_rate=0.001,
# 是否使用early stopping策略,当精度不再改善时提前终止训练
early_stop=False,
# 是否启用VisualDL日志功能

Loading…
Cancel
Save