From 0c84b65b57b394c572d9320e6747f19256fd57be Mon Sep 17 00:00:00 2001 From: kellenf <984677144@qq.com> Date: Wed, 23 Dec 2020 18:06:19 +0800 Subject: [PATCH] Refactor hungarian assigner (#4259) * add focal_loss in hungarian_assigner.py * update hungarian_assigner.py * update detr config * update doc and codes * fix unitest * fix unitest * modify code format * update docstring * fix unitest * fix name * fix docformat * fix docformat * fix format Co-authored-by: fangkairen --- configs/detr/detr_r50_8x2_150e_coco.py | 6 +- .../core/bbox/assigners/hungarian_assigner.py | 52 ++--- mmdet/core/bbox/match_costs/__init__.py | 7 + mmdet/core/bbox/match_costs/builder.py | 8 + mmdet/core/bbox/match_costs/match_cost.py | 184 ++++++++++++++++++ mmdet/models/dense_heads/transformer_head.py | 19 +- tests/test_assigner.py | 18 +- tests/test_models/test_heads.py | 8 +- 8 files changed, 246 insertions(+), 56 deletions(-) create mode 100644 mmdet/core/bbox/match_costs/__init__.py create mode 100644 mmdet/core/bbox/match_costs/builder.py create mode 100644 mmdet/core/bbox/match_costs/match_cost.py diff --git a/configs/detr/detr_r50_8x2_150e_coco.py b/configs/detr/detr_r50_8x2_150e_coco.py index c28157e62..ca1f9262e 100644 --- a/configs/detr/detr_r50_8x2_150e_coco.py +++ b/configs/detr/detr_r50_8x2_150e_coco.py @@ -44,8 +44,10 @@ model = dict( # training and testing settings train_cfg = dict( assigner=dict( - type='HungarianAssigner', cls_weight=1., bbox_weight=5., - iou_weight=2.)) + type='HungarianAssigner', + cls_cost=dict(type='ClassificationCost', weight=1.), + reg_cost=dict(type='BBoxL1Cost', weight=5.0), + iou_cost=dict(type='IoUCost', iou_mode='giou', weight=2.0))) test_cfg = dict(max_per_img=100) img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) diff --git a/mmdet/core/bbox/assigners/hungarian_assigner.py b/mmdet/core/bbox/assigners/hungarian_assigner.py index 404118597..437b32a28 100644 --- a/mmdet/core/bbox/assigners/hungarian_assigner.py +++ b/mmdet/core/bbox/assigners/hungarian_assigner.py @@ -1,8 +1,8 @@ import torch from ..builder import BBOX_ASSIGNERS -from ..iou_calculators import build_iou_calculator -from ..transforms import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh +from ..match_costs import build_match_cost +from ..transforms import bbox_cxcywh_to_xyxy from .assign_result import AssignResult from .base_assigner import BaseAssigner @@ -42,17 +42,12 @@ class HungarianAssigner(BaseAssigner): """ def __init__(self, - cls_weight=1., - bbox_weight=1., - iou_weight=1., - iou_calculator=dict(type='BboxOverlaps2D'), - iou_mode='giou'): - # defaultly giou cost is used in the official DETR repo. - self.iou_mode = iou_mode - self.cls_weight = cls_weight - self.bbox_weight = bbox_weight - self.iou_weight = iou_weight - self.iou_calculator = build_iou_calculator(iou_calculator) + cls_cost=dict(type='ClassificationCost', weight=1.), + reg_cost=dict(type='BBoxL1Cost', weight=1.0), + iou_cost=dict(type='IoUCost', iou_mode='giou', weight=1.0)): + self.cls_cost = build_match_cost(cls_cost) + self.reg_cost = build_match_cost(reg_cost) + self.iou_cost = build_match_cost(iou_cost) def assign(self, bbox_pred, @@ -113,36 +108,21 @@ class HungarianAssigner(BaseAssigner): assigned_gt_inds[:] = 0 return AssignResult( num_gts, assigned_gt_inds, None, labels=assigned_labels) - - # 2. compute the weighted costs - # classification cost. - # Following the official DETR repo, contrary to the loss that - # NLL is used, we approximate it in 1 - cls_score[gt_label]. - # The 1 is a constant that doesn't change the matching, - # so it can be ommitted. - cls_score = cls_pred.softmax(-1) - cls_cost = -cls_score[:, gt_labels] # [num_bboxes, num_gt] - - # regression L1 cost img_h, img_w, _ = img_meta['img_shape'] factor = torch.Tensor([img_w, img_h, img_w, img_h]).unsqueeze(0).to(gt_bboxes.device) - gt_bboxes_normalized = gt_bboxes / factor - bbox_cost = torch.cdist( - bbox_pred, bbox_xyxy_to_cxcywh(gt_bboxes_normalized), - p=1) # [num_bboxes, num_gt] + # 2. compute the weighted costs + # classification and bboxcost. + cls_cost = self.cls_cost(cls_pred, gt_labels) + # regression L1 cost + normalize_gt_bboxes = gt_bboxes / factor + reg_cost = self.reg_cost(bbox_pred, normalize_gt_bboxes) # regression iou cost, defaultly giou is used in official DETR. bboxes = bbox_cxcywh_to_xyxy(bbox_pred) * factor - # overlaps: [num_bboxes, num_gt] - overlaps = self.iou_calculator( - bboxes, gt_bboxes, mode=self.iou_mode, is_aligned=False) - # The 1 is a constant that doesn't change the matching, so ommitted. - iou_cost = -overlaps - + iou_cost = self.iou_cost(bboxes, gt_bboxes) # weighted sum of above three costs - cost = self.cls_weight * cls_cost + self.bbox_weight * bbox_cost - cost = cost + self.iou_weight * iou_cost + cost = cls_cost + reg_cost + iou_cost # 3. do Hungarian matching on CPU using linear_sum_assignment cost = cost.detach().cpu() diff --git a/mmdet/core/bbox/match_costs/__init__.py b/mmdet/core/bbox/match_costs/__init__.py new file mode 100644 index 000000000..add5e0d39 --- /dev/null +++ b/mmdet/core/bbox/match_costs/__init__.py @@ -0,0 +1,7 @@ +from .builder import build_match_cost +from .match_cost import BBoxL1Cost, ClassificationCost, FocalLossCost, IoUCost + +__all__ = [ + 'build_match_cost', 'ClassificationCost', 'BBoxL1Cost', 'IoUCost', + 'FocalLossCost' +] diff --git a/mmdet/core/bbox/match_costs/builder.py b/mmdet/core/bbox/match_costs/builder.py new file mode 100644 index 000000000..6894017d4 --- /dev/null +++ b/mmdet/core/bbox/match_costs/builder.py @@ -0,0 +1,8 @@ +from mmcv.utils import Registry, build_from_cfg + +MATCH_COST = Registry('Match Cost') + + +def build_match_cost(cfg, default_args=None): + """Builder of IoU calculator.""" + return build_from_cfg(cfg, MATCH_COST, default_args) diff --git a/mmdet/core/bbox/match_costs/match_cost.py b/mmdet/core/bbox/match_costs/match_cost.py new file mode 100644 index 000000000..7c4d20ccc --- /dev/null +++ b/mmdet/core/bbox/match_costs/match_cost.py @@ -0,0 +1,184 @@ +import torch + +from mmdet.core.bbox.iou_calculators import bbox_overlaps +from mmdet.core.bbox.transforms import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh +from .builder import MATCH_COST + + +@MATCH_COST.register_module() +class BBoxL1Cost(object): + """BBoxL1Cost. + + Args: + weight (int | float, optional): loss_weight + box_format (str, optional): 'xyxy' for DETR, 'xywh' for Sparse_RCNN + + Examples: + >>> from mmdet.core.bbox.match_costs.match_cost import BBoxL1Cost + >>> import torch + >>> self = BBoxL1Cost() + >>> bbox_pred = torch.rand(1, 4) + >>> gt_bboxes= torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]]) + >>> factor = torch.tensor([10, 8, 10, 8]) + >>> self(bbox_pred, gt_bboxes, factor) + tensor([[1.6172, 1.6422]]) + """ + + def __init__(self, weight=1., box_format='xyxy'): + self.weight = weight + assert box_format in ['xyxy', 'xywh'] + self.box_format = box_format + + def __call__(self, bbox_pred, gt_bboxes): + """ + Args: + bbox_pred (Tensor): Predicted boxes with normalized coordinates + (cx, cy, w, h), which are all in range [0, 1]. Shape + [num_query, 4]. + gt_bboxes (Tensor): Ground truth boxes with normalized + coordinates (x1, y1, x2, y2). Shape [num_gt, 4]. + + Returns: + torch.Tensor: bbox_cost value with weight + """ + if self.box_format == 'xywh': + gt_bboxes = bbox_xyxy_to_cxcywh(gt_bboxes) + elif self.box_format == 'xyxy': + bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred) + bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1) + return bbox_cost * self.weight + + +@MATCH_COST.register_module() +class FocalLossCost(object): + """FocalLossCost. + + Args: + weight (int | float, optional): loss_weight + alpha (int | float, optional): focal_loss alpha + gamma (int | float, optional): focal_loss gamma + eps (float, optional): default 1e-12 + + Examples: + >>> from mmdet.core.bbox.match_costs.match_cost import FocalLossCost + >>> import torch + >>> self = FocalLossCost() + >>> cls_pred = torch.rand(4, 3) + >>> gt_labels = torch.tensor([0, 1, 2]) + >>> factor = torch.tensor([10, 8, 10, 8]) + >>> self(cls_pred, gt_labels) + tensor([[-0.3236, -0.3364, -0.2699], + [-0.3439, -0.3209, -0.4807], + [-0.4099, -0.3795, -0.2929], + [-0.1950, -0.1207, -0.2626]]) + """ + + def __init__(self, weight=1., alpha=0.25, gamma=2, eps=1e-12): + self.weight = weight + self.alpha = alpha + self.gamma = gamma + self.eps = eps + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + + Returns: + torch.Tensor: cls_cost value with weight + """ + cls_pred = cls_pred.sigmoid() + neg_cost = -(1 - cls_pred + self.eps).log() * ( + 1 - self.alpha) * cls_pred.pow(self.gamma) + pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( + 1 - cls_pred).pow(self.gamma) + cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels] + return cls_cost * self.weight + + +@MATCH_COST.register_module() +class ClassificationCost(object): + """ClsSoftmaxCost. + + Args: + weight (int | float, optional): loss_weight + + Examples: + >>> from mmdet.core.bbox.match_costs.match_cost import \ + ... ClassificationCost + >>> import torch + >>> self = ClassificationCost() + >>> cls_pred = torch.rand(4, 3) + >>> gt_labels = torch.tensor([0, 1, 2]) + >>> factor = torch.tensor([10, 8, 10, 8]) + >>> self(cls_pred, gt_labels) + tensor([[-0.3430, -0.3525, -0.3045], + [-0.3077, -0.2931, -0.3992], + [-0.3664, -0.3455, -0.2881], + [-0.3343, -0.2701, -0.3956]]) + """ + + def __init__(self, weight=1.): + self.weight = weight + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + + Returns: + torch.Tensor: cls_cost value with weight + """ + # Following the official DETR repo, contrary to the loss that + # NLL is used, we approximate it in 1 - cls_score[gt_label]. + # The 1 is a constant that doesn't change the matching, + # so it can be ommitted. + cls_score = cls_pred.softmax(-1) + cls_cost = -cls_score[:, gt_labels] + return cls_cost * self.weight + + +@MATCH_COST.register_module() +class IoUCost(object): + """IoUCost. + + Args: + iou_mode (str, optional): iou mode such as 'iou' | 'giou' + weight (int | float, optional): loss weight + + Examples: + >>> from mmdet.core.bbox.match_costs.match_cost import IoUCost + >>> import torch + >>> self = IoUCost() + >>> bboxes = torch.FloatTensor([[1,1, 2, 2], [2, 2, 3, 4]]) + >>> gt_bboxes = torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]]) + >>> self(bboxes, gt_bboxes) + tensor([[-0.1250, 0.1667], + [ 0.1667, -0.5000]]) + """ + + def __init__(self, iou_mode='giou', weight=1.): + self.weight = weight + self.iou_mode = iou_mode + + def __call__(self, bboxes, gt_bboxes): + """ + Args: + bboxes (Tensor): Predicted boxes with unnormalized coordinates + (x1, y1, x2, y2). Shape [num_query, 4]. + gt_bboxes (Tensor): Ground truth boxes with unnormalized + coordinates (x1, y1, x2, y2). Shape [num_gt, 4]. + + Returns: + torch.Tensor: iou_cost value with weight + """ + # overlaps: [num_bboxes, num_gt] + overlaps = bbox_overlaps( + bboxes, gt_bboxes, mode=self.iou_mode, is_aligned=False) + # The 1 is a constant that doesn't change the matching, so ommitted. + iou_cost = -overlaps + return iou_cost * self.weight diff --git a/mmdet/models/dense_heads/transformer_head.py b/mmdet/models/dense_heads/transformer_head.py index df3fc9411..eab6cf0cd 100644 --- a/mmdet/models/dense_heads/transformer_head.py +++ b/mmdet/models/dense_heads/transformer_head.py @@ -77,11 +77,10 @@ class TransformerHead(AnchorFreeHead): train_cfg=dict( assigner=dict( type='HungarianAssigner', - cls_weight=1., - bbox_weight=5., - iou_weight=2., - iou_calculator=dict(type='BboxOverlaps2D'), - iou_mode='giou')), + cls_cost=dict(type='ClassificationCost', weight=1.), + reg_cost=dict(type='BBoxL1Cost', weight=5.0), + iou_cost=dict( + type='IoUCost', iou_mode='giou', weight=2.0))), test_cfg=dict(max_per_img=100), **kwargs): # NOTE here use `AnchorFreeHead` instead of `TransformerHead`, @@ -124,13 +123,13 @@ class TransformerHead(AnchorFreeHead): assert 'assigner' in train_cfg, 'assigner should be provided '\ 'when train_cfg is set.' assigner = train_cfg['assigner'] - assert loss_cls['loss_weight'] == assigner['cls_weight'], \ + assert loss_cls['loss_weight'] == assigner['cls_cost']['weight'], \ 'The classification weight for loss and matcher should be' \ 'exactly the same.' - assert loss_bbox['loss_weight'] == assigner['bbox_weight'], \ - 'The regression L1 weight for loss and matcher should be' \ - 'exactly the same.' - assert loss_iou['loss_weight'] == assigner['iou_weight'], \ + assert loss_bbox['loss_weight'] == assigner['reg_cost'][ + 'weight'], 'The regression L1 weight for loss and matcher ' \ + 'should be exactly the same.' + assert loss_iou['loss_weight'] == assigner['iou_cost']['weight'], \ 'The regression iou weight for loss and matcher should be' \ 'exactly the same.' self.assigner = build_assigner(assigner) diff --git a/tests/test_assigner.py b/tests/test_assigner.py index 8e2d4b7e2..2f7a16ff8 100644 --- a/tests/test_assigner.py +++ b/tests/test_assigner.py @@ -380,7 +380,7 @@ def test_center_region_assigner_with_empty_gts(): def test_hungarian_match_assigner(): self = HungarianAssigner() - assert self.iou_mode == 'giou' + assert self.iou_cost.iou_mode == 'giou' # test no gt bboxes bbox_pred = torch.rand((10, 4)) @@ -403,8 +403,20 @@ def test_hungarian_match_assigner(): assert (assign_result.labels > -1).sum() == gt_bboxes.size(0) # test iou mode - self = HungarianAssigner(iou_mode='iou') - assert self.iou_mode == 'iou' + self = HungarianAssigner( + iou_cost=dict(type='IoUCost', iou_mode='iou', weight=1.0)) + assert self.iou_cost.iou_mode == 'iou' + assign_result = self.assign(bbox_pred, cls_pred, gt_bboxes, gt_labels, + img_meta) + assert torch.all(assign_result.gt_inds > -1) + assert (assign_result.gt_inds > 0).sum() == gt_bboxes.size(0) + assert (assign_result.labels > -1).sum() == gt_bboxes.size(0) + + # test focal loss mode + self = HungarianAssigner( + iou_cost=dict(type='IoUCost', iou_mode='giou', weight=1.0), + cls_cost=dict(type='FocalLossCost', weight=1.)) + assert self.iou_cost.iou_mode == 'giou' assign_result = self.assign(bbox_pred, cls_pred, gt_bboxes, gt_labels, img_meta) assert torch.all(assign_result.gt_inds > -1) diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index a598c6596..04ef584e7 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -1243,11 +1243,9 @@ def test_transformer_head_loss(): train_cfg = dict( assigner=dict( type='HungarianAssigner', - cls_weight=1., - bbox_weight=5., - iou_weight=2., - iou_calculator=dict(type='BboxOverlaps2D'), - iou_mode='giou')) + cls_cost=dict(type='ClassificationCost', weight=1.0), + reg_cost=dict(type='BBoxL1Cost', weight=5.0), + iou_cost=dict(type='IoUCost', iou_mode='giou', weight=2.0))) transformer_cfg = dict( type='Transformer', embed_dims=4,