From 151a803ed0119560f59dbe7b73824dbdcae08fc6 Mon Sep 17 00:00:00 2001
From: "Irving.Gao" <58903762+Irvingao@users.noreply.github.com>
Date: Tue, 24 May 2022 11:38:50 +0800
Subject: [PATCH] [Feature] Support DDOD: Disentangle Your Dense Object
Detector(ACM MM2021 oral) (#7279)
* add ddod feature
* add ddod feature
* modify new
* [Feature] modify ddod code0225
* [Feature] modify ddod code0226
* [Feature] modify ddod code0228
* [Feature] modify ddod code0228#7279
* [Feature] modify ddod code0301
* [Feature] modify ddod code0301 test draft
* [Feature] modify ddod code0301 test
* [Feature] modify ddod code0301 extra
* [Feature] modify ddod code0301 delete src/mmtrack
* [Feature] modify ddod code0302
* [Feature] modify ddod code0302(2)
* [Feature] modify ddod code0303
* [Feature] modify ddod code0303(2)
* [Feature] modify ddod code0303(3)
* [Feature] modify ddod code0305
* [Feature] modify ddod code0305(2) delete diou
* [Feature] modify ddod code0305(3)
* modify ddod code0306
* [Feature] modify ddod code0307
* [Feature] modify ddod code0311
* [Feature] modify ddod code0311(2)
* [Feature] modify ddod code0313
* update
* [Feature] modify ddod code0319
* fix
* fix lint
* [Feature] modify ddod code0321
* update readme
* [0502] compute common vars at once for get_target
* [0504] update ddod conflicts
* [0518] seperate reg and cls loss and get_target compute
* [0518] merge ATSSCostAssigner to ATSSAssigner
* [0518] refine ATSSAssigner
* [0518] refine ATSSAssigner 2
* [0518] refine ATSSAssigner 2
* [0518] refine ATSSAssigner 3
* [0519] fix bugs
* update
* fix lr
* update weight
Co-authored-by: hha <1286304229@qq.com>
---
configs/ddod/README.md | 31 +
configs/ddod/ddod_r50_fpn_1x_coco.py | 67 ++
configs/ddod/metafile.yml | 33 +
mmdet/core/bbox/assigners/atss_assigner.py | 65 +-
mmdet/models/dense_heads/__init__.py | 3 +-
mmdet/models/dense_heads/ddod_head.py | 778 ++++++++++++++++++
mmdet/models/detectors/__init__.py | 3 +-
mmdet/models/detectors/ddod.py | 19 +
.../test_dense_heads/test_ddod_head.py | 72 ++
tests/test_utils/test_assigner.py | 1 +
10 files changed, 1065 insertions(+), 7 deletions(-)
create mode 100644 configs/ddod/README.md
create mode 100644 configs/ddod/ddod_r50_fpn_1x_coco.py
create mode 100644 configs/ddod/metafile.yml
create mode 100644 mmdet/models/dense_heads/ddod_head.py
create mode 100644 mmdet/models/detectors/ddod.py
create mode 100644 tests/test_models/test_dense_heads/test_ddod_head.py
diff --git a/configs/ddod/README.md b/configs/ddod/README.md
new file mode 100644
index 000000000..9ab1f4869
--- /dev/null
+++ b/configs/ddod/README.md
@@ -0,0 +1,31 @@
+# DDOD
+
+> [Disentangle Your Dense Object Detector](https://arxiv.org/pdf/2107.02963.pdf)
+
+
+
+## Abstract
+
+Deep learning-based dense object detectors have achieved great success in the past few years and have been applied to numerous multimedia applications such as video understanding. However, the current training pipeline for dense detectors is compromised to lots of conjunctions that may not hold. In this paper, we investigate three such important conjunctions: 1) only samples assigned as positive in classification head are used to train the regression head; 2) classification and regression share the same input feature and computational fields defined by the parallel head architecture; and 3) samples distributed in different feature pyramid layers are treated equally when computing the loss. We first carry out a series of pilot experiments to show disentangling such conjunctions can lead to persistent performance improvement. Then, based on these findings, we propose Disentangled Dense Object Detector(DDOD), in which simple and effective disentanglement mechanisms are designed and integrated into the current state-of-the-art dense object detectors. Extensive experiments on MS COCO benchmark show that our approach can lead to 2.0 mAP, 2.4 mAP and 2.2 mAP absolute improvements on RetinaNet, FCOS, and ATSS baselines with negligible extra overhead. Notably, our best model reaches 55.0 mAP on the COCO test-dev set and 93.5 AP on the hard subset of WIDER FACE, achieving new state-of-the-art performance on these two competitive benchmarks. Code is available at https://github.com/zehuichen123/DDOD.
+
+
+

