Refactor hungarian assigner (#4259)

* add focal_loss in hungarian_assigner.py

* update hungarian_assigner.py

* update detr config

* update doc and codes

* fix unitest

* fix unitest

* modify code format

* update docstring

* fix unitest

* fix name

* fix docformat

* fix docformat

* fix format

Co-authored-by: fangkairen <fangkairen@sensetime.com>
pull/4348/head
kellenf 4 years ago committed by GitHub
parent 23ded99365
commit 0c84b65b57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 6
      configs/detr/detr_r50_8x2_150e_coco.py
  2. 52
      mmdet/core/bbox/assigners/hungarian_assigner.py
  3. 7
      mmdet/core/bbox/match_costs/__init__.py
  4. 8
      mmdet/core/bbox/match_costs/builder.py
  5. 184
      mmdet/core/bbox/match_costs/match_cost.py
  6. 19
      mmdet/models/dense_heads/transformer_head.py
  7. 18
      tests/test_assigner.py
  8. 8
      tests/test_models/test_heads.py

@ -44,8 +44,10 @@ model = dict(
# training and testing settings
train_cfg = dict(
assigner=dict(
type='HungarianAssigner', cls_weight=1., bbox_weight=5.,
iou_weight=2.))
type='HungarianAssigner',
cls_cost=dict(type='ClassificationCost', weight=1.),
reg_cost=dict(type='BBoxL1Cost', weight=5.0),
iou_cost=dict(type='IoUCost', iou_mode='giou', weight=2.0)))
test_cfg = dict(max_per_img=100)
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

