You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
47 lines
1.8 KiB
47 lines
1.8 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import paddle.nn.functional as F |
|
from paddlers.models.ppdet.core.workspace import register, serializable |
|
from .iou_loss import IouLoss |
|
from ..bbox_utils import bbox_iou |
|
|
|
|
|
@register |
|
@serializable |
|
class IouAwareLoss(IouLoss): |
|
""" |
|
iou aware loss, see https://arxiv.org/abs/1912.05992 |
|
Args: |
|
loss_weight (float): iou aware loss weight, default is 1.0 |
|
max_height (int): max height of input to support random shape input |
|
max_width (int): max width of input to support random shape input |
|
""" |
|
|
|
def __init__(self, loss_weight=1.0, giou=False, diou=False, ciou=False): |
|
super(IouAwareLoss, self).__init__( |
|
loss_weight=loss_weight, giou=giou, diou=diou, ciou=ciou) |
|
|
|
def __call__(self, ioup, pbox, gbox): |
|
iou = bbox_iou( |
|
pbox, gbox, giou=self.giou, diou=self.diou, ciou=self.ciou) |
|
iou.stop_gradient = True |
|
loss_iou_aware = F.binary_cross_entropy_with_logits( |
|
ioup, iou, reduction='none') |
|
loss_iou_aware = loss_iou_aware * self.loss_weight |
|
return loss_iou_aware
|
|
|