[Feature] Support DDOD: Disentangle Your Dense Object Detector(ACM MM2021 oral) (#7279)

* add ddod feature

* add ddod feature

* modify new

* [Feature] modify ddod code0225

* [Feature] modify ddod code0226

* [Feature] modify ddod code0228

* [Feature] modify ddod code0228#7279

* [Feature] modify ddod code0301

* [Feature] modify ddod code0301 test draft

* [Feature] modify ddod code0301 test

* [Feature] modify ddod code0301 extra

* [Feature] modify ddod code0301 delete src/mmtrack

* [Feature] modify ddod code0302

* [Feature] modify ddod code0302(2)

* [Feature] modify ddod code0303

* [Feature] modify ddod code0303(2)

* [Feature] modify ddod code0303(3)

* [Feature] modify ddod code0305

* [Feature] modify ddod code0305(2) delete diou

* [Feature] modify ddod code0305(3)

* modify ddod code0306

* [Feature] modify ddod code0307

* [Feature] modify ddod code0311

* [Feature] modify ddod code0311(2)

* [Feature] modify ddod code0313

* update

* [Feature] modify ddod code0319

* fix

* fix lint

* [Feature] modify ddod code0321

* update readme

* [0502] compute common vars at once for get_target

* [0504] update ddod conflicts

* [0518] seperate reg and cls loss and get_target compute

* [0518] merge ATSSCostAssigner to ATSSAssigner

* [0518] refine ATSSAssigner

* [0518] refine ATSSAssigner 2

* [0518] refine ATSSAssigner 2

* [0518] refine ATSSAssigner 3

* [0519] fix bugs

* update

* fix lr

* update weight

Co-authored-by: hha <1286304229@qq.com>
pull/7996/head^2
Irving.Gao 3 years ago committed by GitHub
parent 1fd48f7318
commit 151a803ed0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 31
      configs/ddod/README.md
  2. 67
      configs/ddod/ddod_r50_fpn_1x_coco.py
  3. 33
      configs/ddod/metafile.yml
  4. 65
      mmdet/core/bbox/assigners/atss_assigner.py
  5. 3
      mmdet/models/dense_heads/__init__.py
  6. 778
      mmdet/models/dense_heads/ddod_head.py
  7. 3
      mmdet/models/detectors/__init__.py
  8. 19
      mmdet/models/detectors/ddod.py
  9. 72
      tests/test_models/test_dense_heads/test_ddod_head.py
  10. 1
      tests/test_utils/test_assigner.py

@ -0,0 +1,31 @@
# DDOD
> [Disentangle Your Dense Object Detector](https://arxiv.org/pdf/2107.02963.pdf)
<!-- [ALGORITHM] -->
## Abstract
Deep learning-based dense object detectors have achieved great success in the past few years and have been applied to numerous multimedia applications such as video understanding. However, the current training pipeline for dense detectors is compromised to lots of conjunctions that may not hold. In this paper, we investigate three such important conjunctions: 1) only samples assigned as positive in classification head are used to train the regression head; 2) classification and regression share the same input feature and computational fields defined by the parallel head architecture; and 3) samples distributed in different feature pyramid layers are treated equally when computing the loss. We first carry out a series of pilot experiments to show disentangling such conjunctions can lead to persistent performance improvement. Then, based on these findings, we propose Disentangled Dense Object Detector(DDOD), in which simple and effective disentanglement mechanisms are designed and integrated into the current state-of-the-art dense object detectors. Extensive experiments on MS COCO benchmark show that our approach can lead to 2.0 mAP, 2.4 mAP and 2.2 mAP absolute improvements on RetinaNet, FCOS, and ATSS baselines with negligible extra overhead. Notably, our best model reaches 55.0 mAP on the COCO test-dev set and 93.5 AP on the hard subset of WIDER FACE, achieving new state-of-the-art performance on these two competitive benchmarks. Code is available at https://github.com/zehuichen123/DDOD.
<div align=center>
<img src="https://user-images.githubusercontent.com/17425982/159212920-2e99d433-82c9-46cf-8f3a-32fdf3c566f5.png"/>
</div>
## Results and Models
| Model | Backbone | Style | Lr schd | Mem (GB) | box AP | Config | Download |
| :-------: | :------: | :-----: | :-----: | :------: | :----: | :--------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| DDOD-ATSS | R-50 | pytorch | 1x | 3.4 | 41.7 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ddod/ddod_r50_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/ddod/ddod_r50_fpn_1x_coco/ddod_r50_fpn_1x_coco_20220523_223737-29b2fc67.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/ddod/ddod_r50_fpn_1x_coco/ddod_r50_fpn_1x_coco_20220523_223737.log.json) |
## Citation
```latex
@inproceedings{chen2021disentangle,
title={Disentangle Your Dense Object Detector},
author={Chen, Zehui and Yang, Chenhongyi and Li, Qiaofei and Zhao, Feng and Zha, Zheng-Jun and Wu, Feng},
booktitle={Proceedings of the 29th ACM International Conference on Multimedia},
pages={4939--4948},
year={2021}
}
```

