implement PAA assign (#3547)

* implement PAA assign

* unpate setup.cfg

* fix multiply ioupred with clsscore

* add score voting

* add 2x and 101 configs

* add voting return

* make sklearm optional

* remove sklearm in setup.cfg

* separate pos loss calculation from reassign

* add unitest and readme

* add model url

* fix reduction and doc

* add 1.5x

* fix config base

* add r50 1.5x results

* mock skm

* change according to comment

* remove return none

* add universal partition interface

* fix docstr

* fix docstr

* fix docstr
pull/3692/head
shilong 5 years ago committed by GitHub
parent f93c00fd05
commit 000d00a06f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 22
      configs/paa/README.md
  2. 4
      configs/paa/paa_r101_fpn_1x_coco.py
  3. 3
      configs/paa/paa_r101_fpn_2x_coco.py
  4. 3
      configs/paa/paa_r50_fpn_1.5x_coco.py
  5. 70
      configs/paa/paa_r50_fpn_1x_coco.py
  6. 3
      configs/paa/paa_r50_fpn_2x_coco.py
  7. 3
      mmdet/models/dense_heads/__init__.py
  8. 620
      mmdet/models/dense_heads/paa_head.py
  9. 3
      mmdet/models/detectors/__init__.py
  10. 17
      mmdet/models/detectors/paa.py
  11. 121
      tests/test_models/test_heads.py

@ -0,0 +1,22 @@
# Probabilistic Anchor Assignment with IoU Prediction for Object Detection
## Results and Models
We provide config files to reproduce the object detection results in the
ECCV 2020 paper for Probabilistic Anchor Assignment with IoU
Prediction for Object Detection.
| Backbone | Lr schd | Mem (GB) | Score voting | box AP | Download |
|:-----------:|:-------:|:--------:|:------------:|:------:|:--------:|
| R-50-FPN | 12e | 3.7 | True | 40.4 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r50_fpn_1x_20200821-936edec3.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r50_fpn_1x_20200821-936edec3.log.json) |
| R-50-FPN | 12e | 3.7 | False | 40.2 | - |
| R-50-FPN | 18e | 3.7 | True | 41.4 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r50_fpn_1.5x_20200823-805d6078.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r50_fpn_1.5x_20200823-805d6078.log.json) |
| R-50-FPN | 18e | 3.7 | False | 41.2 | - |
| R-50-FPN | 24e | 3.7 | True | 41.6 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r50_fpn_2x_20200821-c98bfc4e.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r50_fpn_2x_20200821-c98bfc4e.log.json) |
| R-101-FPN | 12e | 6.2 | True | 42.6 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r101_fpn_1x_20200821-0a1825a4.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r101_fpn_1x_20200821-0a1825a4.log.json) |
| R-101-FPN | 12e | 6.2 | False | 42.4 | - |
| R-101-FPN | 24e | 6.2 | True | 43.5 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r101_fpn_2x_20200821-6829f96b.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r101_fpn_2x_20200821-6829f96b.log.json) |
**Note**:
1. We find that the performance is unstable with 1x setting and may fluctuate by about 0.2 mAP. We report the best results.

@ -0,0 +1,4 @@
_base_ = './paa_r50_fpn_1x_coco.py'
model = dict(pretrained='torchvision://resnet101', backbone=dict(depth=101))
lr_config = dict(step=[16, 22])
total_epochs = 24

@ -0,0 +1,3 @@
_base_ = './paa_r101_fpn_1x_coco.py'
lr_config = dict(step=[16, 22])
total_epochs = 24

@ -0,0 +1,3 @@
_base_ = './paa_r50_fpn_1x_coco.py'
lr_config = dict(step=[12, 16])
total_epochs = 18

@ -0,0 +1,70 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
type='PAA',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5),
bbox_head=dict(
type='PAAHead',
reg_decoded_bbox=True,
score_voting=True,
topk=9,
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=1.3),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)))
# training and testing settings
train_cfg = dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.1,
neg_iou_thr=0.1,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False)
test_cfg = dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100)
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)

