You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
101 lines
3.9 KiB
101 lines
3.9 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import paddle |
|
import paddle.nn.functional as F |
|
from paddlers.models.ppdet.core.workspace import register, serializable |
|
|
|
__all__ = ['SOLOv2Loss'] |
|
|
|
|
|
@register |
|
@serializable |
|
class SOLOv2Loss(object): |
|
""" |
|
SOLOv2Loss |
|
Args: |
|
ins_loss_weight (float): Weight of instance loss. |
|
focal_loss_gamma (float): Gamma parameter for focal loss. |
|
focal_loss_alpha (float): Alpha parameter for focal loss. |
|
""" |
|
|
|
def __init__(self, |
|
ins_loss_weight=3.0, |
|
focal_loss_gamma=2.0, |
|
focal_loss_alpha=0.25): |
|
self.ins_loss_weight = ins_loss_weight |
|
self.focal_loss_gamma = focal_loss_gamma |
|
self.focal_loss_alpha = focal_loss_alpha |
|
|
|
def _dice_loss(self, input, target): |
|
input = paddle.reshape(input, shape=(paddle.shape(input)[0], -1)) |
|
target = paddle.reshape(target, shape=(paddle.shape(target)[0], -1)) |
|
a = paddle.sum(input * target, axis=1) |
|
b = paddle.sum(input * input, axis=1) + 0.001 |
|
c = paddle.sum(target * target, axis=1) + 0.001 |
|
d = (2 * a) / (b + c) |
|
return 1 - d |
|
|
|
def __call__(self, ins_pred_list, ins_label_list, cate_preds, cate_labels, |
|
num_ins): |
|
""" |
|
Get loss of network of SOLOv2. |
|
Args: |
|
ins_pred_list (list): Variable list of instance branch output. |
|
ins_label_list (list): List of instance labels pre batch. |
|
cate_preds (list): Concat Variable list of categroy branch output. |
|
cate_labels (list): Concat list of categroy labels pre batch. |
|
num_ins (int): Number of positive samples in a mini-batch. |
|
Returns: |
|
loss_ins (Variable): The instance loss Variable of SOLOv2 network. |
|
loss_cate (Variable): The category loss Variable of SOLOv2 network. |
|
""" |
|
|
|
#1. Ues dice_loss to calculate instance loss |
|
loss_ins = [] |
|
total_weights = paddle.zeros(shape=[1], dtype='float32') |
|
for input, target in zip(ins_pred_list, ins_label_list): |
|
if input is None: |
|
continue |
|
target = paddle.cast(target, 'float32') |
|
target = paddle.reshape( |
|
target, |
|
shape=[-1, paddle.shape(input)[-2], paddle.shape(input)[-1]]) |
|
weights = paddle.cast( |
|
paddle.sum(target, axis=[1, 2]) > 0, 'float32') |
|
input = F.sigmoid(input) |
|
dice_out = paddle.multiply(self._dice_loss(input, target), weights) |
|
total_weights += paddle.sum(weights) |
|
loss_ins.append(dice_out) |
|
loss_ins = paddle.sum(paddle.concat(loss_ins)) / total_weights |
|
loss_ins = loss_ins * self.ins_loss_weight |
|
|
|
#2. Ues sigmoid_focal_loss to calculate category loss |
|
# expand onehot labels |
|
num_classes = cate_preds.shape[-1] |
|
cate_labels_bin = F.one_hot(cate_labels, num_classes=num_classes + 1) |
|
cate_labels_bin = cate_labels_bin[:, 1:] |
|
|
|
loss_cate = F.sigmoid_focal_loss( |
|
cate_preds, |
|
label=cate_labels_bin, |
|
normalizer=num_ins + 1., |
|
gamma=self.focal_loss_gamma, |
|
alpha=self.focal_loss_alpha) |
|
|
|
return loss_ins, loss_cate
|
|
|