@ -0,0 +1,67 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
type='DDOD',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5),
bbox_head=dict(
type='DDODHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
loss_iou=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
train_cfg=dict(
# assigner is mean cls_assigner
assigner=dict(type='ATSSAssigner', topk=9, alpha=0.8),
reg_assigner=dict(type='ATSSAssigner', topk=9, alpha=0.5),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))
# This `persistent_workers` is only valid when PyTorch>=1.7.0
data = dict(persistent_workers=True)
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)

@ -0,0 +1,33 @@
Collections:
- Name: DDOD
Metadata:
Training Data: COCO
Training Techniques:
- SGD with Momentum
- Weight Decay
Training Resources: 8x V100 GPUs
Architecture:
- DDOD
- FPN
- ResNet
Paper:
URL: https://arxiv.org/pdf/2107.02963.pdf
Title: 'Disentangle Your Dense Object Detector'
README: configs/ddod/README.md
Code:
URL: https://github.com/open-mmlab/mmdetection/blob/v2.25.0/mmdet/models/detectors/ddod.py#L6
Version: v2.25.0
Models:
- Name: ddod_r50_fpn_1x_coco
In Collection: DDOD
Config: configs/ddod/ddod_r50_fpn_1x_coco.py
Metadata:
Training Memory (GB): 3.4
Epochs: 12
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 41.7
Weights: https://download.openmmlab.com/mmdetection/v2.0/ddod/ddod_r50_fpn_1x_coco/ddod_r50_fpn_1x_coco_20220523_223737-29b2fc67.pth

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
from ..builder import BBOX_ASSIGNERS
@ -17,26 +19,44 @@ class ATSSAssigner(BaseAssigner):
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
If ``alpha`` is not None, it means that the dynamic cost
ATSSAssigner is adopted, which is currently only used in the DDOD.
Args:
topk (float): number of bbox selected in each level
"""
def __init__(self,
topk,
alpha=None,
iou_calculator=dict(type='BboxOverlaps2D'),
ignore_iof_thr=-1):
self.topk = topk
self.alpha = alpha
self.iou_calculator = build_iou_calculator(iou_calculator)
self.ignore_iof_thr = ignore_iof_thr
# https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py
"""Assign a corresponding gt bbox or background to each bbox.
Args:
topk (int): number of bbox selected in each level.
alpha (float): param of cost rate for each proposal only in DDOD.
Default None.
iou_calculator (dict): builder of IoU calculator.
Default dict(type='BboxOverlaps2D').
ignore_iof_thr (int): whether ignore max overlaps or not.
Default -1 (1 or -1).
"""
# https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py
def assign(self,
bboxes,
num_level_bboxes,
gt_bboxes,
gt_bboxes_ignore=None,
gt_labels=None):
gt_labels=None,
cls_scores=None,
bbox_preds=None):
"""Assign gt to bboxes.
The assignment is done in following steps
@ -52,14 +72,24 @@ class ATSSAssigner(BaseAssigner):
the threshold as positive
6. limit the positive sample's center in gt
If ``alpha`` is not None, and ``cls_scores`` and `bbox_preds`
are not None, the overlaps calculation in the first step
will also include dynamic cost, which is currently only used in
the DDOD.
Args:
bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
num_level_bboxes (List): num of bboxes in each level
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
labelled as `ignored`, e.g., crowd boxes in COCO.
labelled as `ignored`, e.g., crowd boxes in COCO. Default None.
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
cls_scores (list[Tensor]): Classification scores for all scale
levels, each is a 4D-tensor, the channels number is
num_base_priors * num_classes. Default None.
bbox_preds (list[Tensor]): Box energies / deltas for all scale
levels, each is a 4D-tensor, the channels number is
num_base_priors * 4. Default None.
Returns:
:obj:`AssignResult`: The assign result.
@ -68,8 +98,31 @@ class ATSSAssigner(BaseAssigner):
bboxes = bboxes[:, :4]
num_gt, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
# compute iou between all bbox and gt
overlaps = self.iou_calculator(bboxes, gt_bboxes)
message = 'Invalid alpha parameter because cls_scores or ' \
'bbox_preds are None. If you want to use the ' \
'cost-based ATSSAssigner, please set cls_scores, ' \
'bbox_preds and self.alpha at the same time. '
if self.alpha is None:
# ATSSAssigner
overlaps = self.iou_calculator(bboxes, gt_bboxes)
if cls_scores is not None or bbox_preds is not None:
warnings.warn(message)
else:
# Dynamic cost ATSSAssigner in DDOD
assert cls_scores is not None and bbox_preds is not None, message
# compute cls cost for bbox and GT
cls_cost = torch.sigmoid(cls_scores[:, gt_labels])
# compute iou between all bbox and gt
overlaps = self.iou_calculator(bbox_preds, gt_bboxes)
# make sure that we are in element-wise multiplication
assert cls_cost.shape == overlaps.shape
# overlaps is actually a cost matrix
overlaps = cls_cost**(1 - self.alpha) * overlaps**self.alpha
# assign 0 by default
assigned_gt_inds = overlaps.new_full((num_bboxes, ),
@ -121,6 +174,7 @@ class ATSSAssigner(BaseAssigner):
end_idx = start_idx + bboxes_per_level
distances_per_level = distances[start_idx:end_idx, :]
selectable_k = min(self.topk, bboxes_per_level)
_, topk_idxs_per_level = distances_per_level.topk(
selectable_k, dim=0, largest=False)
candidate_idxs.append(topk_idxs_per_level + start_idx)
@ -152,6 +206,7 @@ class ATSSAssigner(BaseAssigner):
r_ = gt_bboxes[:, 2] - ep_bboxes_cx[candidate_idxs].view(-1, num_gt)
b_ = gt_bboxes[:, 3] - ep_bboxes_cy[candidate_idxs].view(-1, num_gt)
is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01
is_pos = is_pos & is_in_gts
# if an anchor box is assigned to multiple gts,

@ -7,6 +7,7 @@ from .cascade_rpn_head import CascadeRPNHead, StageCascadeRPNHead
from .centernet_head import CenterNetHead
from .centripetal_head import CentripetalHead
from .corner_head import CornerHead
from .ddod_head import DDODHead
from .deformable_detr_head import DeformableDETRHead
from .detr_head import DETRHead
from .embedding_rpn_head import EmbeddingRPNHead
@ -52,5 +53,5 @@ __all__ = [
'AutoAssignHead', 'DETRHead', 'YOLOFHead', 'DeformableDETRHead',
'SOLOHead', 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead',
'DecoupledSOLOLightHead', 'LADHead', 'TOODHead', 'MaskFormerHead',
'Mask2FormerHead'
'DDODHead', 'Mask2FormerHead'
]

@ -0,0 +1,778 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, Scale, bias_init_with_prob, normal_init
from mmcv.runner import force_fp32
from mmdet.core import (anchor_inside_flags, build_assigner, build_sampler,
images_to_levels, multi_apply, reduce_mean, unmap)
from mmdet.core.bbox import bbox_overlaps
from ..builder import HEADS, build_loss
from .anchor_head import AnchorHead
EPS = 1e-12
@HEADS.register_module()
class DDODHead(AnchorHead):
"""DDOD head decomposes conjunctions lying in most current one-stage
detectors via label assignment disentanglement, spatial feature
disentanglement, and pyramid supervision disentanglement.
https://arxiv.org/abs/2107.02963
Args:
num_classes (int): Number of categories excluding the
background category.
in_channels (int): Number of channels in the input feature map.
stacked_convs (int): The number of stacked Conv. Default: 4.
conv_cfg (dict): Conv config of ddod head. Default: None.
use_dcn (bool): Use dcn, Same as ATSS when False. Default: True.
norm_cfg (dict): Normal config of ddod head. Default:
dict(type='GN', num_groups=32, requires_grad=True).
loss_iou (dict): Config of IoU loss. Default:
dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0).
"""
def __init__(self,
num_classes,
in_channels,
stacked_convs=4,
conv_cfg=None,
use_dcn=True,
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
loss_iou=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
**kwargs):
self.stacked_convs = stacked_convs
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.use_dcn = use_dcn
super(DDODHead, self).__init__(num_classes, in_channels, **kwargs)
self.sampling = False
if self.train_cfg:
self.cls_assigner = build_assigner(self.train_cfg.assigner)
self.reg_assigner = build_assigner(self.train_cfg.reg_assigner)
sampler_cfg = dict(type='PseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self)
self.loss_iou = build_loss(loss_iou)
def _init_layers(self):
"""Initialize layers of the head."""
self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=dict(type='DCN', deform_groups=1)
if i == 0 and self.use_dcn else self.conv_cfg,
norm_cfg=self.norm_cfg))
self.reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=dict(type='DCN', deform_groups=1)
if i == 0 and self.use_dcn else self.conv_cfg,
norm_cfg=self.norm_cfg))
self.atss_cls = nn.Conv2d(
self.feat_channels,
self.num_base_priors * self.cls_out_channels,
3,
padding=1)
self.atss_reg = nn.Conv2d(
self.feat_channels, self.num_base_priors * 4, 3, padding=1)
self.atss_iou = nn.Conv2d(
self.feat_channels, self.num_base_priors * 1, 3, padding=1)
self.scales = nn.ModuleList(
[Scale(1.0) for _ in self.prior_generator.strides])
# we use the global list in loss
self.cls_num_pos_samples_per_level = [
0. for _ in range(len(self.prior_generator.strides))
]
self.reg_num_pos_samples_per_level = [
0. for _ in range(len(self.prior_generator.strides))
]
def init_weights(self):
"""Initialize weights of the head."""
for m in self.cls_convs:
normal_init(m.conv, std=0.01)
for m in self.reg_convs:
normal_init(m.conv, std=0.01)
normal_init(self.atss_reg, std=0.01)
normal_init(self.atss_iou, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.atss_cls, std=0.01, bias=bias_cls)
def forward(self, feats):
"""Forward features from the upstream network.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: Usually a tuple of classification scores and bbox prediction
cls_scores (list[Tensor]): Classification scores for all scale
levels, each is a 4D-tensor, the channels number is
num_base_priors * num_classes.
bbox_preds (list[Tensor]): Box energies / deltas for all scale
levels, each is a 4D-tensor, the channels number is
num_base_priors * 4.
iou_preds (list[Tensor]): IoU scores for all scale levels,
each is a 4D-tensor, the channels number is
num_base_priors * 1.
"""
return multi_apply(self.forward_single, feats, self.scales)
def forward_single(self, x, scale):
"""Forward feature of a single scale level.
Args:
x (Tensor): Features of a single scale level.
scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
the bbox prediction.
Returns:
tuple:
- cls_score (Tensor): Cls scores for a single scale level \
the channels number is num_base_priors * num_classes.
- bbox_pred (Tensor): Box energies / deltas for a single \
scale level, the channels number is num_base_priors * 4.
- iou_pred (Tensor): Iou for a single scale level, the \
channel number is (N, num_base_priors * 1, H, W).
"""
cls_feat = x
reg_feat = x
for cls_conv in self.cls_convs:
cls_feat = cls_conv(cls_feat)
for reg_conv in self.reg_convs:
reg_feat = reg_conv(reg_feat)
cls_score = self.atss_cls(cls_feat)
# we just follow atss, not apply exp in bbox_pred
bbox_pred = scale(self.atss_reg(reg_feat)).float()
iou_pred = self.atss_iou(reg_feat)
return cls_score, bbox_pred, iou_pred
def loss_cls_single(self, cls_score, labels, label_weights,
reweight_factor, num_total_samples):
"""Compute cls loss of a single scale level.
Args:
cls_score (Tensor): Box scores for each scale level
Has shape (N, num_base_priors * num_classes, H, W).
labels (Tensor): Labels of each anchors with shape
(N, num_total_anchors).
label_weights (Tensor): Label weights of each anchor with shape
(N, num_total_anchors)
reweight_factor (list[int]): Reweight factor for cls and reg
loss.
num_total_samples (int): Number of positive samples that is
reduced over all GPUs.
Returns:
tuple[Tensor]: A tuple of loss components.
"""
cls_score = cls_score.permute(0, 2, 3, 1).reshape(
-1, self.cls_out_channels).contiguous()
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=num_total_samples)
return reweight_factor * loss_cls,
def loss_reg_single(self, anchors, bbox_pred, iou_pred, labels,
label_weights, bbox_targets, bbox_weights,
reweight_factor, num_total_samples):
"""Compute reg loss of a single scale level.
Args:
anchors (Tensor): Box reference for each scale level with shape
(N, num_total_anchors, 4).
bbox_pred (Tensor): Box energies / deltas for each scale
level with shape (N, num_base_priors * 4, H, W).
iou_pred (Tensor): Iou for a single scale level, the
channel number is (N, num_base_priors * 1, H, W).
labels (Tensor): Labels of each anchors with shape
(N, num_total_anchors).
label_weights (Tensor): Label weights of each anchor with shape
(N, num_total_anchors)
bbox_targets (Tensor): BBox regression targets of each anchor
weight shape (N, num_total_anchors, 4).
bbox_weights (Tensor): BBox weights of all anchors in the
image with shape (N, 4)
reweight_factor (list[int]): Reweight factor for cls and reg
loss.
num_total_samples (int): Number of positive samples that is
reduced over all GPUs.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
anchors = anchors.reshape(-1, 4)
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
iou_pred = iou_pred.permute(0, 2, 3, 1).reshape(-1, )
bbox_targets = bbox_targets.reshape(-1, 4)
bbox_weights = bbox_weights.reshape(-1, 4)
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
iou_targets = label_weights.new_zeros(labels.shape)
iou_weights = label_weights.new_zeros(labels.shape)
iou_weights[(bbox_weights.sum(axis=1) > 0).nonzero(
as_tuple=False)] = 1.
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
&
(labels < bg_class_ind)).nonzero(as_tuple=False).squeeze(1)
if len(pos_inds) > 0:
pos_bbox_targets = bbox_targets[pos_inds]
pos_bbox_pred = bbox_pred[pos_inds]
pos_anchors = anchors[pos_inds]
pos_decode_bbox_pred = self.bbox_coder.decode(
pos_anchors, pos_bbox_pred)
pos_decode_bbox_targets = self.bbox_coder.decode(
pos_anchors, pos_bbox_targets)
# regression loss
loss_bbox = self.loss_bbox(
pos_decode_bbox_pred,
pos_decode_bbox_targets,
avg_factor=num_total_samples)
iou_targets[pos_inds] = bbox_overlaps(
pos_decode_bbox_pred.detach(),
pos_decode_bbox_targets,
is_aligned=True)
loss_iou = self.loss_iou(
iou_pred,
iou_targets,
iou_weights,
avg_factor=num_total_samples)
else:
loss_bbox = bbox_pred.sum() * 0
loss_iou = iou_pred.sum() * 0
return reweight_factor * loss_bbox, reweight_factor * loss_iou
def calc_reweight_factor(self, labels_list):
"""Compute reweight_factor for regression and classification loss."""
# get pos samples for each level
bg_class_ind = self.num_classes
for ii, each_level_label in enumerate(labels_list):
pos_inds = ((each_level_label >= 0) &
(each_level_label < bg_class_ind)).nonzero(
as_tuple=False).squeeze(1)
self.cls_num_pos_samples_per_level[ii] += len(pos_inds)
# get reweight factor from 1 ~ 2 with bilinear interpolation
min_pos_samples = min(self.cls_num_pos_samples_per_level)
max_pos_samples = max(self.cls_num_pos_samples_per_level)
interval = 1. / (max_pos_samples - min_pos_samples + 1e-10)
reweight_factor_per_level = []
for pos_samples in self.cls_num_pos_samples_per_level:
factor = 2. - (pos_samples - min_pos_samples) * interval
reweight_factor_per_level.append(factor)
return reweight_factor_per_level
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'iou_preds'))
def loss(self,
cls_scores,
bbox_preds,
iou_preds,
gt_bboxes,
gt_labels,
img_metas,
gt_bboxes_ignore=None):
"""Compute losses of the head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_base_priors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_base_priors * 4, H, W)
iou_preds (list[Tensor]): Score factor for all scale level,
each is a 4D-tensor, has shape (batch_size, 1, H, W).
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (list[Tensor] | None): specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.prior_generator.num_levels
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas, device=device)
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
# calculate common vars for cls and reg assigners at once
targets_com = self.process_predictions_and_anchors(
anchor_list, valid_flag_list, cls_scores, bbox_preds, img_metas,
gt_bboxes_ignore)
(anchor_list, valid_flag_list, num_level_anchors_list, cls_score_list,
bbox_pred_list, gt_bboxes_ignore_list) = targets_com
# classification branch assigner
cls_targets = self.get_cls_targets(
anchor_list,
valid_flag_list,
num_level_anchors_list,
cls_score_list,
bbox_pred_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore_list,
gt_labels_list=gt_labels,
label_channels=label_channels)
if cls_targets is None:
return None
(cls_anchor_list, labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_total_pos, num_total_neg) = cls_targets
num_total_samples = reduce_mean(
torch.tensor(num_total_pos, dtype=torch.float,
device=device)).item()
num_total_samples = max(num_total_samples, 1.0)
reweight_factor_per_level = self.calc_reweight_factor(labels_list)
cls_losses_cls, = multi_apply(
self.loss_cls_single,
cls_scores,
labels_list,
label_weights_list,
reweight_factor_per_level,
num_total_samples=num_total_samples)
# regression branch assigner
reg_targets = self.get_reg_targets(
anchor_list,
valid_flag_list,
num_level_anchors_list,
cls_score_list,
bbox_pred_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore_list,
gt_labels_list=gt_labels,
label_channels=label_channels)
if reg_targets is None:
return None
(reg_anchor_list, labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_total_pos, num_total_neg) = reg_targets
num_total_samples = reduce_mean(
torch.tensor(num_total_pos, dtype=torch.float,
device=device)).item()
num_total_samples = max(num_total_samples, 1.0)
reweight_factor_per_level = self.calc_reweight_factor(labels_list)
reg_losses_bbox, reg_losses_iou = multi_apply(
self.loss_reg_single,
reg_anchor_list,
bbox_preds,
iou_preds,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
reweight_factor_per_level,
num_total_samples=num_total_samples)
return dict(
loss_cls=cls_losses_cls,
loss_bbox=reg_losses_bbox,
loss_iou=reg_losses_iou)
def process_predictions_and_anchors(self, anchor_list, valid_flag_list,
cls_scores, bbox_preds, img_metas,
gt_bboxes_ignore_list):
"""Compute common vars for regression and classification targets.
Args:
anchor_list (list[Tensor]): anchors of each image.
valid_flag_list (list[Tensor]): Valid flags of each image.
cls_scores (list[Tensor]): Classification scores for all scale
levels, each is a 4D-tensor, the channels number is
num_base_priors * num_classes.
bbox_preds (list[Tensor]): Box energies / deltas for all scale
levels, each is a 4D-tensor, the channels number is
num_base_priors * 4.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore_list (list[Tensor] | None): specify which bounding
boxes can be ignored when computing the loss.
Return:
tuple[Tensor]: A tuple of common loss vars.
"""
num_imgs = len(img_metas)
assert len(anchor_list) == len(valid_flag_list) == num_imgs
# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
num_level_anchors_list = [num_level_anchors] * num_imgs
anchor_list_ = []
valid_flag_list_ = []
# concat all level anchors and flags to a single tensor
for i in range(num_imgs):
assert len(anchor_list[i]) == len(valid_flag_list[i])
anchor_list_.append(torch.cat(anchor_list[i]))
valid_flag_list_.append(torch.cat(valid_flag_list[i]))
# compute targets for each image
if gt_bboxes_ignore_list is None:
gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
num_levels = len(cls_scores)
cls_score_list = []
bbox_pred_list = []
mlvl_cls_score_list = [
cls_score.permute(0, 2, 3, 1).reshape(
num_imgs, -1, self.num_base_priors * self.cls_out_channels)
for cls_score in cls_scores
]
mlvl_bbox_pred_list = [
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.num_base_priors * 4)
for bbox_pred in bbox_preds
]
for i in range(num_imgs):
mlvl_cls_tensor_list = [
mlvl_cls_score_list[j][i] for j in range(num_levels)
]
mlvl_bbox_tensor_list = [
mlvl_bbox_pred_list[j][i] for j in range(num_levels)
]
cat_mlvl_cls_score = torch.cat(mlvl_cls_tensor_list, dim=0)
cat_mlvl_bbox_pred = torch.cat(mlvl_bbox_tensor_list, dim=0)
cls_score_list.append(cat_mlvl_cls_score)
bbox_pred_list.append(cat_mlvl_bbox_pred)
return (anchor_list_, valid_flag_list_, num_level_anchors_list,
cls_score_list, bbox_pred_list, gt_bboxes_ignore_list)
def get_cls_targets(self,
anchor_list,
valid_flag_list,
num_level_anchors_list,
cls_score_list,
bbox_pred_list,
gt_bboxes_list,
img_metas,
gt_bboxes_ignore_list=None,
gt_labels_list=None,
label_channels=1,
unmap_outputs=True):
"""Get cls targets for DDOD head.
This method is almost the same as `AnchorHead.get_targets()`.
Besides returning the targets as the parent method does,
it also returns the anchors as the first element of the
returned tuple.
Args:
anchor_list (list[Tensor]): anchors of each image.
valid_flag_list (list[Tensor]): Valid flags of each image.
num_level_anchors_list (list[Tensor]): Number of anchors of each
scale level of all image.
cls_score_list (list[Tensor]): Classification scores for all scale
levels, each is a 4D-tensor, the channels number is
num_base_priors * num_classes.
bbox_pred_list (list[Tensor]): Box energies / deltas for all scale
levels, each is a 4D-tensor, the channels number is
num_base_priors * 4.
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore_list (list[Tensor] | None): specify which bounding
boxes can be ignored when computing the loss.
gt_labels_list (list[Tensor]): class indices corresponding to
each box.
label_channels (int): Channel of label.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors.
Return:
tuple[Tensor]: A tuple of cls targets components.
"""
(all_anchors, all_labels, all_label_weights, all_bbox_targets,
all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
self._get_target_single,
anchor_list,
valid_flag_list,
cls_score_list,
bbox_pred_list,
num_level_anchors_list,
gt_bboxes_list,
gt_bboxes_ignore_list,
gt_labels_list,
img_metas,
label_channels=label_channels,
unmap_outputs=unmap_outputs,
is_cls_assigner=True)
# no valid anchors
if any([labels is None for labels in all_labels]):
return None
# sampled anchors of all images
num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
# split targets to a list w.r.t. multiple levels
anchors_list = images_to_levels(all_anchors, num_level_anchors_list[0])
labels_list = images_to_levels(all_labels, num_level_anchors_list[0])
label_weights_list = images_to_levels(all_label_weights,
num_level_anchors_list[0])
bbox_targets_list = images_to_levels(all_bbox_targets,
num_level_anchors_list[0])
bbox_weights_list = images_to_levels(all_bbox_weights,
num_level_anchors_list[0])
return (anchors_list, labels_list, label_weights_list,
bbox_targets_list, bbox_weights_list, num_total_pos,
num_total_neg)
def get_reg_targets(self,
anchor_list,
valid_flag_list,
num_level_anchors_list,
cls_score_list,
bbox_pred_list,
gt_bboxes_list,
img_metas,
gt_bboxes_ignore_list=None,
gt_labels_list=None,
label_channels=1,
unmap_outputs=True):
"""Get reg targets for DDOD head.
This method is almost the same as `AnchorHead.get_targets()` when
is_cls_assigner is False. Besides returning the targets as the parent
method does, it also returns the anchors as the first element of the
returned tuple.
Args:
anchor_list (list[Tensor]): anchors of each image.
valid_flag_list (list[Tensor]): Valid flags of each image.
num_level_anchors (int): Number of anchors of each scale level.
cls_scores (list[Tensor]): Classification scores for all scale
levels, each is a 4D-tensor, the channels number is
num_base_priors * num_classes.
bbox_preds (list[Tensor]): Box energies / deltas for all scale
levels, each is a 4D-tensor, the channels number is
num_base_priors * 4.
gt_labels_list (list[Tensor]): class indices corresponding to
each box.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore_list (list[Tensor] | None): specify which bounding
boxes can be ignored when computing the loss.
Return:
tuple[Tensor]: A tuple of reg targets components.
"""
(all_anchors, all_labels, all_label_weights, all_bbox_targets,
all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
self._get_target_single,
anchor_list,
valid_flag_list,
cls_score_list,
bbox_pred_list,
num_level_anchors_list,
gt_bboxes_list,
gt_bboxes_ignore_list,
gt_labels_list,
img_metas,
label_channels=label_channels,
unmap_outputs=unmap_outputs,
is_cls_assigner=False)
# no valid anchors
if any([labels is None for labels in all_labels]):
return None
# sampled anchors of all images
num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
# split targets to a list w.r.t. multiple levels
anchors_list = images_to_levels(all_anchors, num_level_anchors_list[0])
labels_list = images_to_levels(all_labels, num_level_anchors_list[0])
label_weights_list = images_to_levels(all_label_weights,
num_level_anchors_list[0])
bbox_targets_list = images_to_levels(all_bbox_targets,
num_level_anchors_list[0])
bbox_weights_list = images_to_levels(all_bbox_weights,
num_level_anchors_list[0])
return (anchors_list, labels_list, label_weights_list,
bbox_targets_list, bbox_weights_list, num_total_pos,
num_total_neg)
def _get_target_single(self,
flat_anchors,
valid_flags,
cls_scores,
bbox_preds,
num_level_anchors,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
img_meta,
label_channels=1,
unmap_outputs=True,
is_cls_assigner=True):
"""Compute regression, classification targets for anchors in a single
image.
Args:
flat_anchors (Tensor): Multi-level anchors of the image,
which are concatenated into a single tensor of shape
(num_base_priors, 4).
valid_flags (Tensor): Multi level valid flags of the image,
which are concatenated into a single tensor of
shape (num_base_priors,).
cls_scores (Tensor): Classification scores for all scale
levels of the image.
bbox_preds (Tensor): Box energies / deltas for all scale
levels of the image.
num_level_anchors (list[int]): Number of anchors of each
scale level.
gt_bboxes (Tensor): Ground truth bboxes of the image,
shape (num_gts, 4).
gt_bboxes_ignore (Tensor): Ground truth bboxes to be
ignored, shape (num_ignored_gts, ).
gt_labels (Tensor): Ground truth labels of each box,
shape (num_gts, ).
img_meta (dict): Meta info of the image.
label_channels (int): Channel of label. Default: 1.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors. Default: True.
is_cls_assigner (bool): Classification or regression.
Default: True.
Returns:
tuple: N is the number of total anchors in the image.
- labels (Tensor): Labels of all anchors in the image with \
shape (N, ).
- label_weights (Tensor): Label weights of all anchor in the \
image with shape (N, ).
- bbox_targets (Tensor): BBox targets of all anchors in the \
image with shape (N, 4).
- bbox_weights (Tensor): BBox weights of all anchors in the \
image with shape (N, 4)
- pos_inds (Tensor): Indices of positive anchor with shape \
(num_pos, ).
- neg_inds (Tensor): Indices of negative anchor with shape \
(num_neg, ).
"""
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
img_meta['img_shape'][:2],
self.train_cfg.allowed_border)
if not inside_flags.any():
return (None, ) * 7
# assign gt and sample anchors
anchors = flat_anchors[inside_flags, :]
num_level_anchors_inside = self.get_num_level_anchors_inside(
num_level_anchors, inside_flags)
bbox_preds_valid = bbox_preds[inside_flags, :]
cls_scores_valid = cls_scores[inside_flags, :]
assigner = self.cls_assigner if is_cls_assigner else self.reg_assigner
# decode prediction out of assigner
bbox_preds_valid = self.bbox_coder.decode(anchors, bbox_preds_valid)
assign_result = assigner.assign(anchors, num_level_anchors_inside,
gt_bboxes, gt_bboxes_ignore, gt_labels,
cls_scores_valid, bbox_preds_valid)
sampling_result = self.sampler.sample(assign_result, anchors,
gt_bboxes)
num_valid_anchors = anchors.shape[0]
bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors)
labels = anchors.new_full((num_valid_anchors, ),
self.num_classes,
dtype=torch.long)
label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
if hasattr(self, 'bbox_coder'):
pos_bbox_targets = self.bbox_coder.encode(
sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
else:
# used in VFNetHead
pos_bbox_targets = sampling_result.pos_gt_bboxes
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
if gt_labels is None:
# Only rpn gives gt_labels as None
# Foreground is the first class since v2.5.0
labels[pos_inds] = 0
else:
labels[pos_inds] = gt_labels[
sampling_result.pos_assigned_gt_inds]
if self.train_cfg.pos_weight <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = self.train_cfg.pos_weight
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
# map up to original set of anchors
if unmap_outputs:
num_total_anchors = flat_anchors.size(0)
anchors = unmap(anchors, num_total_anchors, inside_flags)
labels = unmap(
labels, num_total_anchors, inside_flags, fill=self.num_classes)
label_weights = unmap(label_weights, num_total_anchors,
inside_flags)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
return (anchors, labels, label_weights, bbox_targets, bbox_weights,
pos_inds, neg_inds)
def get_num_level_anchors_inside(self, num_level_anchors, inside_flags):
"""Get the anchors of each scale level inside.
Args:
num_level_anchors (list[int]): Number of anchors of each
scale level.
inside_flags (Tensor): Multi level inside flags of the image,
which are concatenated into a single tensor of
shape (num_base_priors,).
Returns:
list[int]: Number of anchors of each scale level inside.
"""
split_inside_flags = torch.split(inside_flags, num_level_anchors)
num_level_anchors_inside = [
int(flags.sum()) for flags in split_inside_flags
]
return num_level_anchors_inside

@ -5,6 +5,7 @@ from .base import BaseDetector
from .cascade_rcnn import CascadeRCNN
from .centernet import CenterNet
from .cornernet import CornerNet
from .ddod import DDOD
from .deformable_detr import DeformableDETR
from .detr import DETR
from .fast_rcnn import FastRCNN
@ -52,5 +53,5 @@ __all__ = [
'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN', 'SCNet', 'SOLO',
'DeformableDETR', 'AutoAssign', 'YOLOF', 'CenterNet', 'YOLOX',
'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst', 'LAD', 'TOOD',
'MaskFormer', 'Mask2Former'
'MaskFormer', 'DDOD', 'Mask2Former'
]

@ -0,0 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
from ..builder import DETECTORS
from .single_stage import SingleStageDetector
@DETECTORS.register_module()
class DDOD(SingleStageDetector):
"""Implementation of `DDOD <https://arxiv.org/pdf/2107.02963.pdf>`_."""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
super(DDOD, self).__init__(backbone, neck, bbox_head, train_cfg,
test_cfg, pretrained, init_cfg)

@ -0,0 +1,72 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch
from mmdet.models.dense_heads import DDODHead
def test_ddod_head_loss():
"""Tests ddod head loss when truth is empty and non-empty."""
s = 256
img_metas = [{
'img_shape': (s, s, 3),
'scale_factor': 1,
'pad_shape': (s, s, 3)
}]
train_cfg = mmcv.Config(
dict( # ATSSAssigner
assigner=dict(type='ATSSAssigner', topk=9, alpha=0.8),
reg_assigner=dict(type='ATSSAssigner', topk=9, alpha=0.5),
allowed_border=-1,
pos_weight=-1,
debug=False))
self = DDODHead(
num_classes=4,
in_channels=1,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]),
train_cfg=train_cfg,
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
loss_iou=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0))
feat = [
torch.rand(1, 1, s // feat_size, s // feat_size)
for feat_size in [4, 8, 16, 32, 64]
]
cls_scores, bbox_preds, iou_preds = self.forward(feat)
# Test that empty ground truth encourages the network to predict background
gt_bboxes = [torch.empty((0, 4))]
gt_labels = [torch.LongTensor([])]
gt_bboxes_ignore = None
empty_gt_losses = self.loss(cls_scores, bbox_preds, iou_preds, gt_bboxes,
gt_labels, img_metas, gt_bboxes_ignore)
# When there is no truth, the cls loss should be nonzero but there should
# be no box loss.
empty_cls_loss = sum(empty_gt_losses['loss_cls'])
empty_box_loss = sum(empty_gt_losses['loss_bbox'])
empty_iou_loss = sum(empty_gt_losses['loss_iou'])
assert empty_cls_loss.item() > 0, 'cls loss should be non-zero'
assert empty_box_loss.item() == 0, (
'there should be no box loss when there are no true boxes')
assert empty_iou_loss.item() == 0, (
'there should be no iou loss when there are no true boxes')
# When truth is non-empty then both cls and box loss should be nonzero for
# random inputs
gt_bboxes = [
torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
]
gt_labels = [torch.LongTensor([2])]
one_gt_losses = self.loss(cls_scores, bbox_preds, iou_preds, gt_bboxes,
gt_labels, img_metas, gt_bboxes_ignore)
onegt_cls_loss = sum(one_gt_losses['loss_cls'])
onegt_box_loss = sum(one_gt_losses['loss_bbox'])
onegt_iou_loss = sum(one_gt_losses['loss_iou'])
assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero'
assert onegt_box_loss.item() > 0, 'box loss should be non-zero'
assert onegt_iou_loss.item() > 0, 'iou loss should be non-zero'

@ -402,6 +402,7 @@ def test_hungarian_match_assigner():
gt_labels = torch.LongTensor([1, 20])
assign_result = self.assign(bbox_pred, cls_pred, gt_bboxes, gt_labels,
img_meta)
assert torch.all(assign_result.gt_inds > -1)
assert (assign_result.gt_inds > 0).sum() == gt_bboxes.size(0)
assert (assign_result.labels > -1).sum() == gt_bboxes.size(0)

Loading…
Cancel
Save