@ -0,0 +1,3 @@
_base_ = './paa_r50_fpn_1x_coco.py'
lr_config = dict(step=[16, 22])
total_epochs = 24

@ -11,6 +11,7 @@ from .ga_rpn_head import GARPNHead
from .gfl_head import GFLHead
from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
from .nasfcos_head import NASFCOSHead
from .paa_head import PAAHead
from .pisa_retinanet_head import PISARetinaHead
from .pisa_ssd_head import PISASSDHead
from .reppoints_head import RepPointsHead
@ -24,5 +25,5 @@ __all__ = [
'RPNHead', 'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead',
'SSDHead', 'FCOSHead', 'RepPointsHead', 'FoveaHead',
'FreeAnchorRetinaHead', 'ATSSHead', 'FSAFHead', 'NASFCOSHead',
'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead'
'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'PAAHead'
]

@ -0,0 +1,620 @@
import numpy as np
import torch
from mmdet.core import force_fp32, multi_apply, multiclass_nms
from mmdet.core.bbox.iou_calculators import bbox_overlaps
from mmdet.models import HEADS
from mmdet.models.dense_heads import ATSSHead
eps = 1e-12
try:
import sklearn.mixture as skm
except ImportError:
skm = None
def levels_to_images(mlvl_tensor):
"""Concat multi-level feature maps by image.
[feature_level0, feature_level1...] -> [feature_image0, feature_image1...]
Convert the shape of each element in mlvl_tensor from (N, C, H, W) to
(N, H*W , C), then split the element to N elements with shape (H*W, C), and
concat elements in same image of all level along first dimension.
Args:
mlvl_tensor (list[torch.Tensor]): list of Tensor which collect from
corresponding level. Each element is of shape (N, C, H, W)
Returns:
list[torch.Tensor]: A list that contains N tensors and each tensor is
of shape (num_elements, C)
"""
batch_size = mlvl_tensor[0].size(0)
batch_list = [[] for _ in range(batch_size)]
channels = mlvl_tensor[0].size(1)
for t in mlvl_tensor:
t = t.permute(0, 2, 3, 1)
t = t.view(batch_size, -1, channels).contiguous()
for img in range(batch_size):
batch_list[img].append(t[img])
return [torch.cat(item, 0) for item in batch_list]
@HEADS.register_module()
class PAAHead(ATSSHead):
"""Head of PAAAssignment: Probabilistic Anchor Assignment with IoU
Prediction for Object Detection.
Code is modified from the `official github repo
<https://github.com/kkhoot/PAA/blob/master/paa_core
/modeling/rpn/paa/loss.py>`_.
More details can be found in the `paper
<https://arxiv.org/abs/2007.08103>`_ .
Args:
topk (int): Select topk samples with smallest loss in
each level.
score_voting (bool): Whether to use score voting in post-process.
"""
def __init__(self, *args, topk=9, score_voting=True, **kwargs):
# topk used in paa reassign process
self.topk = topk
self.with_score_voting = score_voting
super(PAAHead, self).__init__(*args, **kwargs)
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'iou_preds'))
def loss(self,
cls_scores,
bbox_preds,
iou_preds,
gt_bboxes,
gt_labels,
img_metas,
gt_bboxes_ignore=None):
"""Compute losses of the head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W)
iou_preds (list[Tensor]): iou_preds for each scale
level with shape (N, num_anchors * 1, H, W)
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (list[Tensor] | None): Specify which bounding
boxes can be ignored when are computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss gmm_assignment.
"""
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.anchor_generator.num_levels
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas, device=device)
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = self.get_targets(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
)
(labels, labels_weight, bboxes_target, bboxes_weight, pos_inds,
pos_gt_index) = cls_reg_targets
cls_scores = levels_to_images(cls_scores)
cls_scores = [
item.reshape(-1, self.cls_out_channels) for item in cls_scores
]
bbox_preds = levels_to_images(bbox_preds)
bbox_preds = [item.reshape(-1, 4) for item in bbox_preds]
iou_preds = levels_to_images(iou_preds)
iou_preds = [item.reshape(-1, 1) for item in iou_preds]
pos_losses_list, = multi_apply(self.get_pos_loss, anchor_list,
cls_scores, bbox_preds, labels,
labels_weight, bboxes_target,
bboxes_weight, pos_inds)
with torch.no_grad():
labels, label_weights, bbox_weights, num_pos = multi_apply(
self.paa_reassign,
pos_losses_list,
labels,
labels_weight,
bboxes_weight,
pos_inds,
pos_gt_index,
anchor_list,
)
num_pos = sum(num_pos)
# convert all tensor list to a flatten tensor
cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1))
bbox_preds = torch.cat(bbox_preds, 0).view(-1, bbox_preds[0].size(-1))
iou_preds = torch.cat(iou_preds, 0).view(-1, iou_preds[0].size(-1))
labels = torch.cat(labels, 0).view(-1)
flatten_anchors = torch.cat(
[torch.cat(item, 0) for item in anchor_list])
labels_weight = torch.cat(labels_weight, 0).view(-1)
bboxes_target = torch.cat(bboxes_target,
0).view(-1, bboxes_target[0].size(-1))
pos_inds_flatten = (
(labels >= 0)
& (labels < self.background_label)).nonzero().reshape(-1)
losses_cls = self.loss_cls(
cls_scores, labels, labels_weight, avg_factor=num_pos)
if num_pos:
pos_bbox_pred = self.bbox_coder.decode(
flatten_anchors[pos_inds_flatten],
bbox_preds[pos_inds_flatten])
pos_bbox_target = bboxes_target[pos_inds_flatten]
iou_target = bbox_overlaps(
pos_bbox_pred.detach(), pos_bbox_target, is_aligned=True)
losses_iou = self.loss_centerness(
iou_preds[pos_inds_flatten],
iou_target.unsqueeze(-1),
avg_factor=num_pos)
losses_bbox = self.loss_bbox(
pos_bbox_pred,
pos_bbox_target,
iou_target.clamp(min=eps),
avg_factor=iou_target.sum())
else:
losses_iou = iou_preds.sum() * 0
losses_bbox = bbox_preds.sum() * 0
return dict(
loss_cls=losses_cls, loss_bbox=losses_bbox, loss_iou=losses_iou)
def get_pos_loss(self, anchors, cls_score, bbox_pred, label, label_weight,
bbox_target, bbox_weight, pos_inds):
"""Calculate loss of all potential positive samples obtained from first
match process.
Args:
anchors (list[Tensor]): Anchors of each scale.
cls_score (Tensor): Box scores of single image with shape
(num_anchors, num_classes)
bbox_pred (Tensor): Box energies / deltas of single image
with shape (num_anchors, 4)
label (Tensor): classification target of each anchor with
shape (num_anchors,)
label_weight (Tensor): Classification loss weight of each
anchor with shape (num_anchors).
bbox_target (dict): Regression target of each anchor with
shape (num_anchors, 4).
bbox_weight (Tensor): Bbox weight of each anchor with shape
(num_anchors, 4).
pos_inds (Tensor): Index of all positive samples got from
first assign process.
Returns:
Tensor: Losses of all positive samples in single image.
"""
anchors_all_level = torch.cat(anchors, 0)
pos_scores = cls_score[pos_inds]
pos_bbox_pred = bbox_pred[pos_inds]
pos_label = label[pos_inds]
pos_label_weight = label_weight[pos_inds]
pos_bbox_target = bbox_target[pos_inds]
pos_bbox_weight = bbox_weight[pos_inds]
pos_anchors = anchors_all_level[pos_inds]
pos_bbox_pred = self.bbox_coder.decode(pos_anchors, pos_bbox_pred)
# to keep loss dimension
loss_cls = self.loss_cls(
pos_scores,
pos_label,
pos_label_weight,
avg_factor=self.loss_cls.loss_weight,
reduction_override='none')
loss_bbox = self.loss_bbox(
pos_bbox_pred,
pos_bbox_target,
pos_bbox_weight,
avg_factor=self.loss_cls.loss_weight,
reduction_override='none')
loss_cls = loss_cls.sum(-1)
pos_loss = loss_bbox + loss_cls
return pos_loss,
def paa_reassign(self, pos_losses, label, label_weight, bbox_weight,
pos_inds, pos_gt_inds, anchors):
"""Fit loss to GMM distribution and separate positive, ignore, negative
samples again with GMM model.
Args:
pos_losses (Tensor): Losses of all positive samples in
single image.
label (Tensor): classification target of each anchor with
shape (num_anchors,)
label_weight (Tensor): Classification loss weight of each
anchor with shape (num_anchors).
bbox_weight (Tensor): Bbox weight of each anchor with shape
(num_anchors, 4).
pos_inds (Tensor): Index of all positive samples got from
first assign process.
pos_gt_inds (Tensor): Gt_index of all positive samples got
from first assign process.
anchors (list[Tensor]): Anchors of each scale.
Returns:
tuple: Usually returns a tuple containing learning targets.
- label (Tensor): classification target of each anchor after
paa assign, with shape (num_anchors,)
- label_weight (Tensor): Classification loss weight of each
anchor after paa assign, with shape (num_anchors).
- bbox_weight (Tensor): Bbox weight of each anchor with shape
(num_anchors, 4).
- num_pos (int): The number of positive samples after paa
assign.
"""
if not len(pos_inds):
return label, label_weight, bbox_weight, 0
num_gt = pos_gt_inds.max() + 1
num_level = len(anchors)
num_anchors_each_level = [item.size(0) for item in anchors]
num_anchors_each_level.insert(0, 0)
inds_level_interval = np.cumsum(num_anchors_each_level)
pos_level_mask = []
for i in range(num_level):
mask = (pos_inds >= inds_level_interval[i]) & (
pos_inds < inds_level_interval[i + 1])
pos_level_mask.append(mask)
pos_inds_after_paa = []
ignore_inds_after_paa = []
for gt_ind in range(num_gt):
pos_inds_gmm = []
pos_loss_gmm = []
gt_mask = pos_gt_inds == gt_ind
for level in range(num_level):
level_mask = pos_level_mask[level]
level_gt_mask = level_mask & gt_mask
value, topk_inds = pos_losses[level_gt_mask].topk(
min(level_gt_mask.sum(), self.topk), largest=False)
pos_inds_gmm.append(pos_inds[level_gt_mask][topk_inds])
pos_loss_gmm.append(value)
pos_inds_gmm = torch.cat(pos_inds_gmm)
pos_loss_gmm = torch.cat(pos_loss_gmm)
# fix gmm need at least two sample
if len(pos_inds_gmm) < 2:
continue
device = pos_inds_gmm.device
pos_loss_gmm, sort_inds = pos_loss_gmm.sort()
pos_inds_gmm = pos_inds_gmm[sort_inds]
pos_loss_gmm = pos_loss_gmm.view(-1, 1).cpu().numpy()
min_loss, max_loss = pos_loss_gmm.min(), pos_loss_gmm.max()
means_init = [[min_loss], [max_loss]]
weights_init = [0.5, 0.5]
precisions_init = [[[1.0]], [[1.0]]]
if skm is None:
raise ImportError('Please run "pip install sklearn" '
'to install sklearn first.')
gmm = skm.GaussianMixture(
2,
weights_init=weights_init,
means_init=means_init,
precisions_init=precisions_init)
gmm.fit(pos_loss_gmm)
gmm_assignment = gmm.predict(pos_loss_gmm)
scores = gmm.score_samples(pos_loss_gmm)
gmm_assignment = torch.from_numpy(gmm_assignment).to(device)
scores = torch.from_numpy(scores).to(device)
pos_inds_temp, ignore_inds_temp = self.gmm_separation_scheme(
gmm_assignment, scores, pos_inds_gmm)
pos_inds_after_paa.append(pos_inds_temp)
ignore_inds_after_paa.append(ignore_inds_temp)
pos_inds_after_paa = torch.cat(pos_inds_after_paa)
ignore_inds_after_paa = torch.cat(ignore_inds_after_paa)
reassign_mask = (pos_inds.unsqueeze(1) != pos_inds_after_paa).all(1)
reassign_ids = pos_inds[reassign_mask]
label[reassign_ids] = self.background_label
label_weight[ignore_inds_after_paa] = 0
bbox_weight[reassign_ids] = 0
num_pos = len(pos_inds_after_paa)
return label, label_weight, bbox_weight, num_pos
def gmm_separation_scheme(self, gmm_assignment, scores, pos_inds_gmm):
"""A general separation scheme for gmm model.
It separates a GMM distribution of candidate samples into three
parts, 0 1 and uncertain areas, and you can implement other
separation schemes by rewriting this function.
Args:
gmm_assignment (Tensor): The prediction of GMM which is of shape
(num_samples,). The 0/1 value indicates the distribution
that each sample comes from.
scores (Tensor): The probability of sample coming from the
fit GMM distribution. The tensor is of shape (num_samples,).
pos_inds_gmm (Tensor): All the indexes of samples which are used
to fit GMM model. The tensor is of shape (num_samples,)
Returns:
tuple[Tensor]: The indices of positive and ignored samples.
- pos_inds_temp (Tensor): Indices of positive samples.
- ignore_inds_temp (Tensor): Indices of ignore samples.
"""
# The implementation is (c) in Fig.3 in origin paper intead of (b).
# You can refer to issues such as
# https://github.com/kkhoot/PAA/issues/8 and
# https://github.com/kkhoot/PAA/issues/9.
fgs = gmm_assignment == 0
if fgs.nonzero().numel():
_, pos_thr_ind = scores[fgs].topk(1)
pos_inds_temp = pos_inds_gmm[fgs][:pos_thr_ind + 1]
ignore_inds_temp = pos_inds_gmm.new_tensor([])
return pos_inds_temp, ignore_inds_temp
def get_targets(
self,
anchor_list,
valid_flag_list,
gt_bboxes_list,
img_metas,
gt_bboxes_ignore_list=None,
gt_labels_list=None,
label_channels=1,
unmap_outputs=True,
):
"""Get targets for PAA head.
This method is almost the same as `AnchorHead.get_targets()`. We direct
return the results from _get_targets_single instead map it to levels
by images_to_levels function.
Args:
anchor_list (list[list[Tensor]]): Multi level anchors of each
image. The outer list indicates images, and the inner list
corresponds to feature levels of the image. Each element of
the inner list is a tensor of shape (num_anchors, 4).
valid_flag_list (list[list[Tensor]]): Multi level valid flags of
each image. The outer list indicates images, and the inner list
corresponds to feature levels of the image. Each element of
the inner list is a tensor of shape (num_anchors, )
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
img_metas (list[dict]): Meta info of each image.
gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
ignored.
gt_labels_list (list[Tensor]): Ground truth labels of each box.
label_channels (int): Channel of label.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors.
Returns:
tuple: Usually returns a tuple containing learning targets.
- labels (list[Tensor]): Labels of all anchors, each with
shape (num_anchors,).
- label_weights (list[Tensor]): Label weights of all anchor.
each with shape (num_anchors,).
- bbox_targets (list[Tensor]): BBox targets of all anchors.
each with shape (num_anchors, 4).
- bbox_weights (list[Tensor]): BBox weights of all anchors.
each with shape (num_anchors, 4).
- pos_inds (list[Tensor]): Contains all index of positive
sample in all anchor.
- gt_inds (list[Tensor]): Contains all gt_index of positive
sample in all anchor.
"""
num_imgs = len(img_metas)
assert len(anchor_list) == len(valid_flag_list) == num_imgs
concat_anchor_list = []
concat_valid_flag_list = []
for i in range(num_imgs):
assert len(anchor_list[i]) == len(valid_flag_list[i])
concat_anchor_list.append(torch.cat(anchor_list[i]))
concat_valid_flag_list.append(torch.cat(valid_flag_list[i]))
# compute targets for each image
if gt_bboxes_ignore_list is None:
gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
if gt_labels_list is None:
gt_labels_list = [None for _ in range(num_imgs)]
results = multi_apply(
self._get_targets_single,
concat_anchor_list,
concat_valid_flag_list,
gt_bboxes_list,
gt_bboxes_ignore_list,
gt_labels_list,
img_metas,
label_channels=label_channels,
unmap_outputs=unmap_outputs)
(labels, label_weights, bbox_targets, bbox_weights, valid_pos_inds,
valid_neg_inds, sampling_result) = results
# Due to valid flag of anchors, we have to calculate the real pos_inds
# in origin anchor set.
pos_inds = []
for i, single_labels in enumerate(labels):
pos_mask = (0 <= single_labels) & (
single_labels < self.background_label)
pos_inds.append(pos_mask.nonzero().view(-1))
gt_inds = [item.pos_assigned_gt_inds for item in sampling_result]
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
gt_inds)
def _get_targets_single(self,
flat_anchors,
valid_flags,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
img_meta,
label_channels=1,
unmap_outputs=True):
"""Compute regression and classification targets for anchors in a
single image.
This method is same as `AnchorHead._get_targets_single()`.
"""
assert unmap_outputs, 'We must map outputs back to the original' \
'set of anchors in PAAhead'
return super(ATSSHead, self)._get_targets_single(
flat_anchors,
valid_flags,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
img_meta,
label_channels=1,
unmap_outputs=True)
def _get_bboxes_single(self,
cls_scores,
bbox_preds,
iou_preds,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale=False):
"""Transform outputs for a single batch item into labeled boxes.
This method is almost same as `ATSSHead._get_bboxes_single()`.
We use sqrt(iou_preds * cls_scores) in NMS process instead of just
cls_scores. Besides, score voting is used when `` score_voting``
is set to True.
"""
assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
mlvl_bboxes = []
mlvl_scores = []
mlvl_iou_preds = []
for cls_score, bbox_pred, iou_preds, anchors in zip(
cls_scores, bbox_preds, iou_preds, mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
scores = cls_score.permute(1, 2, 0).reshape(
-1, self.cls_out_channels).sigmoid()
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
iou_preds = iou_preds.permute(1, 2, 0).reshape(-1).sigmoid()
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and scores.shape[0] > nms_pre:
max_scores, _ = (scores * iou_preds[:, None]).sqrt().max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
anchors = anchors[topk_inds, :]
bbox_pred = bbox_pred[topk_inds, :]
scores = scores[topk_inds, :]
iou_preds = iou_preds[topk_inds]
bboxes = self.bbox_coder.decode(
anchors, bbox_pred, max_shape=img_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_iou_preds.append(iou_preds)
mlvl_bboxes = torch.cat(mlvl_bboxes)
if rescale:
mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores)
# Add a dummy background class to the backend when using sigmoid
# remind that we set FG labels to [0, num_class-1] since mmdet v2.0
# BG cat_id: num_class
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
mlvl_iou_preds = torch.cat(mlvl_iou_preds)
mlvl_nms_scores = (mlvl_scores * mlvl_iou_preds[:, None]).sqrt()
det_bboxes, det_labels = multiclass_nms(
mlvl_bboxes,
mlvl_nms_scores,
cfg.score_thr,
cfg.nms,
cfg.max_per_img,
score_factors=None)
if self.with_score_voting:
det_bboxes, det_labels = self.score_voting(det_bboxes, det_labels,
mlvl_bboxes,
mlvl_nms_scores,
cfg.score_thr)
return det_bboxes, det_labels
def score_voting(self, det_bboxes, det_labels, mlvl_bboxes,
mlvl_nms_scores, score_thr):
"""Implementation of score voting method works on each remaining boxes
after NMS procedure.
Args:
det_bboxes (Tensor): Remaining boxes after NMS procedure,
with shape (k, 5), each dimension means
(x1, y1, x2, y2, score).
det_labels (Tensor): The label of remaining boxes, with shape
(k, 1),Labels are 0-based.
mlvl_bboxes (Tensor): All boxes before the NMS procedure,
with shape (num_anchors,4).
mlvl_nms_scores (Tensor): The scores of all boxes which is used
in the NMS procedure, with shape (num_anchors, num_class)
mlvl_iou_preds (Tensot): The predictions of IOU of all boxes
before the NMS procedure, with shape (num_anchors, 1)
score_thr (float): The score threshold of bboxes.
Returns:
tuple: Usually returns a tuple containing voting results.
- det_bboxes_voted (Tensor): Remaining boxes after
score voting procedure, with shape (k, 5), each
dimension means (x1, y1, x2, y2, score).
- det_labels_voted (Tensor): Label of remaining bboxes
after voting, with shape (num_anchors,).
"""
candidate_mask = mlvl_nms_scores > score_thr
candidate_mask_nozeros = candidate_mask.nonzero()
candidate_inds = candidate_mask_nozeros[:, 0]
candidate_labels = candidate_mask_nozeros[:, 1]
candidate_bboxes = mlvl_bboxes[candidate_inds]
candidate_scores = mlvl_nms_scores[candidate_mask]
det_bboxes_voted = []
det_labels_voted = []
for cls in range(self.cls_out_channels):
candidate_cls_mask = candidate_labels == cls
if not candidate_cls_mask.any():
continue
candidate_cls_scores = candidate_scores[candidate_cls_mask]
candidate_cls_bboxes = candidate_bboxes[candidate_cls_mask]
det_cls_mask = det_labels == cls
det_cls_bboxes = det_bboxes[det_cls_mask].view(
-1, det_bboxes.size(-1))
det_candidate_ious = bbox_overlaps(det_cls_bboxes[:, :4],
candidate_cls_bboxes)
for det_ind in range(len(det_cls_bboxes)):
single_det_ious = det_candidate_ious[det_ind]
pos_ious_mask = single_det_ious > 0.01
pos_ious = single_det_ious[pos_ious_mask]
pos_bboxes = candidate_cls_bboxes[pos_ious_mask]
pos_scores = candidate_cls_scores[pos_ious_mask]
pis = (torch.exp(-(1 - pos_ious)**2 / 0.025) *
pos_scores)[:, None]
voted_box = torch.sum(
pis * pos_bboxes, dim=0) / torch.sum(
pis, dim=0)
voted_score = det_cls_bboxes[det_ind][-1:][None, :]
det_bboxes_voted.append(
torch.cat((voted_box[None, :], voted_score), dim=1))
det_labels_voted.append(cls)
det_bboxes_voted = torch.cat(det_bboxes_voted, dim=0)
det_labels_voted = det_labels.new_tensor(det_labels_voted)
return det_bboxes_voted, det_labels_voted

@ -13,6 +13,7 @@ from .htc import HybridTaskCascade
from .mask_rcnn import MaskRCNN
from .mask_scoring_rcnn import MaskScoringRCNN
from .nasfcos import NASFCOS
from .paa import PAA
from .point_rend import PointRend
from .reppoints_detector import RepPointsDetector
from .retinanet import RetinaNet
@ -24,5 +25,5 @@ __all__ = [
'ATSS', 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN',
'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade',
'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector',
'FOVEA', 'FSAF', 'NASFCOS', 'PointRend', 'GFL', 'CornerNet'
'FOVEA', 'FSAF', 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA'
]

@ -0,0 +1,17 @@
from ..builder import DETECTORS
from .single_stage import SingleStageDetector
@DETECTORS.register_module()
class PAA(SingleStageDetector):
"""Implementation of `PAA <https://arxiv.org/pdf/2007.08103.pdf>`_."""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(PAA, self).__init__(backbone, neck, bbox_head, train_cfg,
test_cfg, pretrained)

@ -1,14 +1,133 @@
import mmcv
import numpy as np
import torch
from mmdet.core import bbox2roi, build_assigner, build_sampler
from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
from mmdet.models.dense_heads import (AnchorHead, CornerHead, FCOSHead,
FSAFHead, GuidedAnchorHead)
FSAFHead, GuidedAnchorHead, PAAHead,
paa_head)
from mmdet.models.dense_heads.paa_head import levels_to_images
from mmdet.models.roi_heads.bbox_heads import BBoxHead
from mmdet.models.roi_heads.mask_heads import FCNMaskHead, MaskIoUHead
def test_paa_head_loss():
"""Tests paa head loss when truth is empty and non-empty."""
class mock_skm(object):
def GaussianMixture(self, *args, **kwargs):
return self
def fit(self, loss):
pass
def predict(self, loss):
components = np.zeros_like(loss, dtype=np.long)
return components.reshape(-1)
def score_samples(self, loss):
scores = np.random.random(len(loss))
return scores
paa_head.skm = mock_skm()
s = 256
img_metas = [{
'img_shape': (s, s, 3),
'scale_factor': 1,
'pad_shape': (s, s, 3)
}]
train_cfg = mmcv.Config(
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.1,
neg_iou_thr=0.1,
min_pos_iou=0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False))
# since Focal Loss is not supported on CPU
self = PAAHead(
num_classes=4,
in_channels=1,
train_cfg=train_cfg,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=1.3),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5))
feat = [
torch.rand(1, 1, s // feat_size, s // feat_size)
for feat_size in [4, 8, 16, 32, 64]
]
self.init_weights()
cls_scores, bbox_preds, iou_preds = self(feat)
# Test that empty ground truth encourages the network to predict background
gt_bboxes = [torch.empty((0, 4))]
gt_labels = [torch.LongTensor([])]
gt_bboxes_ignore = None
empty_gt_losses = self.loss(cls_scores, bbox_preds, iou_preds, gt_bboxes,
gt_labels, img_metas, gt_bboxes_ignore)
# When there is no truth, the cls loss should be nonzero but there should
# be no box loss.
empty_cls_loss = empty_gt_losses['loss_cls']
empty_box_loss = empty_gt_losses['loss_bbox']
empty_iou_loss = empty_gt_losses['loss_iou']
assert empty_cls_loss.item() > 0, 'cls loss should be non-zero'
assert empty_box_loss.item() == 0, (
'there should be no box loss when there are no true boxes')
assert empty_iou_loss.item() == 0, (
'there should be no box loss when there are no true boxes')
# When truth is non-empty then both cls and box loss should be nonzero for
# random inputs
gt_bboxes = [
torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
]
gt_labels = [torch.LongTensor([2])]
one_gt_losses = self.loss(cls_scores, bbox_preds, iou_preds, gt_bboxes,
gt_labels, img_metas, gt_bboxes_ignore)
onegt_cls_loss = one_gt_losses['loss_cls']
onegt_box_loss = one_gt_losses['loss_bbox']
onegt_iou_loss = one_gt_losses['loss_iou']
assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero'
assert onegt_box_loss.item() > 0, 'box loss should be non-zero'
assert onegt_iou_loss.item() > 0, 'box loss should be non-zero'
n, c, h, w = 10, 4, 20, 20
mlvl_tensor = [torch.ones(n, c, h, w) for i in range(5)]
results = levels_to_images(mlvl_tensor)
assert len(results) == n
assert results[0].size() == (h * w * 5, c)
assert self.with_score_voting
cls_scores = [torch.ones(4, 5, 5)]
bbox_preds = [torch.ones(4, 5, 5)]
iou_preds = [torch.ones(1, 5, 5)]
mlvl_anchors = [torch.ones(5 * 5, 4)]
img_shape = None
scale_factor = [0.5, 0.5]
cfg = mmcv.Config(
dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))
rescale = False
self._get_bboxes_single(
cls_scores,
bbox_preds,
iou_preds,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale=rescale)
def test_fcos_head_loss():
"""Tests fcos head loss when truth is empty and non-empty."""
s = 256

Loading…
Cancel
Save