@ -1,8 +1,8 @@
import torch
from ..builder import BBOX_ASSIGNERS
from ..iou_calculators import build_iou_calculator
from ..transforms import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh
from ..match_costs import build_match_cost
from ..transforms import bbox_cxcywh_to_xyxy
from .assign_result import AssignResult
from .base_assigner import BaseAssigner
@ -42,17 +42,12 @@ class HungarianAssigner(BaseAssigner):
"""
def __init__(self,
cls_weight=1.,
bbox_weight=1.,
iou_weight=1.,
iou_calculator=dict(type='BboxOverlaps2D'),
iou_mode='giou'):
# defaultly giou cost is used in the official DETR repo.
self.iou_mode = iou_mode
self.cls_weight = cls_weight
self.bbox_weight = bbox_weight
self.iou_weight = iou_weight
self.iou_calculator = build_iou_calculator(iou_calculator)
cls_cost=dict(type='ClassificationCost', weight=1.),
reg_cost=dict(type='BBoxL1Cost', weight=1.0),
iou_cost=dict(type='IoUCost', iou_mode='giou', weight=1.0)):
self.cls_cost = build_match_cost(cls_cost)
self.reg_cost = build_match_cost(reg_cost)
self.iou_cost = build_match_cost(iou_cost)
def assign(self,
bbox_pred,
@ -113,36 +108,21 @@ class HungarianAssigner(BaseAssigner):
assigned_gt_inds[:] = 0
return AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels)
# 2. compute the weighted costs
# classification cost.
# Following the official DETR repo, contrary to the loss that
# NLL is used, we approximate it in 1 - cls_score[gt_label].
# The 1 is a constant that doesn't change the matching,
# so it can be ommitted.
cls_score = cls_pred.softmax(-1)
cls_cost = -cls_score[:, gt_labels] # [num_bboxes, num_gt]
# regression L1 cost
img_h, img_w, _ = img_meta['img_shape']
factor = torch.Tensor([img_w, img_h, img_w,
img_h]).unsqueeze(0).to(gt_bboxes.device)
gt_bboxes_normalized = gt_bboxes / factor
bbox_cost = torch.cdist(
bbox_pred, bbox_xyxy_to_cxcywh(gt_bboxes_normalized),
p=1) # [num_bboxes, num_gt]
# 2. compute the weighted costs
# classification and bboxcost.
cls_cost = self.cls_cost(cls_pred, gt_labels)
# regression L1 cost
normalize_gt_bboxes = gt_bboxes / factor
reg_cost = self.reg_cost(bbox_pred, normalize_gt_bboxes)
# regression iou cost, defaultly giou is used in official DETR.
bboxes = bbox_cxcywh_to_xyxy(bbox_pred) * factor
# overlaps: [num_bboxes, num_gt]
overlaps = self.iou_calculator(
bboxes, gt_bboxes, mode=self.iou_mode, is_aligned=False)
# The 1 is a constant that doesn't change the matching, so ommitted.
iou_cost = -overlaps
iou_cost = self.iou_cost(bboxes, gt_bboxes)
# weighted sum of above three costs
cost = self.cls_weight * cls_cost + self.bbox_weight * bbox_cost
cost = cost + self.iou_weight * iou_cost
cost = cls_cost + reg_cost + iou_cost
# 3. do Hungarian matching on CPU using linear_sum_assignment
cost = cost.detach().cpu()

@ -0,0 +1,7 @@
from .builder import build_match_cost
from .match_cost import BBoxL1Cost, ClassificationCost, FocalLossCost, IoUCost
__all__ = [
'build_match_cost', 'ClassificationCost', 'BBoxL1Cost', 'IoUCost',
'FocalLossCost'
]

@ -0,0 +1,8 @@
from mmcv.utils import Registry, build_from_cfg
MATCH_COST = Registry('Match Cost')
def build_match_cost(cfg, default_args=None):
"""Builder of IoU calculator."""
return build_from_cfg(cfg, MATCH_COST, default_args)

@ -0,0 +1,184 @@
import torch
from mmdet.core.bbox.iou_calculators import bbox_overlaps
from mmdet.core.bbox.transforms import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh
from .builder import MATCH_COST
@MATCH_COST.register_module()
class BBoxL1Cost(object):
"""BBoxL1Cost.
Args:
weight (int | float, optional): loss_weight
box_format (str, optional): 'xyxy' for DETR, 'xywh' for Sparse_RCNN
Examples:
>>> from mmdet.core.bbox.match_costs.match_cost import BBoxL1Cost
>>> import torch
>>> self = BBoxL1Cost()
>>> bbox_pred = torch.rand(1, 4)
>>> gt_bboxes= torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]])
>>> factor = torch.tensor([10, 8, 10, 8])
>>> self(bbox_pred, gt_bboxes, factor)
tensor([[1.6172, 1.6422]])
"""
def __init__(self, weight=1., box_format='xyxy'):
self.weight = weight
assert box_format in ['xyxy', 'xywh']
self.box_format = box_format
def __call__(self, bbox_pred, gt_bboxes):
"""
Args:
bbox_pred (Tensor): Predicted boxes with normalized coordinates
(cx, cy, w, h), which are all in range [0, 1]. Shape
[num_query, 4].
gt_bboxes (Tensor): Ground truth boxes with normalized
coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
Returns:
torch.Tensor: bbox_cost value with weight
"""
if self.box_format == 'xywh':
gt_bboxes = bbox_xyxy_to_cxcywh(gt_bboxes)
elif self.box_format == 'xyxy':
bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred)
bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1)
return bbox_cost * self.weight
@MATCH_COST.register_module()
class FocalLossCost(object):
"""FocalLossCost.
Args:
weight (int | float, optional): loss_weight
alpha (int | float, optional): focal_loss alpha
gamma (int | float, optional): focal_loss gamma
eps (float, optional): default 1e-12
Examples:
>>> from mmdet.core.bbox.match_costs.match_cost import FocalLossCost
>>> import torch
>>> self = FocalLossCost()
>>> cls_pred = torch.rand(4, 3)
>>> gt_labels = torch.tensor([0, 1, 2])
>>> factor = torch.tensor([10, 8, 10, 8])
>>> self(cls_pred, gt_labels)
tensor([[-0.3236, -0.3364, -0.2699],
[-0.3439, -0.3209, -0.4807],
[-0.4099, -0.3795, -0.2929],
[-0.1950, -0.1207, -0.2626]])
"""
def __init__(self, weight=1., alpha=0.25, gamma=2, eps=1e-12):
self.weight = weight
self.alpha = alpha
self.gamma = gamma
self.eps = eps
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classification logits, shape
[num_query, num_class].
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
Returns:
torch.Tensor: cls_cost value with weight
"""
cls_pred = cls_pred.sigmoid()
neg_cost = -(1 - cls_pred + self.eps).log() * (
1 - self.alpha) * cls_pred.pow(self.gamma)
pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
1 - cls_pred).pow(self.gamma)
cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels]
return cls_cost * self.weight
@MATCH_COST.register_module()
class ClassificationCost(object):
"""ClsSoftmaxCost.
Args:
weight (int | float, optional): loss_weight
Examples:
>>> from mmdet.core.bbox.match_costs.match_cost import \
... ClassificationCost
>>> import torch
>>> self = ClassificationCost()
>>> cls_pred = torch.rand(4, 3)
>>> gt_labels = torch.tensor([0, 1, 2])
>>> factor = torch.tensor([10, 8, 10, 8])
>>> self(cls_pred, gt_labels)
tensor([[-0.3430, -0.3525, -0.3045],
[-0.3077, -0.2931, -0.3992],
[-0.3664, -0.3455, -0.2881],
[-0.3343, -0.2701, -0.3956]])
"""
def __init__(self, weight=1.):
self.weight = weight
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classification logits, shape
[num_query, num_class].
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
Returns:
torch.Tensor: cls_cost value with weight
"""
# Following the official DETR repo, contrary to the loss that
# NLL is used, we approximate it in 1 - cls_score[gt_label].
# The 1 is a constant that doesn't change the matching,
# so it can be ommitted.
cls_score = cls_pred.softmax(-1)
cls_cost = -cls_score[:, gt_labels]
return cls_cost * self.weight
@MATCH_COST.register_module()
class IoUCost(object):
"""IoUCost.
Args:
iou_mode (str, optional): iou mode such as 'iou' | 'giou'
weight (int | float, optional): loss weight
Examples:
>>> from mmdet.core.bbox.match_costs.match_cost import IoUCost
>>> import torch
>>> self = IoUCost()
>>> bboxes = torch.FloatTensor([[1,1, 2, 2], [2, 2, 3, 4]])
>>> gt_bboxes = torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]])
>>> self(bboxes, gt_bboxes)
tensor([[-0.1250, 0.1667],
[ 0.1667, -0.5000]])
"""
def __init__(self, iou_mode='giou', weight=1.):
self.weight = weight
self.iou_mode = iou_mode
def __call__(self, bboxes, gt_bboxes):
"""
Args:
bboxes (Tensor): Predicted boxes with unnormalized coordinates
(x1, y1, x2, y2). Shape [num_query, 4].
gt_bboxes (Tensor): Ground truth boxes with unnormalized
coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
Returns:
torch.Tensor: iou_cost value with weight
"""
# overlaps: [num_bboxes, num_gt]
overlaps = bbox_overlaps(
bboxes, gt_bboxes, mode=self.iou_mode, is_aligned=False)
# The 1 is a constant that doesn't change the matching, so ommitted.
iou_cost = -overlaps
return iou_cost * self.weight

@ -77,11 +77,10 @@ class TransformerHead(AnchorFreeHead):
train_cfg=dict(
assigner=dict(
type='HungarianAssigner',
cls_weight=1.,
bbox_weight=5.,
iou_weight=2.,
iou_calculator=dict(type='BboxOverlaps2D'),
iou_mode='giou')),
cls_cost=dict(type='ClassificationCost', weight=1.),
reg_cost=dict(type='BBoxL1Cost', weight=5.0),
iou_cost=dict(
type='IoUCost', iou_mode='giou', weight=2.0))),
test_cfg=dict(max_per_img=100),
**kwargs):
# NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
@ -124,13 +123,13 @@ class TransformerHead(AnchorFreeHead):
assert 'assigner' in train_cfg, 'assigner should be provided '\
'when train_cfg is set.'
assigner = train_cfg['assigner']
assert loss_cls['loss_weight'] == assigner['cls_weight'], \
assert loss_cls['loss_weight'] == assigner['cls_cost']['weight'], \
'The classification weight for loss and matcher should be' \
'exactly the same.'
assert loss_bbox['loss_weight'] == assigner['bbox_weight'], \
'The regression L1 weight for loss and matcher should be' \
'exactly the same.'
assert loss_iou['loss_weight'] == assigner['iou_weight'], \
assert loss_bbox['loss_weight'] == assigner['reg_cost'][
'weight'], 'The regression L1 weight for loss and matcher ' \
'should be exactly the same.'
assert loss_iou['loss_weight'] == assigner['iou_cost']['weight'], \
'The regression iou weight for loss and matcher should be' \
'exactly the same.'
self.assigner = build_assigner(assigner)

@ -380,7 +380,7 @@ def test_center_region_assigner_with_empty_gts():
def test_hungarian_match_assigner():
self = HungarianAssigner()
assert self.iou_mode == 'giou'
assert self.iou_cost.iou_mode == 'giou'
# test no gt bboxes
bbox_pred = torch.rand((10, 4))
@ -403,8 +403,20 @@ def test_hungarian_match_assigner():
assert (assign_result.labels > -1).sum() == gt_bboxes.size(0)
# test iou mode
self = HungarianAssigner(iou_mode='iou')
assert self.iou_mode == 'iou'
self = HungarianAssigner(
iou_cost=dict(type='IoUCost', iou_mode='iou', weight=1.0))
assert self.iou_cost.iou_mode == 'iou'
assign_result = self.assign(bbox_pred, cls_pred, gt_bboxes, gt_labels,
img_meta)
assert torch.all(assign_result.gt_inds > -1)
assert (assign_result.gt_inds > 0).sum() == gt_bboxes.size(0)
assert (assign_result.labels > -1).sum() == gt_bboxes.size(0)
# test focal loss mode
self = HungarianAssigner(
iou_cost=dict(type='IoUCost', iou_mode='giou', weight=1.0),
cls_cost=dict(type='FocalLossCost', weight=1.))
assert self.iou_cost.iou_mode == 'giou'
assign_result = self.assign(bbox_pred, cls_pred, gt_bboxes, gt_labels,
img_meta)
assert torch.all(assign_result.gt_inds > -1)

@ -1243,11 +1243,9 @@ def test_transformer_head_loss():
train_cfg = dict(
assigner=dict(
type='HungarianAssigner',
cls_weight=1.,
bbox_weight=5.,
iou_weight=2.,
iou_calculator=dict(type='BboxOverlaps2D'),
iou_mode='giou'))
cls_cost=dict(type='ClassificationCost', weight=1.0),
reg_cost=dict(type='BBoxL1Cost', weight=5.0),
iou_cost=dict(type='IoUCost', iou_mode='giou', weight=2.0)))
transformer_cfg = dict(
type='Transformer',
embed_dims=4,

Loading…
Cancel
Save