+
+
+## Results and Models
+
+| Model | Backbone | Style | Lr schd | Mem (GB) | box AP | Config | Download |
+| :-------: | :------: | :-----: | :-----: | :------: | :----: | :--------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| DDOD-ATSS | R-50 | pytorch | 1x | 3.4 | 41.7 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ddod/ddod_r50_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/ddod/ddod_r50_fpn_1x_coco/ddod_r50_fpn_1x_coco_20220523_223737-29b2fc67.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/ddod/ddod_r50_fpn_1x_coco/ddod_r50_fpn_1x_coco_20220523_223737.log.json) |
+
+## Citation
+
+```latex
+@inproceedings{chen2021disentangle,
+title={Disentangle Your Dense Object Detector},
+author={Chen, Zehui and Yang, Chenhongyi and Li, Qiaofei and Zhao, Feng and Zha, Zheng-Jun and Wu, Feng},
+booktitle={Proceedings of the 29th ACM International Conference on Multimedia},
+pages={4939--4948},
+year={2021}
+}
+```
diff --git a/configs/ddod/ddod_r50_fpn_1x_coco.py b/configs/ddod/ddod_r50_fpn_1x_coco.py
new file mode 100644
index 000000000..02dd2fe89
--- /dev/null
+++ b/configs/ddod/ddod_r50_fpn_1x_coco.py
@@ -0,0 +1,67 @@
+_base_ = [
+ '../_base_/datasets/coco_detection.py',
+ '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
+]
+
+model = dict(
+ type='DDOD',
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=1,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=True,
+ style='pytorch',
+ init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ start_level=1,
+ add_extra_convs='on_output',
+ num_outs=5),
+ bbox_head=dict(
+ type='DDODHead',
+ num_classes=80,
+ in_channels=256,
+ stacked_convs=4,
+ feat_channels=256,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ ratios=[1.0],
+ octave_base_scale=8,
+ scales_per_octave=1,
+ strides=[8, 16, 32, 64, 128]),
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[.0, .0, .0, .0],
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
+ loss_iou=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
+ train_cfg=dict(
+ # assigner is mean cls_assigner
+ assigner=dict(type='ATSSAssigner', topk=9, alpha=0.8),
+ reg_assigner=dict(type='ATSSAssigner', topk=9, alpha=0.5),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False),
+ test_cfg=dict(
+ nms_pre=1000,
+ min_bbox_size=0,
+ score_thr=0.05,
+ nms=dict(type='nms', iou_threshold=0.6),
+ max_per_img=100))
+
+# This `persistent_workers` is only valid when PyTorch>=1.7.0
+data = dict(persistent_workers=True)
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
diff --git a/configs/ddod/metafile.yml b/configs/ddod/metafile.yml
new file mode 100644
index 000000000..c22395002
--- /dev/null
+++ b/configs/ddod/metafile.yml
@@ -0,0 +1,33 @@
+Collections:
+ - Name: DDOD
+ Metadata:
+ Training Data: COCO
+ Training Techniques:
+ - SGD with Momentum
+ - Weight Decay
+ Training Resources: 8x V100 GPUs
+ Architecture:
+ - DDOD
+ - FPN
+ - ResNet
+ Paper:
+ URL: https://arxiv.org/pdf/2107.02963.pdf
+ Title: 'Disentangle Your Dense Object Detector'
+ README: configs/ddod/README.md
+ Code:
+ URL: https://github.com/open-mmlab/mmdetection/blob/v2.25.0/mmdet/models/detectors/ddod.py#L6
+ Version: v2.25.0
+
+Models:
+ - Name: ddod_r50_fpn_1x_coco
+ In Collection: DDOD
+ Config: configs/ddod/ddod_r50_fpn_1x_coco.py
+ Metadata:
+ Training Memory (GB): 3.4
+ Epochs: 12
+ Results:
+ - Task: Object Detection
+ Dataset: COCO
+ Metrics:
+ box AP: 41.7
+ Weights: https://download.openmmlab.com/mmdetection/v2.0/ddod/ddod_r50_fpn_1x_coco/ddod_r50_fpn_1x_coco_20220523_223737-29b2fc67.pth
diff --git a/mmdet/core/bbox/assigners/atss_assigner.py b/mmdet/core/bbox/assigners/atss_assigner.py
index 7b195303e..79c8281e5 100644
--- a/mmdet/core/bbox/assigners/atss_assigner.py
+++ b/mmdet/core/bbox/assigners/atss_assigner.py
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
import torch
from ..builder import BBOX_ASSIGNERS
@@ -17,26 +19,44 @@ class ATSSAssigner(BaseAssigner):
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
+ If ``alpha`` is not None, it means that the dynamic cost
+ ATSSAssigner is adopted, which is currently only used in the DDOD.
+
Args:
topk (float): number of bbox selected in each level
"""
def __init__(self,
topk,
+ alpha=None,
iou_calculator=dict(type='BboxOverlaps2D'),
ignore_iof_thr=-1):
self.topk = topk
+ self.alpha = alpha
self.iou_calculator = build_iou_calculator(iou_calculator)
self.ignore_iof_thr = ignore_iof_thr
- # https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py
+ """Assign a corresponding gt bbox or background to each bbox.
+ Args:
+ topk (int): number of bbox selected in each level.
+ alpha (float): param of cost rate for each proposal only in DDOD.
+ Default None.
+ iou_calculator (dict): builder of IoU calculator.
+ Default dict(type='BboxOverlaps2D').
+ ignore_iof_thr (int): whether ignore max overlaps or not.
+ Default -1 (1 or -1).
+ """
+
+ # https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py
def assign(self,
bboxes,
num_level_bboxes,
gt_bboxes,
gt_bboxes_ignore=None,
- gt_labels=None):
+ gt_labels=None,
+ cls_scores=None,
+ bbox_preds=None):
"""Assign gt to bboxes.
The assignment is done in following steps
@@ -52,14 +72,24 @@ class ATSSAssigner(BaseAssigner):
the threshold as positive
6. limit the positive sample's center in gt
+ If ``alpha`` is not None, and ``cls_scores`` and `bbox_preds`
+ are not None, the overlaps calculation in the first step
+ will also include dynamic cost, which is currently only used in
+ the DDOD.
Args:
bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
num_level_bboxes (List): num of bboxes in each level
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
- labelled as `ignored`, e.g., crowd boxes in COCO.
+ labelled as `ignored`, e.g., crowd boxes in COCO. Default None.
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
+ cls_scores (list[Tensor]): Classification scores for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * num_classes. Default None.
+ bbox_preds (list[Tensor]): Box energies / deltas for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * 4. Default None.
Returns:
:obj:`AssignResult`: The assign result.
@@ -68,8 +98,31 @@ class ATSSAssigner(BaseAssigner):
bboxes = bboxes[:, :4]
num_gt, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
- # compute iou between all bbox and gt
- overlaps = self.iou_calculator(bboxes, gt_bboxes)
+ message = 'Invalid alpha parameter because cls_scores or ' \
+ 'bbox_preds are None. If you want to use the ' \
+ 'cost-based ATSSAssigner, please set cls_scores, ' \
+ 'bbox_preds and self.alpha at the same time. '
+
+ if self.alpha is None:
+ # ATSSAssigner
+ overlaps = self.iou_calculator(bboxes, gt_bboxes)
+ if cls_scores is not None or bbox_preds is not None:
+ warnings.warn(message)
+ else:
+ # Dynamic cost ATSSAssigner in DDOD
+ assert cls_scores is not None and bbox_preds is not None, message
+
+ # compute cls cost for bbox and GT
+ cls_cost = torch.sigmoid(cls_scores[:, gt_labels])
+
+ # compute iou between all bbox and gt
+ overlaps = self.iou_calculator(bbox_preds, gt_bboxes)
+
+ # make sure that we are in element-wise multiplication
+ assert cls_cost.shape == overlaps.shape
+
+ # overlaps is actually a cost matrix
+ overlaps = cls_cost**(1 - self.alpha) * overlaps**self.alpha
# assign 0 by default
assigned_gt_inds = overlaps.new_full((num_bboxes, ),
@@ -121,6 +174,7 @@ class ATSSAssigner(BaseAssigner):
end_idx = start_idx + bboxes_per_level
distances_per_level = distances[start_idx:end_idx, :]
selectable_k = min(self.topk, bboxes_per_level)
+
_, topk_idxs_per_level = distances_per_level.topk(
selectable_k, dim=0, largest=False)
candidate_idxs.append(topk_idxs_per_level + start_idx)
@@ -152,6 +206,7 @@ class ATSSAssigner(BaseAssigner):
r_ = gt_bboxes[:, 2] - ep_bboxes_cx[candidate_idxs].view(-1, num_gt)
b_ = gt_bboxes[:, 3] - ep_bboxes_cy[candidate_idxs].view(-1, num_gt)
is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01
+
is_pos = is_pos & is_in_gts
# if an anchor box is assigned to multiple gts,
diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py
index 375197a69..1a3e31fda 100644
--- a/mmdet/models/dense_heads/__init__.py
+++ b/mmdet/models/dense_heads/__init__.py
@@ -7,6 +7,7 @@ from .cascade_rpn_head import CascadeRPNHead, StageCascadeRPNHead
from .centernet_head import CenterNetHead
from .centripetal_head import CentripetalHead
from .corner_head import CornerHead
+from .ddod_head import DDODHead
from .deformable_detr_head import DeformableDETRHead
from .detr_head import DETRHead
from .embedding_rpn_head import EmbeddingRPNHead
@@ -52,5 +53,5 @@ __all__ = [
'AutoAssignHead', 'DETRHead', 'YOLOFHead', 'DeformableDETRHead',
'SOLOHead', 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead',
'DecoupledSOLOLightHead', 'LADHead', 'TOODHead', 'MaskFormerHead',
- 'Mask2FormerHead'
+ 'DDODHead', 'Mask2FormerHead'
]
diff --git a/mmdet/models/dense_heads/ddod_head.py b/mmdet/models/dense_heads/ddod_head.py
new file mode 100644
index 000000000..b2ff22334
--- /dev/null
+++ b/mmdet/models/dense_heads/ddod_head.py
@@ -0,0 +1,778 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, Scale, bias_init_with_prob, normal_init
+from mmcv.runner import force_fp32
+
+from mmdet.core import (anchor_inside_flags, build_assigner, build_sampler,
+ images_to_levels, multi_apply, reduce_mean, unmap)
+from mmdet.core.bbox import bbox_overlaps
+from ..builder import HEADS, build_loss
+from .anchor_head import AnchorHead
+
+EPS = 1e-12
+
+
+@HEADS.register_module()
+class DDODHead(AnchorHead):
+ """DDOD head decomposes conjunctions lying in most current one-stage
+ detectors via label assignment disentanglement, spatial feature
+ disentanglement, and pyramid supervision disentanglement.
+
+ https://arxiv.org/abs/2107.02963
+
+ Args:
+ num_classes (int): Number of categories excluding the
+ background category.
+ in_channels (int): Number of channels in the input feature map.
+ stacked_convs (int): The number of stacked Conv. Default: 4.
+ conv_cfg (dict): Conv config of ddod head. Default: None.
+ use_dcn (bool): Use dcn, Same as ATSS when False. Default: True.
+ norm_cfg (dict): Normal config of ddod head. Default:
+ dict(type='GN', num_groups=32, requires_grad=True).
+ loss_iou (dict): Config of IoU loss. Default:
+ dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0).
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ stacked_convs=4,
+ conv_cfg=None,
+ use_dcn=True,
+ norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
+ loss_iou=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ **kwargs):
+ self.stacked_convs = stacked_convs
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.use_dcn = use_dcn
+ super(DDODHead, self).__init__(num_classes, in_channels, **kwargs)
+
+ self.sampling = False
+ if self.train_cfg:
+ self.cls_assigner = build_assigner(self.train_cfg.assigner)
+ self.reg_assigner = build_assigner(self.train_cfg.reg_assigner)
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+ self.loss_iou = build_loss(loss_iou)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU(inplace=True)
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=dict(type='DCN', deform_groups=1)
+ if i == 0 and self.use_dcn else self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=dict(type='DCN', deform_groups=1)
+ if i == 0 and self.use_dcn else self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.atss_cls = nn.Conv2d(
+ self.feat_channels,
+ self.num_base_priors * self.cls_out_channels,
+ 3,
+ padding=1)
+ self.atss_reg = nn.Conv2d(
+ self.feat_channels, self.num_base_priors * 4, 3, padding=1)
+ self.atss_iou = nn.Conv2d(
+ self.feat_channels, self.num_base_priors * 1, 3, padding=1)
+ self.scales = nn.ModuleList(
+ [Scale(1.0) for _ in self.prior_generator.strides])
+
+ # we use the global list in loss
+ self.cls_num_pos_samples_per_level = [
+ 0. for _ in range(len(self.prior_generator.strides))
+ ]
+ self.reg_num_pos_samples_per_level = [
+ 0. for _ in range(len(self.prior_generator.strides))
+ ]
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ for m in self.cls_convs:
+ normal_init(m.conv, std=0.01)
+ for m in self.reg_convs:
+ normal_init(m.conv, std=0.01)
+ normal_init(self.atss_reg, std=0.01)
+ normal_init(self.atss_iou, std=0.01)
+ bias_cls = bias_init_with_prob(0.01)
+ normal_init(self.atss_cls, std=0.01, bias=bias_cls)
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple: Usually a tuple of classification scores and bbox prediction
+ cls_scores (list[Tensor]): Classification scores for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * 4.
+ iou_preds (list[Tensor]): IoU scores for all scale levels,
+ each is a 4D-tensor, the channels number is
+ num_base_priors * 1.
+ """
+ return multi_apply(self.forward_single, feats, self.scales)
+
+ def forward_single(self, x, scale):
+ """Forward feature of a single scale level.
+
+ Args:
+ x (Tensor): Features of a single scale level.
+ scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
+ the bbox prediction.
+
+ Returns:
+ tuple:
+ - cls_score (Tensor): Cls scores for a single scale level \
+ the channels number is num_base_priors * num_classes.
+ - bbox_pred (Tensor): Box energies / deltas for a single \
+ scale level, the channels number is num_base_priors * 4.
+ - iou_pred (Tensor): Iou for a single scale level, the \
+ channel number is (N, num_base_priors * 1, H, W).
+ """
+ cls_feat = x
+ reg_feat = x
+ for cls_conv in self.cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs:
+ reg_feat = reg_conv(reg_feat)
+ cls_score = self.atss_cls(cls_feat)
+ # we just follow atss, not apply exp in bbox_pred
+ bbox_pred = scale(self.atss_reg(reg_feat)).float()
+ iou_pred = self.atss_iou(reg_feat)
+ return cls_score, bbox_pred, iou_pred
+
+ def loss_cls_single(self, cls_score, labels, label_weights,
+ reweight_factor, num_total_samples):
+ """Compute cls loss of a single scale level.
+
+ Args:
+ cls_score (Tensor): Box scores for each scale level
+ Has shape (N, num_base_priors * num_classes, H, W).
+ labels (Tensor): Labels of each anchors with shape
+ (N, num_total_anchors).
+ label_weights (Tensor): Label weights of each anchor with shape
+ (N, num_total_anchors)
+ reweight_factor (list[int]): Reweight factor for cls and reg
+ loss.
+ num_total_samples (int): Number of positive samples that is
+ reduced over all GPUs.
+
+ Returns:
+ tuple[Tensor]: A tuple of loss components.
+ """
+ cls_score = cls_score.permute(0, 2, 3, 1).reshape(
+ -1, self.cls_out_channels).contiguous()
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+ loss_cls = self.loss_cls(
+ cls_score, labels, label_weights, avg_factor=num_total_samples)
+ return reweight_factor * loss_cls,
+
+ def loss_reg_single(self, anchors, bbox_pred, iou_pred, labels,
+ label_weights, bbox_targets, bbox_weights,
+ reweight_factor, num_total_samples):
+ """Compute reg loss of a single scale level.
+
+ Args:
+ anchors (Tensor): Box reference for each scale level with shape
+ (N, num_total_anchors, 4).
+ bbox_pred (Tensor): Box energies / deltas for each scale
+ level with shape (N, num_base_priors * 4, H, W).
+ iou_pred (Tensor): Iou for a single scale level, the
+ channel number is (N, num_base_priors * 1, H, W).
+ labels (Tensor): Labels of each anchors with shape
+ (N, num_total_anchors).
+ label_weights (Tensor): Label weights of each anchor with shape
+ (N, num_total_anchors)
+ bbox_targets (Tensor): BBox regression targets of each anchor
+ weight shape (N, num_total_anchors, 4).
+ bbox_weights (Tensor): BBox weights of all anchors in the
+ image with shape (N, 4)
+ reweight_factor (list[int]): Reweight factor for cls and reg
+ loss.
+ num_total_samples (int): Number of positive samples that is
+ reduced over all GPUs.
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ anchors = anchors.reshape(-1, 4)
+ bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
+ iou_pred = iou_pred.permute(0, 2, 3, 1).reshape(-1, )
+ bbox_targets = bbox_targets.reshape(-1, 4)
+ bbox_weights = bbox_weights.reshape(-1, 4)
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+
+ iou_targets = label_weights.new_zeros(labels.shape)
+ iou_weights = label_weights.new_zeros(labels.shape)
+ iou_weights[(bbox_weights.sum(axis=1) > 0).nonzero(
+ as_tuple=False)] = 1.
+
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ bg_class_ind = self.num_classes
+ pos_inds = ((labels >= 0)
+ &
+ (labels < bg_class_ind)).nonzero(as_tuple=False).squeeze(1)
+
+ if len(pos_inds) > 0:
+ pos_bbox_targets = bbox_targets[pos_inds]
+ pos_bbox_pred = bbox_pred[pos_inds]
+ pos_anchors = anchors[pos_inds]
+
+ pos_decode_bbox_pred = self.bbox_coder.decode(
+ pos_anchors, pos_bbox_pred)
+ pos_decode_bbox_targets = self.bbox_coder.decode(
+ pos_anchors, pos_bbox_targets)
+
+ # regression loss
+ loss_bbox = self.loss_bbox(
+ pos_decode_bbox_pred,
+ pos_decode_bbox_targets,
+ avg_factor=num_total_samples)
+
+ iou_targets[pos_inds] = bbox_overlaps(
+ pos_decode_bbox_pred.detach(),
+ pos_decode_bbox_targets,
+ is_aligned=True)
+ loss_iou = self.loss_iou(
+ iou_pred,
+ iou_targets,
+ iou_weights,
+ avg_factor=num_total_samples)
+ else:
+ loss_bbox = bbox_pred.sum() * 0
+ loss_iou = iou_pred.sum() * 0
+
+ return reweight_factor * loss_bbox, reweight_factor * loss_iou
+
+ def calc_reweight_factor(self, labels_list):
+ """Compute reweight_factor for regression and classification loss."""
+ # get pos samples for each level
+ bg_class_ind = self.num_classes
+ for ii, each_level_label in enumerate(labels_list):
+ pos_inds = ((each_level_label >= 0) &
+ (each_level_label < bg_class_ind)).nonzero(
+ as_tuple=False).squeeze(1)
+ self.cls_num_pos_samples_per_level[ii] += len(pos_inds)
+ # get reweight factor from 1 ~ 2 with bilinear interpolation
+ min_pos_samples = min(self.cls_num_pos_samples_per_level)
+ max_pos_samples = max(self.cls_num_pos_samples_per_level)
+ interval = 1. / (max_pos_samples - min_pos_samples + 1e-10)
+ reweight_factor_per_level = []
+ for pos_samples in self.cls_num_pos_samples_per_level:
+ factor = 2. - (pos_samples - min_pos_samples) * interval
+ reweight_factor_per_level.append(factor)
+ return reweight_factor_per_level
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'iou_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ iou_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_base_priors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_base_priors * 4, H, W)
+ iou_preds (list[Tensor]): Score factor for all scale level,
+ each is a 4D-tensor, has shape (batch_size, 1, H, W).
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor] | None): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.prior_generator.num_levels
+
+ device = cls_scores[0].device
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+
+ # calculate common vars for cls and reg assigners at once
+ targets_com = self.process_predictions_and_anchors(
+ anchor_list, valid_flag_list, cls_scores, bbox_preds, img_metas,
+ gt_bboxes_ignore)
+ (anchor_list, valid_flag_list, num_level_anchors_list, cls_score_list,
+ bbox_pred_list, gt_bboxes_ignore_list) = targets_com
+
+ # classification branch assigner
+ cls_targets = self.get_cls_targets(
+ anchor_list,
+ valid_flag_list,
+ num_level_anchors_list,
+ cls_score_list,
+ bbox_pred_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore_list,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ if cls_targets is None:
+ return None
+
+ (cls_anchor_list, labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg) = cls_targets
+
+ num_total_samples = reduce_mean(
+ torch.tensor(num_total_pos, dtype=torch.float,
+ device=device)).item()
+ num_total_samples = max(num_total_samples, 1.0)
+
+ reweight_factor_per_level = self.calc_reweight_factor(labels_list)
+
+ cls_losses_cls, = multi_apply(
+ self.loss_cls_single,
+ cls_scores,
+ labels_list,
+ label_weights_list,
+ reweight_factor_per_level,
+ num_total_samples=num_total_samples)
+
+ # regression branch assigner
+ reg_targets = self.get_reg_targets(
+ anchor_list,
+ valid_flag_list,
+ num_level_anchors_list,
+ cls_score_list,
+ bbox_pred_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore_list,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ if reg_targets is None:
+ return None
+
+ (reg_anchor_list, labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg) = reg_targets
+
+ num_total_samples = reduce_mean(
+ torch.tensor(num_total_pos, dtype=torch.float,
+ device=device)).item()
+ num_total_samples = max(num_total_samples, 1.0)
+
+ reweight_factor_per_level = self.calc_reweight_factor(labels_list)
+
+ reg_losses_bbox, reg_losses_iou = multi_apply(
+ self.loss_reg_single,
+ reg_anchor_list,
+ bbox_preds,
+ iou_preds,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ reweight_factor_per_level,
+ num_total_samples=num_total_samples)
+
+ return dict(
+ loss_cls=cls_losses_cls,
+ loss_bbox=reg_losses_bbox,
+ loss_iou=reg_losses_iou)
+
+ def process_predictions_and_anchors(self, anchor_list, valid_flag_list,
+ cls_scores, bbox_preds, img_metas,
+ gt_bboxes_ignore_list):
+ """Compute common vars for regression and classification targets.
+
+ Args:
+ anchor_list (list[Tensor]): anchors of each image.
+ valid_flag_list (list[Tensor]): Valid flags of each image.
+ cls_scores (list[Tensor]): Classification scores for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * 4.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore_list (list[Tensor] | None): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Return:
+ tuple[Tensor]: A tuple of common loss vars.
+ """
+ num_imgs = len(img_metas)
+ assert len(anchor_list) == len(valid_flag_list) == num_imgs
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ num_level_anchors_list = [num_level_anchors] * num_imgs
+
+ anchor_list_ = []
+ valid_flag_list_ = []
+ # concat all level anchors and flags to a single tensor
+ for i in range(num_imgs):
+ assert len(anchor_list[i]) == len(valid_flag_list[i])
+ anchor_list_.append(torch.cat(anchor_list[i]))
+ valid_flag_list_.append(torch.cat(valid_flag_list[i]))
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+
+ num_levels = len(cls_scores)
+ cls_score_list = []
+ bbox_pred_list = []
+
+ mlvl_cls_score_list = [
+ cls_score.permute(0, 2, 3, 1).reshape(
+ num_imgs, -1, self.num_base_priors * self.cls_out_channels)
+ for cls_score in cls_scores
+ ]
+ mlvl_bbox_pred_list = [
+ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
+ self.num_base_priors * 4)
+ for bbox_pred in bbox_preds
+ ]
+
+ for i in range(num_imgs):
+ mlvl_cls_tensor_list = [
+ mlvl_cls_score_list[j][i] for j in range(num_levels)
+ ]
+ mlvl_bbox_tensor_list = [
+ mlvl_bbox_pred_list[j][i] for j in range(num_levels)
+ ]
+ cat_mlvl_cls_score = torch.cat(mlvl_cls_tensor_list, dim=0)
+ cat_mlvl_bbox_pred = torch.cat(mlvl_bbox_tensor_list, dim=0)
+ cls_score_list.append(cat_mlvl_cls_score)
+ bbox_pred_list.append(cat_mlvl_bbox_pred)
+ return (anchor_list_, valid_flag_list_, num_level_anchors_list,
+ cls_score_list, bbox_pred_list, gt_bboxes_ignore_list)
+
+ def get_cls_targets(self,
+ anchor_list,
+ valid_flag_list,
+ num_level_anchors_list,
+ cls_score_list,
+ bbox_pred_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True):
+ """Get cls targets for DDOD head.
+
+ This method is almost the same as `AnchorHead.get_targets()`.
+ Besides returning the targets as the parent method does,
+ it also returns the anchors as the first element of the
+ returned tuple.
+
+ Args:
+ anchor_list (list[Tensor]): anchors of each image.
+ valid_flag_list (list[Tensor]): Valid flags of each image.
+ num_level_anchors_list (list[Tensor]): Number of anchors of each
+ scale level of all image.
+ cls_score_list (list[Tensor]): Classification scores for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * num_classes.
+ bbox_pred_list (list[Tensor]): Box energies / deltas for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * 4.
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore_list (list[Tensor] | None): specify which bounding
+ boxes can be ignored when computing the loss.
+ gt_labels_list (list[Tensor]): class indices corresponding to
+ each box.
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Return:
+ tuple[Tensor]: A tuple of cls targets components.
+ """
+ (all_anchors, all_labels, all_label_weights, all_bbox_targets,
+ all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
+ self._get_target_single,
+ anchor_list,
+ valid_flag_list,
+ cls_score_list,
+ bbox_pred_list,
+ num_level_anchors_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ label_channels=label_channels,
+ unmap_outputs=unmap_outputs,
+ is_cls_assigner=True)
+ # no valid anchors
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled anchors of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ # split targets to a list w.r.t. multiple levels
+ anchors_list = images_to_levels(all_anchors, num_level_anchors_list[0])
+ labels_list = images_to_levels(all_labels, num_level_anchors_list[0])
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_anchors_list[0])
+ bbox_targets_list = images_to_levels(all_bbox_targets,
+ num_level_anchors_list[0])
+ bbox_weights_list = images_to_levels(all_bbox_weights,
+ num_level_anchors_list[0])
+ return (anchors_list, labels_list, label_weights_list,
+ bbox_targets_list, bbox_weights_list, num_total_pos,
+ num_total_neg)
+
+ def get_reg_targets(self,
+ anchor_list,
+ valid_flag_list,
+ num_level_anchors_list,
+ cls_score_list,
+ bbox_pred_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True):
+ """Get reg targets for DDOD head.
+
+ This method is almost the same as `AnchorHead.get_targets()` when
+ is_cls_assigner is False. Besides returning the targets as the parent
+ method does, it also returns the anchors as the first element of the
+ returned tuple.
+
+ Args:
+ anchor_list (list[Tensor]): anchors of each image.
+ valid_flag_list (list[Tensor]): Valid flags of each image.
+ num_level_anchors (int): Number of anchors of each scale level.
+ cls_scores (list[Tensor]): Classification scores for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * 4.
+ gt_labels_list (list[Tensor]): class indices corresponding to
+ each box.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore_list (list[Tensor] | None): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Return:
+ tuple[Tensor]: A tuple of reg targets components.
+ """
+ (all_anchors, all_labels, all_label_weights, all_bbox_targets,
+ all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
+ self._get_target_single,
+ anchor_list,
+ valid_flag_list,
+ cls_score_list,
+ bbox_pred_list,
+ num_level_anchors_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ label_channels=label_channels,
+ unmap_outputs=unmap_outputs,
+ is_cls_assigner=False)
+ # no valid anchors
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled anchors of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ # split targets to a list w.r.t. multiple levels
+ anchors_list = images_to_levels(all_anchors, num_level_anchors_list[0])
+ labels_list = images_to_levels(all_labels, num_level_anchors_list[0])
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_anchors_list[0])
+ bbox_targets_list = images_to_levels(all_bbox_targets,
+ num_level_anchors_list[0])
+ bbox_weights_list = images_to_levels(all_bbox_weights,
+ num_level_anchors_list[0])
+ return (anchors_list, labels_list, label_weights_list,
+ bbox_targets_list, bbox_weights_list, num_total_pos,
+ num_total_neg)
+
+ def _get_target_single(self,
+ flat_anchors,
+ valid_flags,
+ cls_scores,
+ bbox_preds,
+ num_level_anchors,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=1,
+ unmap_outputs=True,
+ is_cls_assigner=True):
+ """Compute regression, classification targets for anchors in a single
+ image.
+
+ Args:
+ flat_anchors (Tensor): Multi-level anchors of the image,
+ which are concatenated into a single tensor of shape
+ (num_base_priors, 4).
+ valid_flags (Tensor): Multi level valid flags of the image,
+ which are concatenated into a single tensor of
+ shape (num_base_priors,).
+ cls_scores (Tensor): Classification scores for all scale
+ levels of the image.
+ bbox_preds (Tensor): Box energies / deltas for all scale
+ levels of the image.
+ num_level_anchors (list[int]): Number of anchors of each
+ scale level.
+ gt_bboxes (Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 4).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, ).
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts, ).
+ img_meta (dict): Meta info of the image.
+ label_channels (int): Channel of label. Default: 1.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors. Default: True.
+ is_cls_assigner (bool): Classification or regression.
+ Default: True.
+
+ Returns:
+ tuple: N is the number of total anchors in the image.
+ - labels (Tensor): Labels of all anchors in the image with \
+ shape (N, ).
+ - label_weights (Tensor): Label weights of all anchor in the \
+ image with shape (N, ).
+ - bbox_targets (Tensor): BBox targets of all anchors in the \
+ image with shape (N, 4).
+ - bbox_weights (Tensor): BBox weights of all anchors in the \
+ image with shape (N, 4)
+ - pos_inds (Tensor): Indices of positive anchor with shape \
+ (num_pos, ).
+ - neg_inds (Tensor): Indices of negative anchor with shape \
+ (num_neg, ).
+ """
+ inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
+ img_meta['img_shape'][:2],
+ self.train_cfg.allowed_border)
+ if not inside_flags.any():
+ return (None, ) * 7
+ # assign gt and sample anchors
+ anchors = flat_anchors[inside_flags, :]
+
+ num_level_anchors_inside = self.get_num_level_anchors_inside(
+ num_level_anchors, inside_flags)
+ bbox_preds_valid = bbox_preds[inside_flags, :]
+ cls_scores_valid = cls_scores[inside_flags, :]
+
+ assigner = self.cls_assigner if is_cls_assigner else self.reg_assigner
+
+ # decode prediction out of assigner
+ bbox_preds_valid = self.bbox_coder.decode(anchors, bbox_preds_valid)
+ assign_result = assigner.assign(anchors, num_level_anchors_inside,
+ gt_bboxes, gt_bboxes_ignore, gt_labels,
+ cls_scores_valid, bbox_preds_valid)
+ sampling_result = self.sampler.sample(assign_result, anchors,
+ gt_bboxes)
+
+ num_valid_anchors = anchors.shape[0]
+ bbox_targets = torch.zeros_like(anchors)
+ bbox_weights = torch.zeros_like(anchors)
+ labels = anchors.new_full((num_valid_anchors, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ if hasattr(self, 'bbox_coder'):
+ pos_bbox_targets = self.bbox_coder.encode(
+ sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
+ else:
+ # used in VFNetHead
+ pos_bbox_targets = sampling_result.pos_gt_bboxes
+ bbox_targets[pos_inds, :] = pos_bbox_targets
+ bbox_weights[pos_inds, :] = 1.0
+ if gt_labels is None:
+ # Only rpn gives gt_labels as None
+ # Foreground is the first class since v2.5.0
+ labels[pos_inds] = 0
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if self.train_cfg.pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = self.train_cfg.pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # map up to original set of anchors
+ if unmap_outputs:
+ num_total_anchors = flat_anchors.size(0)
+ anchors = unmap(anchors, num_total_anchors, inside_flags)
+ labels = unmap(
+ labels, num_total_anchors, inside_flags, fill=self.num_classes)
+ label_weights = unmap(label_weights, num_total_anchors,
+ inside_flags)
+ bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
+ bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
+
+ return (anchors, labels, label_weights, bbox_targets, bbox_weights,
+ pos_inds, neg_inds)
+
+ def get_num_level_anchors_inside(self, num_level_anchors, inside_flags):
+ """Get the anchors of each scale level inside.
+
+ Args:
+ num_level_anchors (list[int]): Number of anchors of each
+ scale level.
+ inside_flags (Tensor): Multi level inside flags of the image,
+ which are concatenated into a single tensor of
+ shape (num_base_priors,).
+
+ Returns:
+ list[int]: Number of anchors of each scale level inside.
+ """
+ split_inside_flags = torch.split(inside_flags, num_level_anchors)
+ num_level_anchors_inside = [
+ int(flags.sum()) for flags in split_inside_flags
+ ]
+ return num_level_anchors_inside
diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py
index 5f2b3088d..5a39ccf21 100644
--- a/mmdet/models/detectors/__init__.py
+++ b/mmdet/models/detectors/__init__.py
@@ -5,6 +5,7 @@ from .base import BaseDetector
from .cascade_rcnn import CascadeRCNN
from .centernet import CenterNet
from .cornernet import CornerNet
+from .ddod import DDOD
from .deformable_detr import DeformableDETR
from .detr import DETR
from .fast_rcnn import FastRCNN
@@ -52,5 +53,5 @@ __all__ = [
'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN', 'SCNet', 'SOLO',
'DeformableDETR', 'AutoAssign', 'YOLOF', 'CenterNet', 'YOLOX',
'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst', 'LAD', 'TOOD',
- 'MaskFormer', 'Mask2Former'
+ 'MaskFormer', 'DDOD', 'Mask2Former'
]
diff --git a/mmdet/models/detectors/ddod.py b/mmdet/models/detectors/ddod.py
new file mode 100644
index 000000000..2ae0a7417
--- /dev/null
+++ b/mmdet/models/detectors/ddod.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class DDOD(SingleStageDetector):
+ """Implementation of `DDOD `_."""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(DDOD, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
diff --git a/tests/test_models/test_dense_heads/test_ddod_head.py b/tests/test_models/test_dense_heads/test_ddod_head.py
new file mode 100644
index 000000000..c9e658efa
--- /dev/null
+++ b/tests/test_models/test_dense_heads/test_ddod_head.py
@@ -0,0 +1,72 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import torch
+
+from mmdet.models.dense_heads import DDODHead
+
+
+def test_ddod_head_loss():
+ """Tests ddod head loss when truth is empty and non-empty."""
+ s = 256
+ img_metas = [{
+ 'img_shape': (s, s, 3),
+ 'scale_factor': 1,
+ 'pad_shape': (s, s, 3)
+ }]
+ train_cfg = mmcv.Config(
+ dict( # ATSSAssigner
+ assigner=dict(type='ATSSAssigner', topk=9, alpha=0.8),
+ reg_assigner=dict(type='ATSSAssigner', topk=9, alpha=0.5),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False))
+ self = DDODHead(
+ num_classes=4,
+ in_channels=1,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ ratios=[1.0],
+ octave_base_scale=8,
+ scales_per_octave=1,
+ strides=[8, 16, 32, 64, 128]),
+ train_cfg=train_cfg,
+ norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
+ loss_iou=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0))
+ feat = [
+ torch.rand(1, 1, s // feat_size, s // feat_size)
+ for feat_size in [4, 8, 16, 32, 64]
+ ]
+ cls_scores, bbox_preds, iou_preds = self.forward(feat)
+
+ # Test that empty ground truth encourages the network to predict background
+ gt_bboxes = [torch.empty((0, 4))]
+ gt_labels = [torch.LongTensor([])]
+ gt_bboxes_ignore = None
+ empty_gt_losses = self.loss(cls_scores, bbox_preds, iou_preds, gt_bboxes,
+ gt_labels, img_metas, gt_bboxes_ignore)
+ # When there is no truth, the cls loss should be nonzero but there should
+ # be no box loss.
+ empty_cls_loss = sum(empty_gt_losses['loss_cls'])
+ empty_box_loss = sum(empty_gt_losses['loss_bbox'])
+ empty_iou_loss = sum(empty_gt_losses['loss_iou'])
+ assert empty_cls_loss.item() > 0, 'cls loss should be non-zero'
+ assert empty_box_loss.item() == 0, (
+ 'there should be no box loss when there are no true boxes')
+ assert empty_iou_loss.item() == 0, (
+ 'there should be no iou loss when there are no true boxes')
+
+ # When truth is non-empty then both cls and box loss should be nonzero for
+ # random inputs
+ gt_bboxes = [
+ torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
+ ]
+ gt_labels = [torch.LongTensor([2])]
+ one_gt_losses = self.loss(cls_scores, bbox_preds, iou_preds, gt_bboxes,
+ gt_labels, img_metas, gt_bboxes_ignore)
+ onegt_cls_loss = sum(one_gt_losses['loss_cls'])
+ onegt_box_loss = sum(one_gt_losses['loss_bbox'])
+ onegt_iou_loss = sum(one_gt_losses['loss_iou'])
+ assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero'
+ assert onegt_box_loss.item() > 0, 'box loss should be non-zero'
+ assert onegt_iou_loss.item() > 0, 'iou loss should be non-zero'
diff --git a/tests/test_utils/test_assigner.py b/tests/test_utils/test_assigner.py
index 3e52cdd0b..a53d5304b 100644
--- a/tests/test_utils/test_assigner.py
+++ b/tests/test_utils/test_assigner.py
@@ -402,6 +402,7 @@ def test_hungarian_match_assigner():
gt_labels = torch.LongTensor([1, 20])
assign_result = self.assign(bbox_pred, cls_pred, gt_bboxes, gt_labels,
img_meta)
+
assert torch.all(assign_result.gt_inds > -1)
assert (assign_result.gt_inds > 0).sum() == gt_bboxes.size(0)
assert (assign_result.labels > -1).sum() == gt_bboxes.size(0)