Refactor hungarian assigner (#4259)
* add focal_loss in hungarian_assigner.py * update hungarian_assigner.py * update detr config * update doc and codes * fix unitest * fix unitest * modify code format * update docstring * fix unitest * fix name * fix docformat * fix docformat * fix format Co-authored-by: fangkairen <fangkairen@sensetime.com>pull/4348/head
parent
23ded99365
commit
0c84b65b57
8 changed files with 246 additions and 56 deletions
@ -0,0 +1,7 @@ |
||||
from .builder import build_match_cost |
||||
from .match_cost import BBoxL1Cost, ClassificationCost, FocalLossCost, IoUCost |
||||
|
||||
__all__ = [ |
||||
'build_match_cost', 'ClassificationCost', 'BBoxL1Cost', 'IoUCost', |
||||
'FocalLossCost' |
||||
] |
@ -0,0 +1,8 @@ |
||||
from mmcv.utils import Registry, build_from_cfg |
||||
|
||||
MATCH_COST = Registry('Match Cost') |
||||
|
||||
|
||||
def build_match_cost(cfg, default_args=None): |
||||
"""Builder of IoU calculator.""" |
||||
return build_from_cfg(cfg, MATCH_COST, default_args) |
@ -0,0 +1,184 @@ |
||||
import torch |
||||
|
||||
from mmdet.core.bbox.iou_calculators import bbox_overlaps |
||||
from mmdet.core.bbox.transforms import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh |
||||
from .builder import MATCH_COST |
||||
|
||||
|
||||
@MATCH_COST.register_module() |
||||
class BBoxL1Cost(object): |
||||
"""BBoxL1Cost. |
||||
|
||||
Args: |
||||
weight (int | float, optional): loss_weight |
||||
box_format (str, optional): 'xyxy' for DETR, 'xywh' for Sparse_RCNN |
||||
|
||||
Examples: |
||||
>>> from mmdet.core.bbox.match_costs.match_cost import BBoxL1Cost |
||||
>>> import torch |
||||
>>> self = BBoxL1Cost() |
||||
>>> bbox_pred = torch.rand(1, 4) |
||||
>>> gt_bboxes= torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]]) |
||||
>>> factor = torch.tensor([10, 8, 10, 8]) |
||||
>>> self(bbox_pred, gt_bboxes, factor) |
||||
tensor([[1.6172, 1.6422]]) |
||||
""" |
||||
|
||||
def __init__(self, weight=1., box_format='xyxy'): |
||||
self.weight = weight |
||||
assert box_format in ['xyxy', 'xywh'] |
||||
self.box_format = box_format |
||||
|
||||
def __call__(self, bbox_pred, gt_bboxes): |
||||
""" |
||||
Args: |
||||
bbox_pred (Tensor): Predicted boxes with normalized coordinates |
||||
(cx, cy, w, h), which are all in range [0, 1]. Shape |
||||
[num_query, 4]. |
||||
gt_bboxes (Tensor): Ground truth boxes with normalized |
||||
coordinates (x1, y1, x2, y2). Shape [num_gt, 4]. |
||||
|
||||
Returns: |
||||
torch.Tensor: bbox_cost value with weight |
||||
""" |
||||
if self.box_format == 'xywh': |
||||
gt_bboxes = bbox_xyxy_to_cxcywh(gt_bboxes) |
||||
elif self.box_format == 'xyxy': |
||||
bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred) |
||||
bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1) |
||||
return bbox_cost * self.weight |
||||
|
||||
|
||||
@MATCH_COST.register_module() |
||||
class FocalLossCost(object): |
||||
"""FocalLossCost. |
||||
|
||||
Args: |
||||
weight (int | float, optional): loss_weight |
||||
alpha (int | float, optional): focal_loss alpha |
||||
gamma (int | float, optional): focal_loss gamma |
||||
eps (float, optional): default 1e-12 |
||||
|
||||
Examples: |
||||
>>> from mmdet.core.bbox.match_costs.match_cost import FocalLossCost |
||||
>>> import torch |
||||
>>> self = FocalLossCost() |
||||
>>> cls_pred = torch.rand(4, 3) |
||||
>>> gt_labels = torch.tensor([0, 1, 2]) |
||||
>>> factor = torch.tensor([10, 8, 10, 8]) |
||||
>>> self(cls_pred, gt_labels) |
||||
tensor([[-0.3236, -0.3364, -0.2699], |
||||
[-0.3439, -0.3209, -0.4807], |
||||
[-0.4099, -0.3795, -0.2929], |
||||
[-0.1950, -0.1207, -0.2626]]) |
||||
""" |
||||
|
||||
def __init__(self, weight=1., alpha=0.25, gamma=2, eps=1e-12): |
||||
self.weight = weight |
||||
self.alpha = alpha |
||||
self.gamma = gamma |
||||
self.eps = eps |
||||
|
||||
def __call__(self, cls_pred, gt_labels): |
||||
""" |
||||
Args: |
||||
cls_pred (Tensor): Predicted classification logits, shape |
||||
[num_query, num_class]. |
||||
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). |
||||
|
||||
Returns: |
||||
torch.Tensor: cls_cost value with weight |
||||
""" |
||||
cls_pred = cls_pred.sigmoid() |
||||
neg_cost = -(1 - cls_pred + self.eps).log() * ( |
||||
1 - self.alpha) * cls_pred.pow(self.gamma) |
||||
pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( |
||||
1 - cls_pred).pow(self.gamma) |
||||
cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels] |
||||
return cls_cost * self.weight |
||||
|
||||
|
||||
@MATCH_COST.register_module() |
||||
class ClassificationCost(object): |
||||
"""ClsSoftmaxCost. |
||||
|
||||
Args: |
||||
weight (int | float, optional): loss_weight |
||||
|
||||
Examples: |
||||
>>> from mmdet.core.bbox.match_costs.match_cost import \ |
||||
... ClassificationCost |
||||
>>> import torch |
||||
>>> self = ClassificationCost() |
||||
>>> cls_pred = torch.rand(4, 3) |
||||
>>> gt_labels = torch.tensor([0, 1, 2]) |
||||
>>> factor = torch.tensor([10, 8, 10, 8]) |
||||
>>> self(cls_pred, gt_labels) |
||||
tensor([[-0.3430, -0.3525, -0.3045], |
||||
[-0.3077, -0.2931, -0.3992], |
||||
[-0.3664, -0.3455, -0.2881], |
||||
[-0.3343, -0.2701, -0.3956]]) |
||||
""" |
||||
|
||||
def __init__(self, weight=1.): |
||||
self.weight = weight |
||||
|
||||
def __call__(self, cls_pred, gt_labels): |
||||
""" |
||||
Args: |
||||
cls_pred (Tensor): Predicted classification logits, shape |
||||
[num_query, num_class]. |
||||
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). |
||||
|
||||
Returns: |
||||
torch.Tensor: cls_cost value with weight |
||||
""" |
||||
# Following the official DETR repo, contrary to the loss that |
||||
# NLL is used, we approximate it in 1 - cls_score[gt_label]. |
||||
# The 1 is a constant that doesn't change the matching, |
||||
# so it can be ommitted. |
||||
cls_score = cls_pred.softmax(-1) |
||||
cls_cost = -cls_score[:, gt_labels] |
||||
return cls_cost * self.weight |
||||
|
||||
|
||||
@MATCH_COST.register_module() |
||||
class IoUCost(object): |
||||
"""IoUCost. |
||||
|
||||
Args: |
||||
iou_mode (str, optional): iou mode such as 'iou' | 'giou' |
||||
weight (int | float, optional): loss weight |
||||
|
||||
Examples: |
||||
>>> from mmdet.core.bbox.match_costs.match_cost import IoUCost |
||||
>>> import torch |
||||
>>> self = IoUCost() |
||||
>>> bboxes = torch.FloatTensor([[1,1, 2, 2], [2, 2, 3, 4]]) |
||||
>>> gt_bboxes = torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]]) |
||||
>>> self(bboxes, gt_bboxes) |
||||
tensor([[-0.1250, 0.1667], |
||||
[ 0.1667, -0.5000]]) |
||||
""" |
||||
|
||||
def __init__(self, iou_mode='giou', weight=1.): |
||||
self.weight = weight |
||||
self.iou_mode = iou_mode |
||||
|
||||
def __call__(self, bboxes, gt_bboxes): |
||||
""" |
||||
Args: |
||||
bboxes (Tensor): Predicted boxes with unnormalized coordinates |
||||
(x1, y1, x2, y2). Shape [num_query, 4]. |
||||
gt_bboxes (Tensor): Ground truth boxes with unnormalized |
||||
coordinates (x1, y1, x2, y2). Shape [num_gt, 4]. |
||||
|
||||
Returns: |
||||
torch.Tensor: iou_cost value with weight |
||||
""" |
||||
# overlaps: [num_bboxes, num_gt] |
||||
overlaps = bbox_overlaps( |
||||
bboxes, gt_bboxes, mode=self.iou_mode, is_aligned=False) |
||||
# The 1 is a constant that doesn't change the matching, so ommitted. |
||||
iou_cost = -overlaps |
||||
return iou_cost * self.weight |
Loading…
Reference in new issue