implement PAA assign (#3547)
* implement PAA assign * unpate setup.cfg * fix multiply ioupred with clsscore * add score voting * add 2x and 101 configs * add voting return * make sklearm optional * remove sklearm in setup.cfg * separate pos loss calculation from reassign * add unitest and readme * add model url * fix reduction and doc * add 1.5x * fix config base * add r50 1.5x results * mock skm * change according to comment * remove return none * add universal partition interface * fix docstr * fix docstr * fix docstrpull/3692/head
parent
f93c00fd05
commit
000d00a06f
11 changed files with 866 additions and 3 deletions
@ -0,0 +1,22 @@ |
|||||||
|
# Probabilistic Anchor Assignment with IoU Prediction for Object Detection |
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Results and Models |
||||||
|
We provide config files to reproduce the object detection results in the |
||||||
|
ECCV 2020 paper for Probabilistic Anchor Assignment with IoU |
||||||
|
Prediction for Object Detection. |
||||||
|
|
||||||
|
| Backbone | Lr schd | Mem (GB) | Score voting | box AP | Download | |
||||||
|
|:-----------:|:-------:|:--------:|:------------:|:------:|:--------:| |
||||||
|
| R-50-FPN | 12e | 3.7 | True | 40.4 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r50_fpn_1x_20200821-936edec3.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r50_fpn_1x_20200821-936edec3.log.json) | |
||||||
|
| R-50-FPN | 12e | 3.7 | False | 40.2 | - | |
||||||
|
| R-50-FPN | 18e | 3.7 | True | 41.4 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r50_fpn_1.5x_20200823-805d6078.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r50_fpn_1.5x_20200823-805d6078.log.json) | |
||||||
|
| R-50-FPN | 18e | 3.7 | False | 41.2 | - | |
||||||
|
| R-50-FPN | 24e | 3.7 | True | 41.6 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r50_fpn_2x_20200821-c98bfc4e.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r50_fpn_2x_20200821-c98bfc4e.log.json) | |
||||||
|
| R-101-FPN | 12e | 6.2 | True | 42.6 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r101_fpn_1x_20200821-0a1825a4.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r101_fpn_1x_20200821-0a1825a4.log.json) | |
||||||
|
| R-101-FPN | 12e | 6.2 | False | 42.4 | - | |
||||||
|
| R-101-FPN | 24e | 6.2 | True | 43.5 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r101_fpn_2x_20200821-6829f96b.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r101_fpn_2x_20200821-6829f96b.log.json) | |
||||||
|
|
||||||
|
**Note**: |
||||||
|
1. We find that the performance is unstable with 1x setting and may fluctuate by about 0.2 mAP. We report the best results. |
@ -0,0 +1,4 @@ |
|||||||
|
_base_ = './paa_r50_fpn_1x_coco.py' |
||||||
|
model = dict(pretrained='torchvision://resnet101', backbone=dict(depth=101)) |
||||||
|
lr_config = dict(step=[16, 22]) |
||||||
|
total_epochs = 24 |
@ -0,0 +1,3 @@ |
|||||||
|
_base_ = './paa_r101_fpn_1x_coco.py' |
||||||
|
lr_config = dict(step=[16, 22]) |
||||||
|
total_epochs = 24 |
@ -0,0 +1,3 @@ |
|||||||
|
_base_ = './paa_r50_fpn_1x_coco.py' |
||||||
|
lr_config = dict(step=[12, 16]) |
||||||
|
total_epochs = 18 |
@ -0,0 +1,70 @@ |
|||||||
|
_base_ = [ |
||||||
|
'../_base_/datasets/coco_detection.py', |
||||||
|
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' |
||||||
|
] |
||||||
|
model = dict( |
||||||
|
type='PAA', |
||||||
|
pretrained='torchvision://resnet50', |
||||||
|
backbone=dict( |
||||||
|
type='ResNet', |
||||||
|
depth=50, |
||||||
|
num_stages=4, |
||||||
|
out_indices=(0, 1, 2, 3), |
||||||
|
frozen_stages=1, |
||||||
|
norm_cfg=dict(type='BN', requires_grad=True), |
||||||
|
norm_eval=True, |
||||||
|
style='pytorch'), |
||||||
|
neck=dict( |
||||||
|
type='FPN', |
||||||
|
in_channels=[256, 512, 1024, 2048], |
||||||
|
out_channels=256, |
||||||
|
start_level=1, |
||||||
|
add_extra_convs='on_output', |
||||||
|
num_outs=5), |
||||||
|
bbox_head=dict( |
||||||
|
type='PAAHead', |
||||||
|
reg_decoded_bbox=True, |
||||||
|
score_voting=True, |
||||||
|
topk=9, |
||||||
|
num_classes=80, |
||||||
|
in_channels=256, |
||||||
|
stacked_convs=4, |
||||||
|
feat_channels=256, |
||||||
|
anchor_generator=dict( |
||||||
|
type='AnchorGenerator', |
||||||
|
ratios=[1.0], |
||||||
|
octave_base_scale=8, |
||||||
|
scales_per_octave=1, |
||||||
|
strides=[8, 16, 32, 64, 128]), |
||||||
|
bbox_coder=dict( |
||||||
|
type='DeltaXYWHBBoxCoder', |
||||||
|
target_means=[.0, .0, .0, .0], |
||||||
|
target_stds=[0.1, 0.1, 0.2, 0.2]), |
||||||
|
loss_cls=dict( |
||||||
|
type='FocalLoss', |
||||||
|
use_sigmoid=True, |
||||||
|
gamma=2.0, |
||||||
|
alpha=0.25, |
||||||
|
loss_weight=1.0), |
||||||
|
loss_bbox=dict(type='GIoULoss', loss_weight=1.3), |
||||||
|
loss_centerness=dict( |
||||||
|
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5))) |
||||||
|
# training and testing settings |
||||||
|
train_cfg = dict( |
||||||
|
assigner=dict( |
||||||
|
type='MaxIoUAssigner', |
||||||
|
pos_iou_thr=0.1, |
||||||
|
neg_iou_thr=0.1, |
||||||
|
min_pos_iou=0, |
||||||
|
ignore_iof_thr=-1), |
||||||
|
allowed_border=-1, |
||||||
|
pos_weight=-1, |
||||||
|
debug=False) |
||||||
|
test_cfg = dict( |
||||||
|
nms_pre=1000, |
||||||
|
min_bbox_size=0, |
||||||
|
score_thr=0.05, |
||||||
|
nms=dict(type='nms', iou_threshold=0.6), |
||||||
|
max_per_img=100) |
||||||
|
# optimizer |
||||||
|
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) |
@ -0,0 +1,3 @@ |
|||||||
|
_base_ = './paa_r50_fpn_1x_coco.py' |
||||||
|
lr_config = dict(step=[16, 22]) |
||||||
|
total_epochs = 24 |
@ -0,0 +1,620 @@ |
|||||||
|
import numpy as np |
||||||
|
import torch |
||||||
|
|
||||||
|
from mmdet.core import force_fp32, multi_apply, multiclass_nms |
||||||
|
from mmdet.core.bbox.iou_calculators import bbox_overlaps |
||||||
|
from mmdet.models import HEADS |
||||||
|
from mmdet.models.dense_heads import ATSSHead |
||||||
|
|
||||||
|
eps = 1e-12 |
||||||
|
try: |
||||||
|
import sklearn.mixture as skm |
||||||
|
except ImportError: |
||||||
|
skm = None |
||||||
|
|
||||||
|
|
||||||
|
def levels_to_images(mlvl_tensor): |
||||||
|
"""Concat multi-level feature maps by image. |
||||||
|
|
||||||
|
[feature_level0, feature_level1...] -> [feature_image0, feature_image1...] |
||||||
|
Convert the shape of each element in mlvl_tensor from (N, C, H, W) to |
||||||
|
(N, H*W , C), then split the element to N elements with shape (H*W, C), and |
||||||
|
concat elements in same image of all level along first dimension. |
||||||
|
|
||||||
|
Args: |
||||||
|
mlvl_tensor (list[torch.Tensor]): list of Tensor which collect from |
||||||
|
corresponding level. Each element is of shape (N, C, H, W) |
||||||
|
|
||||||
|
Returns: |
||||||
|
list[torch.Tensor]: A list that contains N tensors and each tensor is |
||||||
|
of shape (num_elements, C) |
||||||
|
""" |
||||||
|
batch_size = mlvl_tensor[0].size(0) |
||||||
|
batch_list = [[] for _ in range(batch_size)] |
||||||
|
channels = mlvl_tensor[0].size(1) |
||||||
|
for t in mlvl_tensor: |
||||||
|
t = t.permute(0, 2, 3, 1) |
||||||
|
t = t.view(batch_size, -1, channels).contiguous() |
||||||
|
for img in range(batch_size): |
||||||
|
batch_list[img].append(t[img]) |
||||||
|
return [torch.cat(item, 0) for item in batch_list] |
||||||
|
|
||||||
|
|
||||||
|
@HEADS.register_module() |
||||||
|
class PAAHead(ATSSHead): |
||||||
|
"""Head of PAAAssignment: Probabilistic Anchor Assignment with IoU |
||||||
|
Prediction for Object Detection. |
||||||
|
|
||||||
|
Code is modified from the `official github repo |
||||||
|
<https://github.com/kkhoot/PAA/blob/master/paa_core |
||||||
|
/modeling/rpn/paa/loss.py>`_. |
||||||
|
|
||||||
|
More details can be found in the `paper |
||||||
|
<https://arxiv.org/abs/2007.08103>`_ . |
||||||
|
|
||||||
|
Args: |
||||||
|
topk (int): Select topk samples with smallest loss in |
||||||
|
each level. |
||||||
|
score_voting (bool): Whether to use score voting in post-process. |
||||||
|
""" |
||||||
|
|
||||||
|
def __init__(self, *args, topk=9, score_voting=True, **kwargs): |
||||||
|
# topk used in paa reassign process |
||||||
|
self.topk = topk |
||||||
|
self.with_score_voting = score_voting |
||||||
|
super(PAAHead, self).__init__(*args, **kwargs) |
||||||
|
|
||||||
|
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'iou_preds')) |
||||||
|
def loss(self, |
||||||
|
cls_scores, |
||||||
|
bbox_preds, |
||||||
|
iou_preds, |
||||||
|
gt_bboxes, |
||||||
|
gt_labels, |
||||||
|
img_metas, |
||||||
|
gt_bboxes_ignore=None): |
||||||
|
"""Compute losses of the head. |
||||||
|
|
||||||
|
Args: |
||||||
|
cls_scores (list[Tensor]): Box scores for each scale level |
||||||
|
Has shape (N, num_anchors * num_classes, H, W) |
||||||
|
bbox_preds (list[Tensor]): Box energies / deltas for each scale |
||||||
|
level with shape (N, num_anchors * 4, H, W) |
||||||
|
iou_preds (list[Tensor]): iou_preds for each scale |
||||||
|
level with shape (N, num_anchors * 1, H, W) |
||||||
|
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with |
||||||
|
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. |
||||||
|
gt_labels (list[Tensor]): class indices corresponding to each box |
||||||
|
img_metas (list[dict]): Meta information of each image, e.g., |
||||||
|
image size, scaling factor, etc. |
||||||
|
gt_bboxes_ignore (list[Tensor] | None): Specify which bounding |
||||||
|
boxes can be ignored when are computing the loss. |
||||||
|
|
||||||
|
Returns: |
||||||
|
dict[str, Tensor]: A dictionary of loss gmm_assignment. |
||||||
|
""" |
||||||
|
|
||||||
|
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] |
||||||
|
assert len(featmap_sizes) == self.anchor_generator.num_levels |
||||||
|
|
||||||
|
device = cls_scores[0].device |
||||||
|
anchor_list, valid_flag_list = self.get_anchors( |
||||||
|
featmap_sizes, img_metas, device=device) |
||||||
|
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 |
||||||
|
cls_reg_targets = self.get_targets( |
||||||
|
anchor_list, |
||||||
|
valid_flag_list, |
||||||
|
gt_bboxes, |
||||||
|
img_metas, |
||||||
|
gt_bboxes_ignore_list=gt_bboxes_ignore, |
||||||
|
gt_labels_list=gt_labels, |
||||||
|
label_channels=label_channels, |
||||||
|
) |
||||||
|
(labels, labels_weight, bboxes_target, bboxes_weight, pos_inds, |
||||||
|
pos_gt_index) = cls_reg_targets |
||||||
|
cls_scores = levels_to_images(cls_scores) |
||||||
|
cls_scores = [ |
||||||
|
item.reshape(-1, self.cls_out_channels) for item in cls_scores |
||||||
|
] |
||||||
|
bbox_preds = levels_to_images(bbox_preds) |
||||||
|
bbox_preds = [item.reshape(-1, 4) for item in bbox_preds] |
||||||
|
iou_preds = levels_to_images(iou_preds) |
||||||
|
iou_preds = [item.reshape(-1, 1) for item in iou_preds] |
||||||
|
|
||||||
|
pos_losses_list, = multi_apply(self.get_pos_loss, anchor_list, |
||||||
|
cls_scores, bbox_preds, labels, |
||||||
|
labels_weight, bboxes_target, |
||||||
|
bboxes_weight, pos_inds) |
||||||
|
|
||||||
|
with torch.no_grad(): |
||||||
|
labels, label_weights, bbox_weights, num_pos = multi_apply( |
||||||
|
self.paa_reassign, |
||||||
|
pos_losses_list, |
||||||
|
labels, |
||||||
|
labels_weight, |
||||||
|
bboxes_weight, |
||||||
|
pos_inds, |
||||||
|
pos_gt_index, |
||||||
|
anchor_list, |
||||||
|
) |
||||||
|
num_pos = sum(num_pos) |
||||||
|
# convert all tensor list to a flatten tensor |
||||||
|
cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1)) |
||||||
|
bbox_preds = torch.cat(bbox_preds, 0).view(-1, bbox_preds[0].size(-1)) |
||||||
|
iou_preds = torch.cat(iou_preds, 0).view(-1, iou_preds[0].size(-1)) |
||||||
|
labels = torch.cat(labels, 0).view(-1) |
||||||
|
flatten_anchors = torch.cat( |
||||||
|
[torch.cat(item, 0) for item in anchor_list]) |
||||||
|
labels_weight = torch.cat(labels_weight, 0).view(-1) |
||||||
|
bboxes_target = torch.cat(bboxes_target, |
||||||
|
0).view(-1, bboxes_target[0].size(-1)) |
||||||
|
|
||||||
|
pos_inds_flatten = ( |
||||||
|
(labels >= 0) |
||||||
|
& (labels < self.background_label)).nonzero().reshape(-1) |
||||||
|
|
||||||
|
losses_cls = self.loss_cls( |
||||||
|
cls_scores, labels, labels_weight, avg_factor=num_pos) |
||||||
|
if num_pos: |
||||||
|
pos_bbox_pred = self.bbox_coder.decode( |
||||||
|
flatten_anchors[pos_inds_flatten], |
||||||
|
bbox_preds[pos_inds_flatten]) |
||||||
|
pos_bbox_target = bboxes_target[pos_inds_flatten] |
||||||
|
iou_target = bbox_overlaps( |
||||||
|
pos_bbox_pred.detach(), pos_bbox_target, is_aligned=True) |
||||||
|
losses_iou = self.loss_centerness( |
||||||
|
iou_preds[pos_inds_flatten], |
||||||
|
iou_target.unsqueeze(-1), |
||||||
|
avg_factor=num_pos) |
||||||
|
losses_bbox = self.loss_bbox( |
||||||
|
pos_bbox_pred, |
||||||
|
pos_bbox_target, |
||||||
|
iou_target.clamp(min=eps), |
||||||
|
avg_factor=iou_target.sum()) |
||||||
|
else: |
||||||
|
losses_iou = iou_preds.sum() * 0 |
||||||
|
losses_bbox = bbox_preds.sum() * 0 |
||||||
|
|
||||||
|
return dict( |
||||||
|
loss_cls=losses_cls, loss_bbox=losses_bbox, loss_iou=losses_iou) |
||||||
|
|
||||||
|
def get_pos_loss(self, anchors, cls_score, bbox_pred, label, label_weight, |
||||||
|
bbox_target, bbox_weight, pos_inds): |
||||||
|
"""Calculate loss of all potential positive samples obtained from first |
||||||
|
match process. |
||||||
|
|
||||||
|
Args: |
||||||
|
anchors (list[Tensor]): Anchors of each scale. |
||||||
|
cls_score (Tensor): Box scores of single image with shape |
||||||
|
(num_anchors, num_classes) |
||||||
|
bbox_pred (Tensor): Box energies / deltas of single image |
||||||
|
with shape (num_anchors, 4) |
||||||
|
label (Tensor): classification target of each anchor with |
||||||
|
shape (num_anchors,) |
||||||
|
label_weight (Tensor): Classification loss weight of each |
||||||
|
anchor with shape (num_anchors). |
||||||
|
bbox_target (dict): Regression target of each anchor with |
||||||
|
shape (num_anchors, 4). |
||||||
|
bbox_weight (Tensor): Bbox weight of each anchor with shape |
||||||
|
(num_anchors, 4). |
||||||
|
pos_inds (Tensor): Index of all positive samples got from |
||||||
|
first assign process. |
||||||
|
|
||||||
|
Returns: |
||||||
|
Tensor: Losses of all positive samples in single image. |
||||||
|
""" |
||||||
|
anchors_all_level = torch.cat(anchors, 0) |
||||||
|
pos_scores = cls_score[pos_inds] |
||||||
|
pos_bbox_pred = bbox_pred[pos_inds] |
||||||
|
pos_label = label[pos_inds] |
||||||
|
pos_label_weight = label_weight[pos_inds] |
||||||
|
pos_bbox_target = bbox_target[pos_inds] |
||||||
|
pos_bbox_weight = bbox_weight[pos_inds] |
||||||
|
pos_anchors = anchors_all_level[pos_inds] |
||||||
|
pos_bbox_pred = self.bbox_coder.decode(pos_anchors, pos_bbox_pred) |
||||||
|
|
||||||
|
# to keep loss dimension |
||||||
|
loss_cls = self.loss_cls( |
||||||
|
pos_scores, |
||||||
|
pos_label, |
||||||
|
pos_label_weight, |
||||||
|
avg_factor=self.loss_cls.loss_weight, |
||||||
|
reduction_override='none') |
||||||
|
|
||||||
|
loss_bbox = self.loss_bbox( |
||||||
|
pos_bbox_pred, |
||||||
|
pos_bbox_target, |
||||||
|
pos_bbox_weight, |
||||||
|
avg_factor=self.loss_cls.loss_weight, |
||||||
|
reduction_override='none') |
||||||
|
|
||||||
|
loss_cls = loss_cls.sum(-1) |
||||||
|
pos_loss = loss_bbox + loss_cls |
||||||
|
return pos_loss, |
||||||
|
|
||||||
|
def paa_reassign(self, pos_losses, label, label_weight, bbox_weight, |
||||||
|
pos_inds, pos_gt_inds, anchors): |
||||||
|
"""Fit loss to GMM distribution and separate positive, ignore, negative |
||||||
|
samples again with GMM model. |
||||||
|
|
||||||
|
Args: |
||||||
|
pos_losses (Tensor): Losses of all positive samples in |
||||||
|
single image. |
||||||
|
label (Tensor): classification target of each anchor with |
||||||
|
shape (num_anchors,) |
||||||
|
label_weight (Tensor): Classification loss weight of each |
||||||
|
anchor with shape (num_anchors). |
||||||
|
bbox_weight (Tensor): Bbox weight of each anchor with shape |
||||||
|
(num_anchors, 4). |
||||||
|
pos_inds (Tensor): Index of all positive samples got from |
||||||
|
first assign process. |
||||||
|
pos_gt_inds (Tensor): Gt_index of all positive samples got |
||||||
|
from first assign process. |
||||||
|
anchors (list[Tensor]): Anchors of each scale. |
||||||
|
|
||||||
|
Returns: |
||||||
|
tuple: Usually returns a tuple containing learning targets. |
||||||
|
|
||||||
|
- label (Tensor): classification target of each anchor after |
||||||
|
paa assign, with shape (num_anchors,) |
||||||
|
- label_weight (Tensor): Classification loss weight of each |
||||||
|
anchor after paa assign, with shape (num_anchors). |
||||||
|
- bbox_weight (Tensor): Bbox weight of each anchor with shape |
||||||
|
(num_anchors, 4). |
||||||
|
- num_pos (int): The number of positive samples after paa |
||||||
|
assign. |
||||||
|
""" |
||||||
|
if not len(pos_inds): |
||||||
|
return label, label_weight, bbox_weight, 0 |
||||||
|
|
||||||
|
num_gt = pos_gt_inds.max() + 1 |
||||||
|
num_level = len(anchors) |
||||||
|
num_anchors_each_level = [item.size(0) for item in anchors] |
||||||
|
num_anchors_each_level.insert(0, 0) |
||||||
|
inds_level_interval = np.cumsum(num_anchors_each_level) |
||||||
|
pos_level_mask = [] |
||||||
|
for i in range(num_level): |
||||||
|
mask = (pos_inds >= inds_level_interval[i]) & ( |
||||||
|
pos_inds < inds_level_interval[i + 1]) |
||||||
|
pos_level_mask.append(mask) |
||||||
|
pos_inds_after_paa = [] |
||||||
|
ignore_inds_after_paa = [] |
||||||
|
for gt_ind in range(num_gt): |
||||||
|
pos_inds_gmm = [] |
||||||
|
pos_loss_gmm = [] |
||||||
|
gt_mask = pos_gt_inds == gt_ind |
||||||
|
for level in range(num_level): |
||||||
|
level_mask = pos_level_mask[level] |
||||||
|
level_gt_mask = level_mask & gt_mask |
||||||
|
value, topk_inds = pos_losses[level_gt_mask].topk( |
||||||
|
min(level_gt_mask.sum(), self.topk), largest=False) |
||||||
|
pos_inds_gmm.append(pos_inds[level_gt_mask][topk_inds]) |
||||||
|
pos_loss_gmm.append(value) |
||||||
|
pos_inds_gmm = torch.cat(pos_inds_gmm) |
||||||
|
pos_loss_gmm = torch.cat(pos_loss_gmm) |
||||||
|
# fix gmm need at least two sample |
||||||
|
if len(pos_inds_gmm) < 2: |
||||||
|
continue |
||||||
|
device = pos_inds_gmm.device |
||||||
|
pos_loss_gmm, sort_inds = pos_loss_gmm.sort() |
||||||
|
pos_inds_gmm = pos_inds_gmm[sort_inds] |
||||||
|
pos_loss_gmm = pos_loss_gmm.view(-1, 1).cpu().numpy() |
||||||
|
min_loss, max_loss = pos_loss_gmm.min(), pos_loss_gmm.max() |
||||||
|
means_init = [[min_loss], [max_loss]] |
||||||
|
weights_init = [0.5, 0.5] |
||||||
|
precisions_init = [[[1.0]], [[1.0]]] |
||||||
|
if skm is None: |
||||||
|
raise ImportError('Please run "pip install sklearn" ' |
||||||
|
'to install sklearn first.') |
||||||
|
gmm = skm.GaussianMixture( |
||||||
|
2, |
||||||
|
weights_init=weights_init, |
||||||
|
means_init=means_init, |
||||||
|
precisions_init=precisions_init) |
||||||
|
gmm.fit(pos_loss_gmm) |
||||||
|
gmm_assignment = gmm.predict(pos_loss_gmm) |
||||||
|
scores = gmm.score_samples(pos_loss_gmm) |
||||||
|
gmm_assignment = torch.from_numpy(gmm_assignment).to(device) |
||||||
|
scores = torch.from_numpy(scores).to(device) |
||||||
|
|
||||||
|
pos_inds_temp, ignore_inds_temp = self.gmm_separation_scheme( |
||||||
|
gmm_assignment, scores, pos_inds_gmm) |
||||||
|
pos_inds_after_paa.append(pos_inds_temp) |
||||||
|
ignore_inds_after_paa.append(ignore_inds_temp) |
||||||
|
|
||||||
|
pos_inds_after_paa = torch.cat(pos_inds_after_paa) |
||||||
|
ignore_inds_after_paa = torch.cat(ignore_inds_after_paa) |
||||||
|
reassign_mask = (pos_inds.unsqueeze(1) != pos_inds_after_paa).all(1) |
||||||
|
reassign_ids = pos_inds[reassign_mask] |
||||||
|
label[reassign_ids] = self.background_label |
||||||
|
label_weight[ignore_inds_after_paa] = 0 |
||||||
|
bbox_weight[reassign_ids] = 0 |
||||||
|
num_pos = len(pos_inds_after_paa) |
||||||
|
return label, label_weight, bbox_weight, num_pos |
||||||
|
|
||||||
|
def gmm_separation_scheme(self, gmm_assignment, scores, pos_inds_gmm): |
||||||
|
"""A general separation scheme for gmm model. |
||||||
|
|
||||||
|
It separates a GMM distribution of candidate samples into three |
||||||
|
parts, 0 1 and uncertain areas, and you can implement other |
||||||
|
separation schemes by rewriting this function. |
||||||
|
|
||||||
|
Args: |
||||||
|
gmm_assignment (Tensor): The prediction of GMM which is of shape |
||||||
|
(num_samples,). The 0/1 value indicates the distribution |
||||||
|
that each sample comes from. |
||||||
|
scores (Tensor): The probability of sample coming from the |
||||||
|
fit GMM distribution. The tensor is of shape (num_samples,). |
||||||
|
pos_inds_gmm (Tensor): All the indexes of samples which are used |
||||||
|
to fit GMM model. The tensor is of shape (num_samples,) |
||||||
|
|
||||||
|
Returns: |
||||||
|
tuple[Tensor]: The indices of positive and ignored samples. |
||||||
|
|
||||||
|
- pos_inds_temp (Tensor): Indices of positive samples. |
||||||
|
- ignore_inds_temp (Tensor): Indices of ignore samples. |
||||||
|
""" |
||||||
|
# The implementation is (c) in Fig.3 in origin paper intead of (b). |
||||||
|
# You can refer to issues such as |
||||||
|
# https://github.com/kkhoot/PAA/issues/8 and |
||||||
|
# https://github.com/kkhoot/PAA/issues/9. |
||||||
|
fgs = gmm_assignment == 0 |
||||||
|
if fgs.nonzero().numel(): |
||||||
|
_, pos_thr_ind = scores[fgs].topk(1) |
||||||
|
pos_inds_temp = pos_inds_gmm[fgs][:pos_thr_ind + 1] |
||||||
|
ignore_inds_temp = pos_inds_gmm.new_tensor([]) |
||||||
|
return pos_inds_temp, ignore_inds_temp |
||||||
|
|
||||||
|
def get_targets( |
||||||
|
self, |
||||||
|
anchor_list, |
||||||
|
valid_flag_list, |
||||||
|
gt_bboxes_list, |
||||||
|
img_metas, |
||||||
|
gt_bboxes_ignore_list=None, |
||||||
|
gt_labels_list=None, |
||||||
|
label_channels=1, |
||||||
|
unmap_outputs=True, |
||||||
|
): |
||||||
|
"""Get targets for PAA head. |
||||||
|
|
||||||
|
This method is almost the same as `AnchorHead.get_targets()`. We direct |
||||||
|
return the results from _get_targets_single instead map it to levels |
||||||
|
by images_to_levels function. |
||||||
|
|
||||||
|
Args: |
||||||
|
anchor_list (list[list[Tensor]]): Multi level anchors of each |
||||||
|
image. The outer list indicates images, and the inner list |
||||||
|
corresponds to feature levels of the image. Each element of |
||||||
|
the inner list is a tensor of shape (num_anchors, 4). |
||||||
|
valid_flag_list (list[list[Tensor]]): Multi level valid flags of |
||||||
|
each image. The outer list indicates images, and the inner list |
||||||
|
corresponds to feature levels of the image. Each element of |
||||||
|
the inner list is a tensor of shape (num_anchors, ) |
||||||
|
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. |
||||||
|
img_metas (list[dict]): Meta info of each image. |
||||||
|
gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be |
||||||
|
ignored. |
||||||
|
gt_labels_list (list[Tensor]): Ground truth labels of each box. |
||||||
|
label_channels (int): Channel of label. |
||||||
|
unmap_outputs (bool): Whether to map outputs back to the original |
||||||
|
set of anchors. |
||||||
|
|
||||||
|
Returns: |
||||||
|
tuple: Usually returns a tuple containing learning targets. |
||||||
|
|
||||||
|
- labels (list[Tensor]): Labels of all anchors, each with |
||||||
|
shape (num_anchors,). |
||||||
|
- label_weights (list[Tensor]): Label weights of all anchor. |
||||||
|
each with shape (num_anchors,). |
||||||
|
- bbox_targets (list[Tensor]): BBox targets of all anchors. |
||||||
|
each with shape (num_anchors, 4). |
||||||
|
- bbox_weights (list[Tensor]): BBox weights of all anchors. |
||||||
|
each with shape (num_anchors, 4). |
||||||
|
- pos_inds (list[Tensor]): Contains all index of positive |
||||||
|
sample in all anchor. |
||||||
|
- gt_inds (list[Tensor]): Contains all gt_index of positive |
||||||
|
sample in all anchor. |
||||||
|
""" |
||||||
|
|
||||||
|
num_imgs = len(img_metas) |
||||||
|
assert len(anchor_list) == len(valid_flag_list) == num_imgs |
||||||
|
concat_anchor_list = [] |
||||||
|
concat_valid_flag_list = [] |
||||||
|
for i in range(num_imgs): |
||||||
|
assert len(anchor_list[i]) == len(valid_flag_list[i]) |
||||||
|
concat_anchor_list.append(torch.cat(anchor_list[i])) |
||||||
|
concat_valid_flag_list.append(torch.cat(valid_flag_list[i])) |
||||||
|
|
||||||
|
# compute targets for each image |
||||||
|
if gt_bboxes_ignore_list is None: |
||||||
|
gt_bboxes_ignore_list = [None for _ in range(num_imgs)] |
||||||
|
if gt_labels_list is None: |
||||||
|
gt_labels_list = [None for _ in range(num_imgs)] |
||||||
|
results = multi_apply( |
||||||
|
self._get_targets_single, |
||||||
|
concat_anchor_list, |
||||||
|
concat_valid_flag_list, |
||||||
|
gt_bboxes_list, |
||||||
|
gt_bboxes_ignore_list, |
||||||
|
gt_labels_list, |
||||||
|
img_metas, |
||||||
|
label_channels=label_channels, |
||||||
|
unmap_outputs=unmap_outputs) |
||||||
|
|
||||||
|
(labels, label_weights, bbox_targets, bbox_weights, valid_pos_inds, |
||||||
|
valid_neg_inds, sampling_result) = results |
||||||
|
|
||||||
|
# Due to valid flag of anchors, we have to calculate the real pos_inds |
||||||
|
# in origin anchor set. |
||||||
|
pos_inds = [] |
||||||
|
for i, single_labels in enumerate(labels): |
||||||
|
pos_mask = (0 <= single_labels) & ( |
||||||
|
single_labels < self.background_label) |
||||||
|
pos_inds.append(pos_mask.nonzero().view(-1)) |
||||||
|
|
||||||
|
gt_inds = [item.pos_assigned_gt_inds for item in sampling_result] |
||||||
|
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, |
||||||
|
gt_inds) |
||||||
|
|
||||||
|
def _get_targets_single(self, |
||||||
|
flat_anchors, |
||||||
|
valid_flags, |
||||||
|
gt_bboxes, |
||||||
|
gt_bboxes_ignore, |
||||||
|
gt_labels, |
||||||
|
img_meta, |
||||||
|
label_channels=1, |
||||||
|
unmap_outputs=True): |
||||||
|
"""Compute regression and classification targets for anchors in a |
||||||
|
single image. |
||||||
|
|
||||||
|
This method is same as `AnchorHead._get_targets_single()`. |
||||||
|
""" |
||||||
|
assert unmap_outputs, 'We must map outputs back to the original' \ |
||||||
|
'set of anchors in PAAhead' |
||||||
|
return super(ATSSHead, self)._get_targets_single( |
||||||
|
flat_anchors, |
||||||
|
valid_flags, |
||||||
|
gt_bboxes, |
||||||
|
gt_bboxes_ignore, |
||||||
|
gt_labels, |
||||||
|
img_meta, |
||||||
|
label_channels=1, |
||||||
|
unmap_outputs=True) |
||||||
|
|
||||||
|
def _get_bboxes_single(self, |
||||||
|
cls_scores, |
||||||
|
bbox_preds, |
||||||
|
iou_preds, |
||||||
|
mlvl_anchors, |
||||||
|
img_shape, |
||||||
|
scale_factor, |
||||||
|
cfg, |
||||||
|
rescale=False): |
||||||
|
"""Transform outputs for a single batch item into labeled boxes. |
||||||
|
|
||||||
|
This method is almost same as `ATSSHead._get_bboxes_single()`. |
||||||
|
We use sqrt(iou_preds * cls_scores) in NMS process instead of just |
||||||
|
cls_scores. Besides, score voting is used when `` score_voting`` |
||||||
|
is set to True. |
||||||
|
""" |
||||||
|
assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) |
||||||
|
mlvl_bboxes = [] |
||||||
|
mlvl_scores = [] |
||||||
|
mlvl_iou_preds = [] |
||||||
|
for cls_score, bbox_pred, iou_preds, anchors in zip( |
||||||
|
cls_scores, bbox_preds, iou_preds, mlvl_anchors): |
||||||
|
assert cls_score.size()[-2:] == bbox_pred.size()[-2:] |
||||||
|
|
||||||
|
scores = cls_score.permute(1, 2, 0).reshape( |
||||||
|
-1, self.cls_out_channels).sigmoid() |
||||||
|
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) |
||||||
|
iou_preds = iou_preds.permute(1, 2, 0).reshape(-1).sigmoid() |
||||||
|
nms_pre = cfg.get('nms_pre', -1) |
||||||
|
if nms_pre > 0 and scores.shape[0] > nms_pre: |
||||||
|
max_scores, _ = (scores * iou_preds[:, None]).sqrt().max(dim=1) |
||||||
|
_, topk_inds = max_scores.topk(nms_pre) |
||||||
|
anchors = anchors[topk_inds, :] |
||||||
|
bbox_pred = bbox_pred[topk_inds, :] |
||||||
|
scores = scores[topk_inds, :] |
||||||
|
iou_preds = iou_preds[topk_inds] |
||||||
|
|
||||||
|
bboxes = self.bbox_coder.decode( |
||||||
|
anchors, bbox_pred, max_shape=img_shape) |
||||||
|
mlvl_bboxes.append(bboxes) |
||||||
|
mlvl_scores.append(scores) |
||||||
|
mlvl_iou_preds.append(iou_preds) |
||||||
|
|
||||||
|
mlvl_bboxes = torch.cat(mlvl_bboxes) |
||||||
|
if rescale: |
||||||
|
mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) |
||||||
|
mlvl_scores = torch.cat(mlvl_scores) |
||||||
|
# Add a dummy background class to the backend when using sigmoid |
||||||
|
# remind that we set FG labels to [0, num_class-1] since mmdet v2.0 |
||||||
|
# BG cat_id: num_class |
||||||
|
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) |
||||||
|
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) |
||||||
|
mlvl_iou_preds = torch.cat(mlvl_iou_preds) |
||||||
|
mlvl_nms_scores = (mlvl_scores * mlvl_iou_preds[:, None]).sqrt() |
||||||
|
det_bboxes, det_labels = multiclass_nms( |
||||||
|
mlvl_bboxes, |
||||||
|
mlvl_nms_scores, |
||||||
|
cfg.score_thr, |
||||||
|
cfg.nms, |
||||||
|
cfg.max_per_img, |
||||||
|
score_factors=None) |
||||||
|
if self.with_score_voting: |
||||||
|
det_bboxes, det_labels = self.score_voting(det_bboxes, det_labels, |
||||||
|
mlvl_bboxes, |
||||||
|
mlvl_nms_scores, |
||||||
|
cfg.score_thr) |
||||||
|
|
||||||
|
return det_bboxes, det_labels |
||||||
|
|
||||||
|
def score_voting(self, det_bboxes, det_labels, mlvl_bboxes, |
||||||
|
mlvl_nms_scores, score_thr): |
||||||
|
"""Implementation of score voting method works on each remaining boxes |
||||||
|
after NMS procedure. |
||||||
|
|
||||||
|
Args: |
||||||
|
det_bboxes (Tensor): Remaining boxes after NMS procedure, |
||||||
|
with shape (k, 5), each dimension means |
||||||
|
(x1, y1, x2, y2, score). |
||||||
|
det_labels (Tensor): The label of remaining boxes, with shape |
||||||
|
(k, 1),Labels are 0-based. |
||||||
|
mlvl_bboxes (Tensor): All boxes before the NMS procedure, |
||||||
|
with shape (num_anchors,4). |
||||||
|
mlvl_nms_scores (Tensor): The scores of all boxes which is used |
||||||
|
in the NMS procedure, with shape (num_anchors, num_class) |
||||||
|
mlvl_iou_preds (Tensot): The predictions of IOU of all boxes |
||||||
|
before the NMS procedure, with shape (num_anchors, 1) |
||||||
|
score_thr (float): The score threshold of bboxes. |
||||||
|
|
||||||
|
Returns: |
||||||
|
tuple: Usually returns a tuple containing voting results. |
||||||
|
|
||||||
|
- det_bboxes_voted (Tensor): Remaining boxes after |
||||||
|
score voting procedure, with shape (k, 5), each |
||||||
|
dimension means (x1, y1, x2, y2, score). |
||||||
|
- det_labels_voted (Tensor): Label of remaining bboxes |
||||||
|
after voting, with shape (num_anchors,). |
||||||
|
""" |
||||||
|
candidate_mask = mlvl_nms_scores > score_thr |
||||||
|
candidate_mask_nozeros = candidate_mask.nonzero() |
||||||
|
candidate_inds = candidate_mask_nozeros[:, 0] |
||||||
|
candidate_labels = candidate_mask_nozeros[:, 1] |
||||||
|
candidate_bboxes = mlvl_bboxes[candidate_inds] |
||||||
|
candidate_scores = mlvl_nms_scores[candidate_mask] |
||||||
|
det_bboxes_voted = [] |
||||||
|
det_labels_voted = [] |
||||||
|
for cls in range(self.cls_out_channels): |
||||||
|
candidate_cls_mask = candidate_labels == cls |
||||||
|
if not candidate_cls_mask.any(): |
||||||
|
continue |
||||||
|
candidate_cls_scores = candidate_scores[candidate_cls_mask] |
||||||
|
candidate_cls_bboxes = candidate_bboxes[candidate_cls_mask] |
||||||
|
det_cls_mask = det_labels == cls |
||||||
|
det_cls_bboxes = det_bboxes[det_cls_mask].view( |
||||||
|
-1, det_bboxes.size(-1)) |
||||||
|
det_candidate_ious = bbox_overlaps(det_cls_bboxes[:, :4], |
||||||
|
candidate_cls_bboxes) |
||||||
|
for det_ind in range(len(det_cls_bboxes)): |
||||||
|
single_det_ious = det_candidate_ious[det_ind] |
||||||
|
pos_ious_mask = single_det_ious > 0.01 |
||||||
|
pos_ious = single_det_ious[pos_ious_mask] |
||||||
|
pos_bboxes = candidate_cls_bboxes[pos_ious_mask] |
||||||
|
pos_scores = candidate_cls_scores[pos_ious_mask] |
||||||
|
pis = (torch.exp(-(1 - pos_ious)**2 / 0.025) * |
||||||
|
pos_scores)[:, None] |
||||||
|
voted_box = torch.sum( |
||||||
|
pis * pos_bboxes, dim=0) / torch.sum( |
||||||
|
pis, dim=0) |
||||||
|
voted_score = det_cls_bboxes[det_ind][-1:][None, :] |
||||||
|
det_bboxes_voted.append( |
||||||
|
torch.cat((voted_box[None, :], voted_score), dim=1)) |
||||||
|
det_labels_voted.append(cls) |
||||||
|
|
||||||
|
det_bboxes_voted = torch.cat(det_bboxes_voted, dim=0) |
||||||
|
det_labels_voted = det_labels.new_tensor(det_labels_voted) |
||||||
|
return det_bboxes_voted, det_labels_voted |
@ -0,0 +1,17 @@ |
|||||||
|
from ..builder import DETECTORS |
||||||
|
from .single_stage import SingleStageDetector |
||||||
|
|
||||||
|
|
||||||
|
@DETECTORS.register_module() |
||||||
|
class PAA(SingleStageDetector): |
||||||
|
"""Implementation of `PAA <https://arxiv.org/pdf/2007.08103.pdf>`_.""" |
||||||
|
|
||||||
|
def __init__(self, |
||||||
|
backbone, |
||||||
|
neck, |
||||||
|
bbox_head, |
||||||
|
train_cfg=None, |
||||||
|
test_cfg=None, |
||||||
|
pretrained=None): |
||||||
|
super(PAA, self).__init__(backbone, neck, bbox_head, train_cfg, |
||||||
|
test_cfg, pretrained) |
Loading…
Reference in new issue