Code Release of ECCV 2020 Spotlight paper for Side-Aware Boundary Localization for More Precise Object Detection (#3603)

* add sabl two stage

* add sabl retina

* ret cfg bug fix

* test bug fix

* minor update

* update

* add r101 two stage

* update

* add cfgs

* add cfgs

* update cfgs

* format

* format

* add readme

* fix isort

* update

* update readme

* add doc string for sabl retina head

* add doc string and rename some functions

* add docstring for bucketing coder

* update docstring

* bucket_num -> num_buckets

* bucket_pw -> bucket_w bucket_ph -> bucket_h

* update label2onehot

* update bucketing bbox coder doc

* update

* typo fix

* bboxes_ -> rescaled_bboxes

* rename some params in sabl head

* init with mmcv.cnn

* update doc

* rename pos->post

* update cfgs

* update test cfg

* update

* add unitest for sabl head

* add unitest for sabl retina

* rename

* minor rename

* minor update

* update docstring

* update

* use F.one_hot

* update docstring

* update test heads

* update ReadMe

* fix
pull/3680/head
Jiaqi Wang 5 years ago committed by GitHub
parent 08d1402c1e
commit 26562a1d0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      README.md
  2. 36
      configs/sabl/README.md
  3. 88
      configs/sabl/sabl_cascade_rcnn_r101_fpn_1x_coco.py
  4. 86
      configs/sabl/sabl_cascade_rcnn_r50_fpn_1x_coco.py
  5. 36
      configs/sabl/sabl_faster_rcnn_r101_fpn_1x_coco.py
  6. 34
      configs/sabl/sabl_faster_rcnn_r50_fpn_1x_coco.py
  7. 52
      configs/sabl/sabl_retinanet_r101_fpn_1x_coco.py
  8. 54
      configs/sabl/sabl_retinanet_r101_fpn_gn_1x_coco.py
  9. 71
      configs/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_480_960_coco.py
  10. 71
      configs/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_640_800_coco.py
  11. 50
      configs/sabl/sabl_retinanet_r50_fpn_1x_coco.py
  12. 52
      configs/sabl/sabl_retinanet_r50_fpn_gn_1x_coco.py
  13. 7
      mmdet/core/bbox/__init__.py
  14. 4
      mmdet/core/bbox/coder/__init__.py
  15. 339
      mmdet/core/bbox/coder/bucketing_bbox_coder.py
  16. 32
      mmdet/core/bbox/transforms.py
  17. 3
      mmdet/models/dense_heads/__init__.py
  18. 622
      mmdet/models/dense_heads/sabl_retina_head.py
  19. 3
      mmdet/models/roi_heads/bbox_heads/__init__.py
  20. 563
      mmdet/models/roi_heads/bbox_heads/sabl_head.py
  21. 15
      tests/test_config.py
  22. 145
      tests/test_models/test_heads.py

@ -97,6 +97,7 @@ Supported methods:
- [x] [DetectoRS](configs/detectors/README.md)
- [x] [Generalized Focal Loss](configs/gfl/README.md)
- [x] [CornerNet](configs/cornernet/README.md)
- [x] [Side-Aware Boundary Localization](configs/sabl/README.md)
- [x] [YOLOv3](configs/yolo/README.md)
Some other methods are also supported in [projects using MMDetection](./docs/projects.md).

@ -0,0 +1,36 @@
# Side-Aware Boundary Localization for More Precise Object Detection
## Introduction
We provide config files to reproduce the object detection results in the ECCV 2020 Spotlight paper for [Side-Aware Boundary Localization for More Precise Object Detection](https://arxiv.org/abs/1912.04260).
```
@inproceedings{Wang_2020_ECCV,
title = {Side-Aware Boundary Localization for More Precise Object Detection},
author = {Wang, Jiaqi and Zhang, Wenwei and Cao, Yuhang and Chen, Kai and Pang, Jiangmiao and Gong, Tao and Shi, Jianping, Loy, Chen Change and Lin, Dahua},
booktitle = {ECCV},
year = {2020}
}
```
## Results and Models
The results on COCO 2017 val is shown in the below table. (results on test-dev are usually slightly higher than val).
Single-scale testing (1333x800) is adopted in all results.
| Method | Backbone | Lr schd | ms-train | box AP | Download |
| :----------------: | :-------: | :-----: | :------: | :----: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| SABL Faster R-CNN | R-50-FPN | 1x | N | 39.9 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_faster_rcnn_r50_fpn_1x_coco/sabl_faster_rcnn_r50_fpn_1x_coco-e867595b.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_faster_rcnn_r50_fpn_1x_coco/20200830_130324.log.json) |
| SABL Faster R-CNN | R-101-FPN | 1x | N | 41.7 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_faster_rcnn_r101_fpn_1x_coco/sabl_faster_rcnn_r101_fpn_1x_coco-f804c6c1.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_faster_rcnn_r101_fpn_1x_coco/20200830_183949.log.json) |
| SABL Cascade R-CNN | R-50-FPN | 1x | N | 41.6 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_cascade_rcnn_r50_fpn_1x_coco/sabl_cascade_rcnn_r50_fpn_1x_coco-e1748e5e.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_cascade_rcnn_r50_fpn_1x_coco/20200831_033726.log.json) |
| SABL Cascade R-CNN | R-101-FPN | 1x | N | 43.0 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_cascade_rcnn_r101_fpn_1x_coco/sabl_cascade_rcnn_r101_fpn_1x_coco-2b83e87c.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_cascade_rcnn_r101_fpn_1x_coco/20200831_141745.log.json) |
| Method | Backbone | GN | Lr schd | ms-train | box AP | Download |
| :------------: | :-------: | :---: | :-----: | :---------: | :----: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| SABL RetinaNet | R-50-FPN | N | 1x | N | 37.7 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r50_fpn_1x_coco/sabl_retinanet_r50_fpn_1x_coco-6c54fd4f.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r50_fpn_1x_coco/20200830_053451.log.json) |
| SABL RetinaNet | R-50-FPN | Y | 1x | N | 38.8 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r50_fpn_gn_1x_coco/sabl_retinanet_r50_fpn_gn_1x_coco-e16dfcf1.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r50_fpn_gn_1x_coco/20200831_141955.log.json) |
| SABL RetinaNet | R-101-FPN | N | 1x | N | 39.7 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_1x_coco/sabl_retinanet_r101_fpn_1x_coco-42026904.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_1x_coco/20200831_034256.log.json) |
| SABL RetinaNet | R-101-FPN | Y | 1x | N | 40.5 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_gn_1x_coco/sabl_retinanet_r101_fpn_gn_1x_coco-40a893e8.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_gn_1x_coco/20200830_201422.log.json) |
| SABL RetinaNet | R-101-FPN | Y | 2x | Y (640~800) | 42.9 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_640_800_coco/sabl_retinanet_r101_fpn_gn_2x_ms_640_800_coco-1e63382c.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_640_800_coco/20200830_144807.log.json) |
| SABL RetinaNet | R-101-FPN | Y | 2x | Y (480~960) | 43.6 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_480_960_coco/sabl_retinanet_r101_fpn_gn_2x_ms_480_960_coco-5342f857.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_480_960_coco/20200830_164537.log.json) |

@ -0,0 +1,88 @@
_base_ = [
'../_base_/models/cascade_rcnn_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
# model settings
model = dict(
pretrained='torchvision://resnet101',
backbone=dict(depth=101),
roi_head=dict(bbox_head=[
dict(
type='SABLHead',
num_classes=80,
cls_in_channels=256,
reg_in_channels=256,
roi_feat_size=7,
reg_feat_up_ratio=2,
reg_pre_kernel=3,
reg_post_kernel=3,
reg_pre_num=2,
reg_post_num=1,
cls_out_channels=1024,
reg_offset_out_channels=256,
reg_cls_out_channels=256,
num_cls_fcs=1,
num_reg_fcs=0,
reg_class_agnostic=True,
norm_cfg=None,
bbox_coder=dict(
type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.7),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1,
loss_weight=1.0)),
dict(
type='SABLHead',
num_classes=80,
cls_in_channels=256,
reg_in_channels=256,
roi_feat_size=7,
reg_feat_up_ratio=2,
reg_pre_kernel=3,
reg_post_kernel=3,
reg_pre_num=2,
reg_post_num=1,
cls_out_channels=1024,
reg_offset_out_channels=256,
reg_cls_out_channels=256,
num_cls_fcs=1,
num_reg_fcs=0,
reg_class_agnostic=True,
norm_cfg=None,
bbox_coder=dict(
type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.5),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1,
loss_weight=1.0)),
dict(
type='SABLHead',
num_classes=80,
cls_in_channels=256,
reg_in_channels=256,
roi_feat_size=7,
reg_feat_up_ratio=2,
reg_pre_kernel=3,
reg_post_kernel=3,
reg_pre_num=2,
reg_post_num=1,
cls_out_channels=1024,
reg_offset_out_channels=256,
reg_cls_out_channels=256,
num_cls_fcs=1,
num_reg_fcs=0,
reg_class_agnostic=True,
norm_cfg=None,
bbox_coder=dict(
type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.3),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1, loss_weight=1.0))
]))

@ -0,0 +1,86 @@
_base_ = [
'../_base_/models/cascade_rcnn_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
# model settings
model = dict(
roi_head=dict(bbox_head=[
dict(
type='SABLHead',
num_classes=80,
cls_in_channels=256,
reg_in_channels=256,
roi_feat_size=7,
reg_feat_up_ratio=2,
reg_pre_kernel=3,
reg_post_kernel=3,
reg_pre_num=2,
reg_post_num=1,
cls_out_channels=1024,
reg_offset_out_channels=256,
reg_cls_out_channels=256,
num_cls_fcs=1,
num_reg_fcs=0,
reg_class_agnostic=True,
norm_cfg=None,
bbox_coder=dict(
type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.7),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1,
loss_weight=1.0)),
dict(
type='SABLHead',
num_classes=80,
cls_in_channels=256,
reg_in_channels=256,
roi_feat_size=7,
reg_feat_up_ratio=2,
reg_pre_kernel=3,
reg_post_kernel=3,
reg_pre_num=2,
reg_post_num=1,
cls_out_channels=1024,
reg_offset_out_channels=256,
reg_cls_out_channels=256,
num_cls_fcs=1,
num_reg_fcs=0,
reg_class_agnostic=True,
norm_cfg=None,
bbox_coder=dict(
type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.5),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1,
loss_weight=1.0)),
dict(
type='SABLHead',
num_classes=80,
cls_in_channels=256,
reg_in_channels=256,
roi_feat_size=7,
reg_feat_up_ratio=2,
reg_pre_kernel=3,
reg_post_kernel=3,
reg_pre_num=2,
reg_post_num=1,
cls_out_channels=1024,
reg_offset_out_channels=256,
reg_cls_out_channels=256,
num_cls_fcs=1,
num_reg_fcs=0,
reg_class_agnostic=True,
norm_cfg=None,
bbox_coder=dict(
type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.3),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1, loss_weight=1.0))
]))

@ -0,0 +1,36 @@
_base_ = [
'../_base_/models/faster_rcnn_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
pretrained='torchvision://resnet101',
backbone=dict(depth=101),
roi_head=dict(
bbox_head=dict(
_delete_=True,
type='SABLHead',
num_classes=80,
cls_in_channels=256,
reg_in_channels=256,
roi_feat_size=7,
reg_feat_up_ratio=2,
reg_pre_kernel=3,
reg_post_kernel=3,
reg_pre_num=2,
reg_post_num=1,
cls_out_channels=1024,
reg_offset_out_channels=256,
reg_cls_out_channels=256,
num_cls_fcs=1,
num_reg_fcs=0,
reg_class_agnostic=True,
norm_cfg=None,
bbox_coder=dict(
type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.7),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1,
loss_weight=1.0))))

@ -0,0 +1,34 @@
_base_ = [
'../_base_/models/faster_rcnn_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
roi_head=dict(
bbox_head=dict(
_delete_=True,
type='SABLHead',
num_classes=80,
cls_in_channels=256,
reg_in_channels=256,
roi_feat_size=7,
reg_feat_up_ratio=2,
reg_pre_kernel=3,
reg_post_kernel=3,
reg_pre_num=2,
reg_post_num=1,
cls_out_channels=1024,
reg_offset_out_channels=256,
reg_cls_out_channels=256,
num_cls_fcs=1,
num_reg_fcs=0,
reg_class_agnostic=True,
norm_cfg=None,
bbox_coder=dict(
type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.7),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_bbox_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1,
loss_weight=1.0))))

@ -0,0 +1,52 @@
_base_ = [
'../_base_/models/retinanet_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
# model settings
model = dict(
pretrained='torchvision://resnet101',
backbone=dict(depth=101),
bbox_head=dict(
_delete_=True,
type='SABLRetinaHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
approx_anchor_generator=dict(
type='AnchorGenerator',
octave_base_scale=4,
scales_per_octave=3,
ratios=[0.5, 1.0, 2.0],
strides=[8, 16, 32, 64, 128]),
square_anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
scales=[4],
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5),
loss_bbox_reg=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5)))
# training and testing settings
train_cfg = dict(
assigner=dict(
type='ApproxMaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0.0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False)
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)

@ -0,0 +1,54 @@
_base_ = [
'../_base_/models/retinanet_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
# model settings
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
pretrained='torchvision://resnet101',
backbone=dict(depth=101),
bbox_head=dict(
_delete_=True,
type='SABLRetinaHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
approx_anchor_generator=dict(
type='AnchorGenerator',
octave_base_scale=4,
scales_per_octave=3,
ratios=[0.5, 1.0, 2.0],
strides=[8, 16, 32, 64, 128]),
square_anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
scales=[4],
strides=[8, 16, 32, 64, 128]),
norm_cfg=norm_cfg,
bbox_coder=dict(
type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5),
loss_bbox_reg=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5)))
# training and testing settings
train_cfg = dict(
assigner=dict(
type='ApproxMaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0.0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False)
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)

@ -0,0 +1,71 @@
_base_ = [
'../_base_/models/retinanet_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_2x.py', '../_base_/default_runtime.py'
]
# model settings
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
pretrained='torchvision://resnet101',
backbone=dict(depth=101),
bbox_head=dict(
_delete_=True,
type='SABLRetinaHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
approx_anchor_generator=dict(
type='AnchorGenerator',
octave_base_scale=4,
scales_per_octave=3,
ratios=[0.5, 1.0, 2.0],
strides=[8, 16, 32, 64, 128]),
square_anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
scales=[4],
strides=[8, 16, 32, 64, 128]),
norm_cfg=norm_cfg,
bbox_coder=dict(
type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5),
loss_bbox_reg=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5)))
# training and testing settings
train_cfg = dict(
assigner=dict(
type='ApproxMaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0.0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False)
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='Resize',
img_scale=[(1333, 480), (1333, 960)],
multiscale_mode='range',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
data = dict(train=dict(pipeline=train_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)

@ -0,0 +1,71 @@
_base_ = [
'../_base_/models/retinanet_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_2x.py', '../_base_/default_runtime.py'
]
# model settings
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
pretrained='torchvision://resnet101',
backbone=dict(depth=101),
bbox_head=dict(
_delete_=True,
type='SABLRetinaHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
approx_anchor_generator=dict(
type='AnchorGenerator',
octave_base_scale=4,
scales_per_octave=3,
ratios=[0.5, 1.0, 2.0],
strides=[8, 16, 32, 64, 128]),
square_anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
scales=[4],
strides=[8, 16, 32, 64, 128]),
norm_cfg=norm_cfg,
bbox_coder=dict(
type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5),
loss_bbox_reg=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5)))
# training and testing settings
train_cfg = dict(
assigner=dict(
type='ApproxMaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0.0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False)
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='Resize',
img_scale=[(1333, 640), (1333, 800)],
multiscale_mode='range',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
data = dict(train=dict(pipeline=train_pipeline))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)

@ -0,0 +1,50 @@
_base_ = [
'../_base_/models/retinanet_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
# model settings
model = dict(
bbox_head=dict(
_delete_=True,
type='SABLRetinaHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
approx_anchor_generator=dict(
type='AnchorGenerator',
octave_base_scale=4,
scales_per_octave=3,
ratios=[0.5, 1.0, 2.0],
strides=[8, 16, 32, 64, 128]),
square_anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
scales=[4],
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5),
loss_bbox_reg=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5)))
# training and testing settings
train_cfg = dict(
assigner=dict(
type='ApproxMaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0.0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False)
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)

@ -0,0 +1,52 @@
_base_ = [
'../_base_/models/retinanet_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
# model settings
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
bbox_head=dict(
_delete_=True,
type='SABLRetinaHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
approx_anchor_generator=dict(
type='AnchorGenerator',
octave_base_scale=4,
scales_per_octave=3,
ratios=[0.5, 1.0, 2.0],
strides=[8, 16, 32, 64, 128]),
square_anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
scales=[4],
strides=[8, 16, 32, 64, 128]),
norm_cfg=norm_cfg,
bbox_coder=dict(
type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5),
loss_bbox_reg=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5)))
# training and testing settings
train_cfg = dict(
assigner=dict(
type='ApproxMaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0.0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False)
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)

@ -9,8 +9,8 @@ from .samplers import (BaseSampler, CombinedSampler,
OHEMSampler, PseudoSampler, RandomSampler,
SamplingResult, ScoreHLRSampler)
from .transforms import (bbox2distance, bbox2result, bbox2roi, bbox_flip,
bbox_mapping, bbox_mapping_back, distance2bbox,
roi2bbox)
bbox_mapping, bbox_mapping_back, bbox_rescale,
distance2bbox, roi2bbox)
__all__ = [
'bbox_overlaps', 'BboxOverlaps2D', 'BaseAssigner', 'MaxIoUAssigner',
@ -20,5 +20,6 @@ __all__ = [
'build_sampler', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back',
'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'bbox2distance',
'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder',
'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'CenterRegionAssigner'
'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'CenterRegionAssigner',
'bbox_rescale'
]

@ -1,4 +1,5 @@
from .base_bbox_coder import BaseBBoxCoder
from .bucketing_bbox_coder import BucketingBBoxCoder
from .delta_xywh_bbox_coder import DeltaXYWHBBoxCoder
from .legacy_delta_xywh_bbox_coder import LegacyDeltaXYWHBBoxCoder
from .pseudo_bbox_coder import PseudoBBoxCoder
@ -7,5 +8,6 @@ from .yolo_bbox_coder import YOLOBBoxCoder
__all__ = [
'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder',
'LegacyDeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'YOLOBBoxCoder'
'LegacyDeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'YOLOBBoxCoder',
'BucketingBBoxCoder'
]

@ -0,0 +1,339 @@
import numpy as np
import torch
import torch.nn.functional as F
from ..builder import BBOX_CODERS
from ..transforms import bbox_rescale
from .base_bbox_coder import BaseBBoxCoder
@BBOX_CODERS.register_module()
class BucketingBBoxCoder(BaseBBoxCoder):
"""Bucketing BBox Coder for Side-Aware Bounday Localization (SABL).
Boundary Localization with Bucketing and Bucketing Guided Rescoring
are implemented here.
Please refer to https://arxiv.org/abs/1912.04260 for more details.
Args:
num_buckets (int): Number of buckets.
scale_factor (int): Scale factor of proposals to generate buckets.
offset_topk (int): Topk buckets are used to generate
bucket fine regression targets. Defaults to 2.
offset_upperbound (float): Offset upperbound to generate
bucket fine regression targets.
To avoid too large offset displacements. Defaults to 1.0.
cls_ignore_neighbor (bool): Ignore second nearest bucket or Not.
Defaults to True.
"""
def __init__(self,
num_buckets,
scale_factor,
offset_topk=2,
offset_upperbound=1.0,
cls_ignore_neighbor=True):
super(BucketingBBoxCoder, self).__init__()
self.num_buckets = num_buckets
self.scale_factor = scale_factor
self.offset_topk = offset_topk
self.offset_upperbound = offset_upperbound
self.cls_ignore_neighbor = cls_ignore_neighbor
def encode(self, bboxes, gt_bboxes):
"""Get bucketing estimation and fine regression targets during
training.
Args:
bboxes (torch.Tensor): source boxes, e.g., object proposals.
gt_bboxes (torch.Tensor): target of the transformation, e.g.,
ground truth boxes.
Returns:
encoded_bboxes(tuple[Tensor]): bucketing estimation
and fine regression targets and weights
"""
assert bboxes.size(0) == gt_bboxes.size(0)
assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
encoded_bboxes = bbox2bucket(bboxes, gt_bboxes, self.num_buckets,
self.scale_factor, self.offset_topk,
self.offset_upperbound,
self.cls_ignore_neighbor)
return encoded_bboxes
def decode(self, bboxes, pred_bboxes, max_shape=None):
"""Apply transformation `pred_bboxes` to `boxes`.
Args:
boxes (torch.Tensor): Basic boxes.
pred_bboxes (torch.Tensor): Predictions for bucketing estimation
and fine regression
max_shape (tuple[int], optional): Maximum shape of boxes.
Defaults to None.
Returns:
torch.Tensor: Decoded boxes.
"""
assert len(pred_bboxes) == 2
cls_preds, offset_preds = pred_bboxes
assert cls_preds.size(0) == bboxes.size(0) and offset_preds.size(
0) == bboxes.size(0)
decoded_bboxes = bucket2bbox(bboxes, cls_preds, offset_preds,
self.num_buckets, self.scale_factor,
max_shape)
return decoded_bboxes
def generat_buckets(proposals, num_buckets, scale_factor=1.0):
"""Generate buckets w.r.t bucket number and scale factor of proposals.
Args:
proposals (Tensor): Shape (n, 4)
num_buckets (int): Number of buckets.
scale_factor (float): Scale factor to rescale proposals.
Returns:
tuple[Tensor]: (bucket_w, bucket_h, l_buckets, r_buckets,
t_buckets, d_buckets)
- bucket_w: Width of buckets on x-axis. Shape (n, ).
- bucket_h: Height of buckets on y-axis. Shape (n, ).
- l_buckets: Left buckets. Shape (n, ceil(side_num/2)).
- r_buckets: Right buckets. Shape (n, ceil(side_num/2)).
- t_buckets: Top buckets. Shape (n, ceil(side_num/2)).
- d_buckets: Down buckets. Shape (n, ceil(side_num/2)).
"""
proposals = bbox_rescale(proposals, scale_factor)
# number of buckets in each side
side_num = int(np.ceil(num_buckets / 2.0))
pw = proposals[..., 2] - proposals[..., 0]
ph = proposals[..., 3] - proposals[..., 1]
px1 = proposals[..., 0]
py1 = proposals[..., 1]
px2 = proposals[..., 2]
py2 = proposals[..., 3]
bucket_w = pw / num_buckets
bucket_h = ph / num_buckets
# left buckets
l_buckets = px1[:, None] + (0.5 + torch.arange(
0, side_num).to(proposals).float())[None, :] * bucket_w[:, None]
# right buckets
r_buckets = px2[:, None] - (0.5 + torch.arange(
0, side_num).to(proposals).float())[None, :] * bucket_w[:, None]
# top buckets
t_buckets = py1[:, None] + (0.5 + torch.arange(
0, side_num).to(proposals).float())[None, :] * bucket_h[:, None]
# down buckets
d_buckets = py2[:, None] - (0.5 + torch.arange(
0, side_num).to(proposals).float())[None, :] * bucket_h[:, None]
return bucket_w, bucket_h, l_buckets, r_buckets, t_buckets, d_buckets
def bbox2bucket(proposals,
gt,
num_buckets,
scale_factor,
offset_topk=2,
offset_upperbound=1.0,
cls_ignore_neighbor=True):
"""Generate buckets estimation and fine regression targets.
Args:
proposals (Tensor): Shape (n, 4)
gt (Tensor): Shape (n, 4)
num_buckets (int): Number of buckets.
scale_factor (float): Scale factor to rescale proposals.
offset_topk (int): Topk buckets are used to generate
bucket fine regression targets. Defaults to 2.
offset_upperbound (float): Offset allowance to generate
bucket fine regression targets.
To avoid too large offset displacements. Defaults to 1.0.
cls_ignore_neighbor (bool): Ignore second nearest bucket or Not.
Defaults to True.
Returns:
tuple[Tensor]: (offsets, offsets_weights, bucket_labels, cls_weights).
- offsets: Fine regression targets. \
Shape (n, num_buckets*2).
- offsets_weights: Fine regression weights. \
Shape (n, num_buckets*2).
- bucket_labels: Bucketing estimation labels. \
Shape (n, num_buckets*2).
- cls_weights: Bucketing estimation weights. \
Shape (n, num_buckets*2).
"""
assert proposals.size() == gt.size()
# generate buckets
proposals = proposals.float()
gt = gt.float()
(bucket_w, bucket_h, l_buckets, r_buckets, t_buckets,
d_buckets) = generat_buckets(proposals, num_buckets, scale_factor)
gx1 = gt[..., 0]
gy1 = gt[..., 1]
gx2 = gt[..., 2]
gy2 = gt[..., 3]
# generate offset targets and weights
# offsets from buckets to gts
l_offsets = (l_buckets - gx1[:, None]) / bucket_w[:, None]
r_offsets = (r_buckets - gx2[:, None]) / bucket_w[:, None]
t_offsets = (t_buckets - gy1[:, None]) / bucket_h[:, None]
d_offsets = (d_buckets - gy2[:, None]) / bucket_h[:, None]
# select top-k nearset buckets
l_topk, l_label = l_offsets.abs().topk(
offset_topk, dim=1, largest=False, sorted=True)
r_topk, r_label = r_offsets.abs().topk(
offset_topk, dim=1, largest=False, sorted=True)
t_topk, t_label = t_offsets.abs().topk(
offset_topk, dim=1, largest=False, sorted=True)
d_topk, d_label = d_offsets.abs().topk(
offset_topk, dim=1, largest=False, sorted=True)
offset_l_weights = l_offsets.new_zeros(l_offsets.size())
offset_r_weights = r_offsets.new_zeros(r_offsets.size())
offset_t_weights = t_offsets.new_zeros(t_offsets.size())
offset_d_weights = d_offsets.new_zeros(d_offsets.size())
inds = torch.arange(0, proposals.size(0)).to(proposals).long()
# generate offset weights of top-k nearset buckets
for k in range(offset_topk):
if k >= 1:
offset_l_weights[inds, l_label[:,
k]] = (l_topk[:, k] <
offset_upperbound).float()
offset_r_weights[inds, r_label[:,
k]] = (r_topk[:, k] <
offset_upperbound).float()
offset_t_weights[inds, t_label[:,
k]] = (t_topk[:, k] <
offset_upperbound).float()
offset_d_weights[inds, d_label[:,
k]] = (d_topk[:, k] <
offset_upperbound).float()
else:
offset_l_weights[inds, l_label[:, k]] = 1.0
offset_r_weights[inds, r_label[:, k]] = 1.0
offset_t_weights[inds, t_label[:, k]] = 1.0
offset_d_weights[inds, d_label[:, k]] = 1.0
offsets = torch.cat([l_offsets, r_offsets, t_offsets, d_offsets], dim=-1)
offsets_weights = torch.cat([
offset_l_weights, offset_r_weights, offset_t_weights, offset_d_weights
],
dim=-1)
# generate bucket labels and weight
side_num = int(np.ceil(num_buckets / 2.0))
labels = torch.stack(
[l_label[:, 0], r_label[:, 0], t_label[:, 0], d_label[:, 0]], dim=-1)
batch_size = labels.size(0)
bucket_labels = F.one_hot(labels.view(-1), side_num).view(batch_size,
-1).float()
bucket_cls_l_weights = (l_offsets.abs() < 1).float()
bucket_cls_r_weights = (r_offsets.abs() < 1).float()
bucket_cls_t_weights = (t_offsets.abs() < 1).float()
bucket_cls_d_weights = (d_offsets.abs() < 1).float()
bucket_cls_weights = torch.cat([
bucket_cls_l_weights, bucket_cls_r_weights, bucket_cls_t_weights,
bucket_cls_d_weights
],
dim=-1)
# ignore second nearest buckets for cls if necessay
if cls_ignore_neighbor:
bucket_cls_weights = (~((bucket_cls_weights == 1) &
(bucket_labels == 0))).float()
else:
bucket_cls_weights[:] = 1.0
return offsets, offsets_weights, bucket_labels, bucket_cls_weights
def bucket2bbox(proposals,
cls_preds,
offset_preds,
num_buckets,
scale_factor=1.0,
max_shape=None):
"""Apply bucketing estimation (cls preds) and fine regression (offset
preds) to generate det bboxes.
Args:
proposals (Tensor): Boxes to be transformed. Shape (n, 4)
cls_preds (Tensor): bucketing estimation. Shape (n, num_buckets*2).
offset_preds (Tensor): fine regression. Shape (n, num_buckets*2).
num_buckets (int): Number of buckets.
scale_factor (float): Scale factor to rescale proposals.
max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W)
Returns:
tuple[Tensor]: (bboxes, loc_confidence).
- bboxes: predicted bboxes. Shape (n, 4)
- loc_confidence: localization confidence of predicted bboxes.
Shape (n,).
"""
side_num = int(np.ceil(num_buckets / 2.0))
cls_preds = cls_preds.view(-1, side_num)
offset_preds = offset_preds.view(-1, side_num)
scores = F.softmax(cls_preds, dim=1)
score_topk, score_label = scores.topk(2, dim=1, largest=True, sorted=True)
rescaled_proposals = bbox_rescale(proposals, scale_factor)
pw = rescaled_proposals[..., 2] - rescaled_proposals[..., 0]
ph = rescaled_proposals[..., 3] - rescaled_proposals[..., 1]
px1 = rescaled_proposals[..., 0]
py1 = rescaled_proposals[..., 1]
px2 = rescaled_proposals[..., 2]
py2 = rescaled_proposals[..., 3]
bucket_w = pw / num_buckets
bucket_h = ph / num_buckets
score_inds_l = score_label[0::4, 0]
score_inds_r = score_label[1::4, 0]
score_inds_t = score_label[2::4, 0]
score_inds_d = score_label[3::4, 0]
l_buckets = px1 + (0.5 + score_inds_l.float()) * bucket_w
r_buckets = px2 - (0.5 + score_inds_r.float()) * bucket_w
t_buckets = py1 + (0.5 + score_inds_t.float()) * bucket_h
d_buckets = py2 - (0.5 + score_inds_d.float()) * bucket_h
offsets = offset_preds.view(-1, 4, side_num)
inds = torch.arange(proposals.size(0)).to(proposals).long()
l_offsets = offsets[:, 0, :][inds, score_inds_l]
r_offsets = offsets[:, 1, :][inds, score_inds_r]
t_offsets = offsets[:, 2, :][inds, score_inds_t]
d_offsets = offsets[:, 3, :][inds, score_inds_d]
x1 = l_buckets - l_offsets * bucket_w
x2 = r_buckets - r_offsets * bucket_w
y1 = t_buckets - t_offsets * bucket_h
y2 = d_buckets - d_offsets * bucket_h
if max_shape is not None:
x1 = x1.clamp(min=0, max=max_shape[1] - 1)
y1 = y1.clamp(min=0, max=max_shape[0] - 1)
x2 = x2.clamp(min=0, max=max_shape[1] - 1)
y2 = y2.clamp(min=0, max=max_shape[0] - 1)
bboxes = torch.cat([x1[:, None], y1[:, None], x2[:, None], y2[:, None]],
dim=-1)
# bucketing guided rescoring
loc_confidence = score_topk[:, 0]
top2_neighbor_inds = (score_label[:, 0] - score_label[:, 1]).abs() == 1
loc_confidence += score_topk[:, 1] * top2_neighbor_inds.float()
loc_confidence = loc_confidence.view(-1, 4).mean(dim=1)
return bboxes, loc_confidence

@ -163,3 +163,35 @@ def bbox2distance(points, bbox, max_dis=None, eps=0.1):
right = right.clamp(min=0, max=max_dis - eps)
bottom = bottom.clamp(min=0, max=max_dis - eps)
return torch.stack([left, top, right, bottom], -1)
def bbox_rescale(bboxes, scale_factor=1.0):
"""Rescale bounding box w.r.t. scale_factor.
Args:
bboxes (Tensor): Shape (n, 4) for bboxes or (n, 5) for rois
scale_factor (float): rescale factor
Returns:
Tensor: Rescaled bboxes.
"""
if bboxes.size(1) == 5:
bboxes_ = bboxes[:, 1:]
inds_ = bboxes[:, 0]
else:
bboxes_ = bboxes
cx = (bboxes_[:, 0] + bboxes_[:, 2]) * 0.5
cy = (bboxes_[:, 1] + bboxes_[:, 3]) * 0.5
w = bboxes_[:, 2] - bboxes_[:, 0]
h = bboxes_[:, 3] - bboxes_[:, 1]
w = w * scale_factor
h = h * scale_factor
x1 = cx - 0.5 * w
x2 = cx + 0.5 * w
y1 = cy - 0.5 * h
y2 = cy + 0.5 * h
if bboxes.size(1) == 5:
rescaled_bboxes = torch.stack([inds_, x1, y1, x2, y2], dim=-1)
else:
rescaled_bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
return rescaled_bboxes

@ -18,6 +18,7 @@ from .reppoints_head import RepPointsHead
from .retina_head import RetinaHead
from .retina_sepbn_head import RetinaSepBNHead
from .rpn_head import RPNHead
from .sabl_retina_head import SABLRetinaHead
from .ssd_head import SSDHead
from .yolo_head import YOLOV3Head
@ -27,5 +28,5 @@ __all__ = [
'SSDHead', 'FCOSHead', 'RepPointsHead', 'FoveaHead',
'FreeAnchorRetinaHead', 'ATSSHead', 'FSAFHead', 'NASFCOSHead',
'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'PAAHead',
'YOLOV3Head'
'YOLOV3Head', 'SABLRetinaHead'
]

@ -0,0 +1,622 @@
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
from mmdet.core import (build_anchor_generator, build_assigner,
build_bbox_coder, build_sampler, force_fp32,
images_to_levels, multi_apply, multiclass_nms, unmap)
from ..builder import HEADS, build_loss
from .base_dense_head import BaseDenseHead
from .guided_anchor_head import GuidedAnchorHead
@HEADS.register_module
class SABLRetinaHead(BaseDenseHead):
"""Side-Aware Boundary Localization (SABL) for RetinaNet.
The anchor generation, assigning and sampling in SABLRetinaHead
are the same as GuidedAnchorHead for guided anchoring.
Please refer to https://arxiv.org/abs/1912.04260 for more details.
Args:
num_classes (int): Number of classes.
in_channels (int): Number of channels in the input feature map.
stacked_convs (int): Number of Convs for classification \
and regression branches. Defaults to 4.
feat_channels (int): Number of hidden channels. \
Defaults to 256.
approx_anchor_generator (dict): Config dict for approx generator.
square_anchor_generator (dict): Config dict for square generator.
conv_cfg (dict): Config dict for ConvModule. Defaults to None.
norm_cfg (dict): Config dict for Norm Layer. Defaults to None.
bbox_coder (dict): Config dict for bbox coder.
reg_decoded_bbox (bool): Whether to regress decoded bbox. \
Defaults to False.
background_label (int): Background label. Defaults to None.
train_cfg (dict): Training config of SABLRetinaHead.
test_cfg (dict): Testing config of SABLRetinaHead.
loss_cls (dict): Config of classification loss.
loss_bbox_cls (dict): Config of classification loss for bbox branch.
loss_bbox_reg (dict): Config of regression loss for bbox branch.
"""
def __init__(self,
num_classes,
in_channels,
stacked_convs=4,
feat_channels=256,
approx_anchor_generator=dict(
type='AnchorGenerator',
octave_base_scale=4,
scales_per_octave=3,
ratios=[0.5, 1.0, 2.0],
strides=[8, 16, 32, 64, 128]),
square_anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
scales=[4],
strides=[8, 16, 32, 64, 128]),
conv_cfg=None,
norm_cfg=None,
bbox_coder=dict(
type='BucketingBBoxCoder',
num_buckets=14,
scale_factor=3.0),
reg_decoded_bbox=False,
background_label=None,
train_cfg=None,
test_cfg=None,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.5),
loss_bbox_reg=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5)):
super(SABLRetinaHead, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.feat_channels = feat_channels
self.num_buckets = bbox_coder['num_buckets']
self.side_num = int(np.ceil(self.num_buckets / 2))
assert (approx_anchor_generator['octave_base_scale'] ==
square_anchor_generator['scales'][0])
assert (approx_anchor_generator['strides'] ==
square_anchor_generator['strides'])
self.approx_anchor_generator = build_anchor_generator(
approx_anchor_generator)
self.square_anchor_generator = build_anchor_generator(
square_anchor_generator)
self.approxs_per_octave = (
self.approx_anchor_generator.num_base_anchors[0])
# one anchor per location
self.num_anchors = 1
self.stacked_convs = stacked_convs
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.reg_decoded_bbox = reg_decoded_bbox
self.background_label = (
num_classes if background_label is None else background_label)
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
self.sampling = loss_cls['type'] not in [
'FocalLoss', 'GHMC', 'QualityFocalLoss'
]
if self.use_sigmoid_cls:
self.cls_out_channels = num_classes
else:
self.cls_out_channels = num_classes + 1
self.bbox_coder = build_bbox_coder(bbox_coder)
self.loss_cls = build_loss(loss_cls)
self.loss_bbox_cls = build_loss(loss_bbox_cls)
self.loss_bbox_reg = build_loss(loss_bbox_reg)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
if self.train_cfg:
self.assigner = build_assigner(self.train_cfg.assigner)
# use PseudoSampler when sampling is False
if self.sampling and hasattr(self.train_cfg, 'sampler'):
sampler_cfg = self.train_cfg.sampler
else:
sampler_cfg = dict(type='PseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self)
self.fp16_enabled = False
self._init_layers()
def _init_layers(self):
self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.retina_cls = nn.Conv2d(
self.feat_channels, self.cls_out_channels, 3, padding=1)
self.retina_bbox_reg = nn.Conv2d(
self.feat_channels, self.side_num * 4, 3, padding=1)
self.retina_bbox_cls = nn.Conv2d(
self.feat_channels, self.side_num * 4, 3, padding=1)
def init_weights(self):
for m in self.cls_convs:
normal_init(m.conv, std=0.01)
for m in self.reg_convs:
normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.retina_cls, std=0.01, bias=bias_cls)
normal_init(self.retina_bbox_reg, std=0.01)
normal_init(self.retina_bbox_cls, std=0.01)
def forward_single(self, x):
cls_feat = x
reg_feat = x
for cls_conv in self.cls_convs:
cls_feat = cls_conv(cls_feat)
for reg_conv in self.reg_convs:
reg_feat = reg_conv(reg_feat)
cls_score = self.retina_cls(cls_feat)
bbox_cls_pred = self.retina_bbox_cls(reg_feat)
bbox_reg_pred = self.retina_bbox_reg(reg_feat)
bbox_pred = (bbox_cls_pred, bbox_reg_pred)
return cls_score, bbox_pred
def forward(self, feats):
return multi_apply(self.forward_single, feats)
def get_anchors(self, featmap_sizes, img_metas, device='cuda'):
"""Get squares according to feature map sizes and guided anchors.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
device (torch.device | str): device for returned tensors
Returns:
tuple: square approxs of each image
"""
num_imgs = len(img_metas)
# since feature map sizes of all images are the same, we only compute
# squares for one time
multi_level_squares = self.square_anchor_generator.grid_anchors(
featmap_sizes, device=device)
squares_list = [multi_level_squares for _ in range(num_imgs)]
return squares_list
def get_target(self,
approx_list,
inside_flag_list,
square_list,
gt_bboxes_list,
img_metas,
gt_bboxes_ignore_list=None,
gt_labels_list=None,
label_channels=None,
sampling=True,
unmap_outputs=True):
"""Compute bucketing targets.
Args:
approx_list (list[list]): Multi level approxs of each image.
inside_flag_list (list[list]): Multi level inside flags of each
image.
square_list (list[list]): Multi level squares of each image.
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
img_metas (list[dict]): Meta info of each image.
gt_bboxes_ignore_list (list[Tensor]): ignore list of gt bboxes.
gt_bboxes_list (list[Tensor]): Gt bboxes of each image.
label_channels (int): Channel of label.
sampling (bool): Sample Anchors or not.
unmap_outputs (bool): unmap outputs or not.
Returns:
tuple: Returns a tuple containing learning targets.
- labels_list (list[Tensor]): Labels of each level.
- label_weights_list (list[Tensor]): Label weights of each \
level.
- bbox_cls_targets_list (list[Tensor]): BBox cls targets of \
each level.
- bbox_cls_weights_list (list[Tensor]): BBox cls weights of \
each level.
- bbox_reg_targets_list (list[Tensor]): BBox reg targets of \
each level.
- bbox_reg_weights_list (list[Tensor]): BBox reg weights of \
each level.
- num_total_pos (int): Number of positive samples in all \
images.
- num_total_neg (int): Number of negative samples in all \
images.
"""
num_imgs = len(img_metas)
assert len(approx_list) == len(inside_flag_list) == len(
square_list) == num_imgs
# anchor number of multi levels
num_level_squares = [squares.size(0) for squares in square_list[0]]
# concat all level anchors and flags to a single tensor
inside_flag_flat_list = []
approx_flat_list = []
square_flat_list = []
for i in range(num_imgs):
assert len(square_list[i]) == len(inside_flag_list[i])
inside_flag_flat_list.append(torch.cat(inside_flag_list[i]))
approx_flat_list.append(torch.cat(approx_list[i]))
square_flat_list.append(torch.cat(square_list[i]))
# compute targets for each image
if gt_bboxes_ignore_list is None:
gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
if gt_labels_list is None:
gt_labels_list = [None for _ in range(num_imgs)]
(all_labels, all_label_weights, all_bbox_cls_targets,
all_bbox_cls_weights, all_bbox_reg_targets, all_bbox_reg_weights,
pos_inds_list, neg_inds_list) = multi_apply(
self._get_target_single,
approx_flat_list,
inside_flag_flat_list,
square_flat_list,
gt_bboxes_list,
gt_bboxes_ignore_list,
gt_labels_list,
img_metas,
label_channels=label_channels,
sampling=sampling,
unmap_outputs=unmap_outputs)
# no valid anchors
if any([labels is None for labels in all_labels]):
return None
# sampled anchors of all images
num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
# split targets to a list w.r.t. multiple levels
labels_list = images_to_levels(all_labels, num_level_squares)
label_weights_list = images_to_levels(all_label_weights,
num_level_squares)
bbox_cls_targets_list = images_to_levels(all_bbox_cls_targets,
num_level_squares)
bbox_cls_weights_list = images_to_levels(all_bbox_cls_weights,
num_level_squares)
bbox_reg_targets_list = images_to_levels(all_bbox_reg_targets,
num_level_squares)
bbox_reg_weights_list = images_to_levels(all_bbox_reg_weights,
num_level_squares)
return (labels_list, label_weights_list, bbox_cls_targets_list,
bbox_cls_weights_list, bbox_reg_targets_list,
bbox_reg_weights_list, num_total_pos, num_total_neg)
def _get_target_single(self,
flat_approxs,
inside_flags,
flat_squares,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
img_meta,
label_channels=None,
sampling=True,
unmap_outputs=True):
"""Compute regression and classification targets for anchors in a
single image.
Args:
flat_approxs (Tensor): flat approxs of a single image,
shape (n, 4)
inside_flags (Tensor): inside flags of a single image,
shape (n, ).
flat_squares (Tensor): flat squares of a single image,
shape (approxs_per_octave * n, 4)
gt_bboxes (Tensor): Ground truth bboxes of a single image, \
shape (num_gts, 4).
gt_bboxes_ignore (Tensor): Ground truth bboxes to be
ignored, shape (num_ignored_gts, 4).
gt_labels (Tensor): Ground truth labels of each box,
shape (num_gts,).
img_meta (dict): Meta info of the image.
label_channels (int): Channel of label.
sampling (bool): Sample Anchors or not.
unmap_outputs (bool): unmap outputs or not.
Returns:
tuple:
- labels_list (Tensor): Labels in a single image
- label_weights (Tensor): Label weights in a single image
- bbox_cls_targets (Tensor): BBox cls targets in a single image
- bbox_cls_weights (Tensor): BBox cls weights in a single image
- bbox_reg_targets (Tensor): BBox reg targets in a single image
- bbox_reg_weights (Tensor): BBox reg weights in a single image
- num_total_pos (int): Number of positive samples \
in a single image
- num_total_neg (int): Number of negative samples \
in a single image
"""
if not inside_flags.any():
return (None, ) * 8
# assign gt and sample anchors
expand_inside_flags = inside_flags[:, None].expand(
-1, self.approxs_per_octave).reshape(-1)
approxs = flat_approxs[expand_inside_flags, :]
squares = flat_squares[inside_flags, :]
assign_result = self.assigner.assign(approxs, squares,
self.approxs_per_octave,
gt_bboxes, gt_bboxes_ignore)
sampling_result = self.sampler.sample(assign_result, squares,
gt_bboxes)
num_valid_squares = squares.shape[0]
bbox_cls_targets = squares.new_zeros(
(num_valid_squares, self.side_num * 4))
bbox_cls_weights = squares.new_zeros(
(num_valid_squares, self.side_num * 4))
bbox_reg_targets = squares.new_zeros(
(num_valid_squares, self.side_num * 4))
bbox_reg_weights = squares.new_zeros(
(num_valid_squares, self.side_num * 4))
labels = squares.new_full((num_valid_squares, ),
self.background_label,
dtype=torch.long)
label_weights = squares.new_zeros(num_valid_squares, dtype=torch.float)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
(pos_bbox_reg_targets, pos_bbox_reg_weights, pos_bbox_cls_targets,
pos_bbox_cls_weights) = self.bbox_coder.encode(
sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
bbox_cls_targets[pos_inds, :] = pos_bbox_cls_targets
bbox_reg_targets[pos_inds, :] = pos_bbox_reg_targets
bbox_cls_weights[pos_inds, :] = pos_bbox_cls_weights
bbox_reg_weights[pos_inds, :] = pos_bbox_reg_weights
if gt_labels is None:
labels[pos_inds] = 1
else:
labels[pos_inds] = gt_labels[
sampling_result.pos_assigned_gt_inds]
if self.train_cfg.pos_weight <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = self.train_cfg.pos_weight
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
# map up to original set of anchors
if unmap_outputs:
num_total_anchors = flat_squares.size(0)
labels = unmap(
labels,
num_total_anchors,
inside_flags,
fill=self.background_label)
label_weights = unmap(label_weights, num_total_anchors,
inside_flags)
bbox_cls_targets = unmap(bbox_cls_targets, num_total_anchors,
inside_flags)
bbox_cls_weights = unmap(bbox_cls_weights, num_total_anchors,
inside_flags)
bbox_reg_targets = unmap(bbox_reg_targets, num_total_anchors,
inside_flags)
bbox_reg_weights = unmap(bbox_reg_weights, num_total_anchors,
inside_flags)
return (labels, label_weights, bbox_cls_targets, bbox_cls_weights,
bbox_reg_targets, bbox_reg_weights, pos_inds, neg_inds)
def loss_single(self, cls_score, bbox_pred, labels, label_weights,
bbox_cls_targets, bbox_cls_weights, bbox_reg_targets,
bbox_reg_weights, num_total_samples):
# classification loss
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3,
1).reshape(-1, self.cls_out_channels)
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_cls_targets = bbox_cls_targets.reshape(-1, self.side_num * 4)
bbox_cls_weights = bbox_cls_weights.reshape(-1, self.side_num * 4)
bbox_reg_targets = bbox_reg_targets.reshape(-1, self.side_num * 4)
bbox_reg_weights = bbox_reg_weights.reshape(-1, self.side_num * 4)
(bbox_cls_pred, bbox_reg_pred) = bbox_pred
bbox_cls_pred = bbox_cls_pred.permute(0, 2, 3, 1).reshape(
-1, self.side_num * 4)
bbox_reg_pred = bbox_reg_pred.permute(0, 2, 3, 1).reshape(
-1, self.side_num * 4)
loss_bbox_cls = self.loss_bbox_cls(
bbox_cls_pred,
bbox_cls_targets.long(),
bbox_cls_weights,
avg_factor=num_total_samples * 4 * self.side_num)
loss_bbox_reg = self.loss_bbox_reg(
bbox_reg_pred,
bbox_reg_targets,
bbox_reg_weights,
avg_factor=num_total_samples * 4 * self.bbox_coder.offset_topk)
return loss_cls, loss_bbox_cls, loss_bbox_reg
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.approx_anchor_generator.num_levels
device = cls_scores[0].device
# get sampled approxes
approxs_list, inside_flag_list = GuidedAnchorHead.get_sampled_approxs(
self, featmap_sizes, img_metas, device=device)
square_list = self.get_anchors(featmap_sizes, img_metas, device=device)
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = self.get_target(
approxs_list,
inside_flag_list,
square_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
sampling=self.sampling)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_cls_targets_list,
bbox_cls_weights_list, bbox_reg_targets_list, bbox_reg_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (
num_total_pos + num_total_neg if self.sampling else num_total_pos)
losses_cls, losses_bbox_cls, losses_bbox_reg = multi_apply(
self.loss_single,
cls_scores,
bbox_preds,
labels_list,
label_weights_list,
bbox_cls_targets_list,
bbox_cls_weights_list,
bbox_reg_targets_list,
bbox_reg_weights_list,
num_total_samples=num_total_samples)
return dict(
loss_cls=losses_cls,
loss_bbox_cls=losses_bbox_cls,
loss_bbox_reg=losses_bbox_reg)
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def get_bboxes(self,
cls_scores,
bbox_preds,
img_metas,
cfg=None,
rescale=False):
assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
device = cls_scores[0].device
mlvl_anchors = self.get_anchors(
featmap_sizes, img_metas, device=device)
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_cls_pred_list = [
bbox_preds[i][0][img_id].detach() for i in range(num_levels)
]
bbox_reg_pred_list = [
bbox_preds[i][1][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
proposals = self.get_bboxes_single(cls_score_list,
bbox_cls_pred_list,
bbox_reg_pred_list,
mlvl_anchors[img_id], img_shape,
scale_factor, cfg, rescale)
result_list.append(proposals)
return result_list
def get_bboxes_single(self,
cls_scores,
bbox_cls_preds,
bbox_reg_preds,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale=False):
cfg = self.test_cfg if cfg is None else cfg
mlvl_bboxes = []
mlvl_scores = []
mlvl_confids = []
assert len(cls_scores) == len(bbox_cls_preds) == len(
bbox_reg_preds) == len(mlvl_anchors)
for cls_score, bbox_cls_pred, bbox_reg_pred, anchors in zip(
cls_scores, bbox_cls_preds, bbox_reg_preds, mlvl_anchors):
assert cls_score.size()[-2:] == bbox_cls_pred.size(
)[-2:] == bbox_reg_pred.size()[-2::]
cls_score = cls_score.permute(1, 2,
0).reshape(-1, self.cls_out_channels)
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
scores = cls_score.softmax(-1)
bbox_cls_pred = bbox_cls_pred.permute(1, 2, 0).reshape(
-1, self.side_num * 4)
bbox_reg_pred = bbox_reg_pred.permute(1, 2, 0).reshape(
-1, self.side_num * 4)
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and scores.shape[0] > nms_pre:
if self.use_sigmoid_cls:
max_scores, _ = scores.max(dim=1)
else:
max_scores, _ = scores[:, :-1].max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
anchors = anchors[topk_inds, :]
bbox_cls_pred = bbox_cls_pred[topk_inds, :]
bbox_reg_pred = bbox_reg_pred[topk_inds, :]
scores = scores[topk_inds, :]
bbox_preds = [
bbox_cls_pred.contiguous(),
bbox_reg_pred.contiguous()
]
bboxes, confids = self.bbox_coder.decode(
anchors.contiguous(), bbox_preds, max_shape=img_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_confids.append(confids)
mlvl_bboxes = torch.cat(mlvl_bboxes)
if rescale:
mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores)
mlvl_confids = torch.cat(mlvl_confids)
if self.use_sigmoid_cls:
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
det_bboxes, det_labels = multiclass_nms(
mlvl_bboxes,
mlvl_scores,
cfg.score_thr,
cfg.nms,
cfg.max_per_img,
score_factors=mlvl_confids)
return det_bboxes, det_labels

@ -2,8 +2,9 @@ from .bbox_head import BBoxHead
from .convfc_bbox_head import (ConvFCBBoxHead, Shared2FCBBoxHead,
Shared4Conv1FCBBoxHead)
from .double_bbox_head import DoubleConvFCBBoxHead
from .sabl_head import SABLHead
__all__ = [
'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead',
'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead'
'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'SABLHead'
]

@ -0,0 +1,563 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, kaiming_init, normal_init, xavier_init
from mmdet.core import (build_bbox_coder, force_fp32, multi_apply,
multiclass_nms)
from mmdet.models.builder import HEADS, build_loss
from mmdet.models.losses import accuracy
@HEADS.register_module
class SABLHead(nn.Module):
"""Side-Aware Boundary Localization (SABL) for RoI-Head.
Side-Aware features are extracted by conv layers
with an attention mechanism.
Boundary Localization with Bucketing and Bucketing Guided Rescoring
are implemented in BucketingBBoxCoder.
Please refer to https://arxiv.org/abs/1912.04260 for more details.
Args:
cls_in_channels (int): Input channels of cls RoI feature. \
Defaults to 256.
reg_in_channels (int): Input channels of reg RoI feature. \
Defaults to 256.
roi_feat_size (int): Size of RoI features. Defaults to 7.
reg_feat_up_ratio (int): Upsample ratio of reg features. \
Defaults to 2.
reg_pre_kernel (int): Kernel of 2D conv layers before \
attention pooling. Defaults to 3.
reg_post_kernel (int): Kernel of 1D conv layers after \
attention pooling. Defaults to 3.
reg_pre_num (int): Number of pre convs. Defaults to 2.
reg_post_num (int): Number of post convs. Defaults to 1.
num_classes (int): Number of classes in dataset. Defaults to 80.
cls_out_channels (int): Hidden channels in cls fcs. Defaults to 1024.
reg_offset_out_channels (int): Hidden and output channel \
of reg offset branch. Defaults to 256.
reg_cls_out_channels (int): Hidden and output channel \
of reg cls branch. Defaults to 256.
num_cls_fcs (int): Number of fcs for cls branch. Defaults to 1.
num_reg_fcs (int): Number of fcs for reg branch.. Defaults to 0.
reg_class_agnostic (bool): Class agnostic regresion or not. \
Defaults to True.
norm_cfg (dict): Config of norm layers. Defaults to None.
bbox_coder (dict): Config of bbox coder. Defaults 'BucketingBBoxCoder'.
loss_cls (dict): Config of classification loss.
loss_bbox_cls (dict): Config of classification loss for bbox branch.
loss_bbox_reg (dict): Config of regression loss for bbox branch.
"""
def __init__(self,
num_classes,
cls_in_channels=256,
reg_in_channels=256,
roi_feat_size=7,
reg_feat_up_ratio=2,
reg_pre_kernel=3,
reg_post_kernel=3,
reg_pre_num=2,
reg_post_num=1,
cls_out_channels=1024,
reg_offset_out_channels=256,
reg_cls_out_channels=256,
num_cls_fcs=1,
num_reg_fcs=0,
reg_class_agnostic=True,
norm_cfg=None,
bbox_coder=dict(
type='BucketingBBoxCoder',
num_buckets=14,
scale_factor=1.7),
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
loss_bbox_reg=dict(
type='SmoothL1Loss', beta=0.1, loss_weight=1.0)):
super(SABLHead, self).__init__()
self.cls_in_channels = cls_in_channels
self.reg_in_channels = reg_in_channels
self.roi_feat_size = roi_feat_size
self.reg_feat_up_ratio = int(reg_feat_up_ratio)
self.num_buckets = bbox_coder['num_buckets']
assert self.reg_feat_up_ratio // 2 >= 1
self.up_reg_feat_size = roi_feat_size * self.reg_feat_up_ratio
assert self.up_reg_feat_size == bbox_coder['num_buckets']
self.reg_pre_kernel = reg_pre_kernel
self.reg_post_kernel = reg_post_kernel
self.reg_pre_num = reg_pre_num
self.reg_post_num = reg_post_num
self.num_classes = num_classes
self.cls_out_channels = cls_out_channels
self.reg_offset_out_channels = reg_offset_out_channels
self.reg_cls_out_channels = reg_cls_out_channels
self.num_cls_fcs = num_cls_fcs
self.num_reg_fcs = num_reg_fcs
self.reg_class_agnostic = reg_class_agnostic
assert self.reg_class_agnostic
self.norm_cfg = norm_cfg
self.bbox_coder = build_bbox_coder(bbox_coder)
self.loss_cls = build_loss(loss_cls)
self.loss_bbox_cls = build_loss(loss_bbox_cls)
self.loss_bbox_reg = build_loss(loss_bbox_reg)
self.cls_fcs = self._add_fc_branch(self.num_cls_fcs,
self.cls_in_channels,
self.roi_feat_size,
self.cls_out_channels)
self.side_num = int(np.ceil(self.num_buckets / 2))
if self.reg_feat_up_ratio > 1:
self.upsample_x = nn.ConvTranspose1d(
reg_in_channels,
reg_in_channels,
self.reg_feat_up_ratio,
stride=self.reg_feat_up_ratio)
self.upsample_y = nn.ConvTranspose1d(
reg_in_channels,
reg_in_channels,
self.reg_feat_up_ratio,
stride=self.reg_feat_up_ratio)
self.reg_pre_convs = nn.ModuleList()
for i in range(self.reg_pre_num):
reg_pre_conv = ConvModule(
reg_in_channels,
reg_in_channels,
kernel_size=reg_pre_kernel,
padding=reg_pre_kernel // 2,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'))
self.reg_pre_convs.append(reg_pre_conv)
self.reg_post_conv_xs = nn.ModuleList()
for i in range(self.reg_post_num):
reg_post_conv_x = ConvModule(
reg_in_channels,
reg_in_channels,
kernel_size=(1, reg_post_kernel),
padding=(0, reg_post_kernel // 2),
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'))
self.reg_post_conv_xs.append(reg_post_conv_x)
self.reg_post_conv_ys = nn.ModuleList()
for i in range(self.reg_post_num):
reg_post_conv_y = ConvModule(
reg_in_channels,
reg_in_channels,
kernel_size=(reg_post_kernel, 1),
padding=(reg_post_kernel // 2, 0),
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'))
self.reg_post_conv_ys.append(reg_post_conv_y)
self.reg_conv_att_x = nn.Conv2d(reg_in_channels, 1, 1)
self.reg_conv_att_y = nn.Conv2d(reg_in_channels, 1, 1)
self.fc_cls = nn.Linear(self.cls_out_channels, self.num_classes + 1)
self.relu = nn.ReLU(inplace=True)
self.reg_cls_fcs = self._add_fc_branch(self.num_reg_fcs,
self.reg_in_channels, 1,
self.reg_cls_out_channels)
self.reg_offset_fcs = self._add_fc_branch(self.num_reg_fcs,
self.reg_in_channels, 1,
self.reg_offset_out_channels)
self.fc_reg_cls = nn.Linear(self.reg_cls_out_channels, 1)
self.fc_reg_offset = nn.Linear(self.reg_offset_out_channels, 1)
def _add_fc_branch(self, num_branch_fcs, in_channels, roi_feat_size,
fc_out_channels):
in_channels = in_channels * roi_feat_size * roi_feat_size
branch_fcs = nn.ModuleList()
for i in range(num_branch_fcs):
fc_in_channels = (in_channels if i == 0 else fc_out_channels)
branch_fcs.append(nn.Linear(fc_in_channels, fc_out_channels))
return branch_fcs
def init_weights(self):
for module_list in [
self.reg_cls_fcs, self.reg_offset_fcs, self.cls_fcs
]:
for m in module_list.modules():
if isinstance(m, nn.Linear):
xavier_init(m, distribution='uniform')
if self.reg_feat_up_ratio > 1:
kaiming_init(self.upsample_x, distribution='normal')
kaiming_init(self.upsample_y, distribution='normal')
normal_init(self.reg_conv_att_x, 0, 0.01)
normal_init(self.reg_conv_att_y, 0, 0.01)
normal_init(self.fc_reg_offset, 0, 0.001)
normal_init(self.fc_reg_cls, 0, 0.01)
normal_init(self.fc_cls, 0, 0.01)
def cls_forward(self, cls_x):
cls_x = cls_x.view(cls_x.size(0), -1)
for fc in self.cls_fcs:
cls_x = self.relu(fc(cls_x))
cls_score = self.fc_cls(cls_x)
return cls_score
def attention_pool(self, reg_x):
"""Extract direction-specific features fx and fy with attention
methanism."""
reg_fx = reg_x
reg_fy = reg_x
reg_fx_att = self.reg_conv_att_x(reg_fx).sigmoid()
reg_fy_att = self.reg_conv_att_y(reg_fy).sigmoid()
reg_fx_att = reg_fx_att / reg_fx_att.sum(dim=2).unsqueeze(2)
reg_fy_att = reg_fy_att / reg_fy_att.sum(dim=3).unsqueeze(3)
reg_fx = (reg_fx * reg_fx_att).sum(dim=2)
reg_fy = (reg_fy * reg_fy_att).sum(dim=3)
return reg_fx, reg_fy
def side_aware_feature_extractor(self, reg_x):
"""Refine and extract side-aware features without split them."""
for reg_pre_conv in self.reg_pre_convs:
reg_x = reg_pre_conv(reg_x)
reg_fx, reg_fy = self.attention_pool(reg_x)
if self.reg_post_num > 0:
reg_fx = reg_fx.unsqueeze(2)
reg_fy = reg_fy.unsqueeze(3)
for i in range(self.reg_post_num):
reg_fx = self.reg_post_conv_xs[i](reg_fx)
reg_fy = self.reg_post_conv_ys[i](reg_fy)
reg_fx = reg_fx.squeeze(2)
reg_fy = reg_fy.squeeze(3)
if self.reg_feat_up_ratio > 1:
reg_fx = self.relu(self.upsample_x(reg_fx))
reg_fy = self.relu(self.upsample_y(reg_fy))
reg_fx = torch.transpose(reg_fx, 1, 2)
reg_fy = torch.transpose(reg_fy, 1, 2)
return reg_fx.contiguous(), reg_fy.contiguous()
def reg_pred(self, x, offfset_fcs, cls_fcs):
"""Predict bucketing esimation (cls_pred) and fine regression (offset
pred) with side-aware features."""
x_offset = x.view(-1, self.reg_in_channels)
x_cls = x.view(-1, self.reg_in_channels)
for fc in offfset_fcs:
x_offset = self.relu(fc(x_offset))
for fc in cls_fcs:
x_cls = self.relu(fc(x_cls))
offset_pred = self.fc_reg_offset(x_offset)
cls_pred = self.fc_reg_cls(x_cls)
offset_pred = offset_pred.view(x.size(0), -1)
cls_pred = cls_pred.view(x.size(0), -1)
return offset_pred, cls_pred
def side_aware_split(self, feat):
"""Split side-aware features aligned with orders of bucketing
targets."""
l_end = int(np.ceil(self.up_reg_feat_size / 2))
r_start = int(np.floor(self.up_reg_feat_size / 2))
feat_fl = feat[:, :l_end]
feat_fr = feat[:, r_start:].flip(dims=(1, ))
feat_fl = feat_fl.contiguous()
feat_fr = feat_fr.contiguous()
feat = torch.cat([feat_fl, feat_fr], dim=-1)
return feat
def reg_forward(self, reg_x):
outs = self.side_aware_feature_extractor(reg_x)
edge_offset_preds = []
edge_cls_preds = []
reg_fx = outs[0]
reg_fy = outs[1]
offset_pred_x, cls_pred_x = self.reg_pred(reg_fx, self.reg_offset_fcs,
self.reg_cls_fcs)
offset_pred_y, cls_pred_y = self.reg_pred(reg_fy, self.reg_offset_fcs,
self.reg_cls_fcs)
offset_pred_x = self.side_aware_split(offset_pred_x)
offset_pred_y = self.side_aware_split(offset_pred_y)
cls_pred_x = self.side_aware_split(cls_pred_x)
cls_pred_y = self.side_aware_split(cls_pred_y)
edge_offset_preds = torch.cat([offset_pred_x, offset_pred_y], dim=-1)
edge_cls_preds = torch.cat([cls_pred_x, cls_pred_y], dim=-1)
return (edge_cls_preds, edge_offset_preds)
def forward(self, x):
bbox_pred = self.reg_forward(x)
cls_score = self.cls_forward(x)
return cls_score, bbox_pred
def get_targets(self, sampling_results, gt_bboxes, gt_labels,
rcnn_train_cfg):
pos_proposals = [res.pos_bboxes for res in sampling_results]
neg_proposals = [res.neg_bboxes for res in sampling_results]
pos_gt_bboxes = [res.pos_gt_bboxes for res in sampling_results]
pos_gt_labels = [res.pos_gt_labels for res in sampling_results]
cls_reg_targets = self.bucket_target(pos_proposals, neg_proposals,
pos_gt_bboxes, pos_gt_labels,
rcnn_train_cfg)
(labels, label_weights, bucket_cls_targets, bucket_cls_weights,
bucket_offset_targets, bucket_offset_weights) = cls_reg_targets
return (labels, label_weights, (bucket_cls_targets,
bucket_offset_targets),
(bucket_cls_weights, bucket_offset_weights))
def bucket_target(self,
pos_proposals_list,
neg_proposals_list,
pos_gt_bboxes_list,
pos_gt_labels_list,
rcnn_train_cfg,
concat=True):
(labels, label_weights, bucket_cls_targets, bucket_cls_weights,
bucket_offset_targets, bucket_offset_weights) = multi_apply(
self._bucket_target_single,
pos_proposals_list,
neg_proposals_list,
pos_gt_bboxes_list,
pos_gt_labels_list,
cfg=rcnn_train_cfg)
if concat:
labels = torch.cat(labels, 0)
label_weights = torch.cat(label_weights, 0)
bucket_cls_targets = torch.cat(bucket_cls_targets, 0)
bucket_cls_weights = torch.cat(bucket_cls_weights, 0)
bucket_offset_targets = torch.cat(bucket_offset_targets, 0)
bucket_offset_weights = torch.cat(bucket_offset_weights, 0)
return (labels, label_weights, bucket_cls_targets, bucket_cls_weights,
bucket_offset_targets, bucket_offset_weights)
def _bucket_target_single(self, pos_proposals, neg_proposals,
pos_gt_bboxes, pos_gt_labels, cfg):
"""Compute bucketing estimation targets and fine regression targets for
a single image.
Args:
pos_proposals (Tensor): positive proposals of a single image,
Shape (n_pos, 4)
neg_proposals (Tensor): negative proposals of a single image,
Shape (n_neg, 4).
pos_gt_bboxes (Tensor): gt bboxes assigned to positive proposals
of a single image, Shape (n_pos, 4).
pos_gt_labels (Tensor): gt labels assigned to positive proposals
of a single image, Shape (n_pos, ).
cfg (dict): Config of calculating targets
Returns:
tuple:
- labels (Tensor): Labels in a single image. \
Shape (n,).
- label_weights (Tensor): Label weights in a single image.\
Shape (n,)
- bucket_cls_targets (Tensor): Bucket cls targets in \
a single image. Shape (n, num_buckets*2).
- bucket_cls_weights (Tensor): Bucket cls weights in \
a single image. Shape (n, num_buckets*2).
- bucket_offset_targets (Tensor): Bucket offset targets \
in a single image. Shape (n, num_buckets*2).
- bucket_offset_targets (Tensor): Bucket offset weights \
in a single image. Shape (n, num_buckets*2).
"""
num_pos = pos_proposals.size(0)
num_neg = neg_proposals.size(0)
num_samples = num_pos + num_neg
labels = pos_gt_bboxes.new_full((num_samples, ),
self.num_classes,
dtype=torch.long)
label_weights = pos_proposals.new_zeros(num_samples)
bucket_cls_targets = pos_proposals.new_zeros(num_samples,
4 * self.side_num)
bucket_cls_weights = pos_proposals.new_zeros(num_samples,
4 * self.side_num)
bucket_offset_targets = pos_proposals.new_zeros(
num_samples, 4 * self.side_num)
bucket_offset_weights = pos_proposals.new_zeros(
num_samples, 4 * self.side_num)
if num_pos > 0:
labels[:num_pos] = pos_gt_labels
label_weights[:num_pos] = 1.0
(pos_bucket_offset_targets, pos_bucket_offset_weights,
pos_bucket_cls_targets,
pos_bucket_cls_weights) = self.bbox_coder.encode(
pos_proposals, pos_gt_bboxes)
bucket_cls_targets[:num_pos, :] = pos_bucket_cls_targets
bucket_cls_weights[:num_pos, :] = pos_bucket_cls_weights
bucket_offset_targets[:num_pos, :] = pos_bucket_offset_targets
bucket_offset_weights[:num_pos, :] = pos_bucket_offset_weights
if num_neg > 0:
label_weights[-num_neg:] = 1.0
return (labels, label_weights, bucket_cls_targets, bucket_cls_weights,
bucket_offset_targets, bucket_offset_weights)
def loss(self,
cls_score,
bbox_pred,
rois,
labels,
label_weights,
bbox_targets,
bbox_weights,
reduction_override=None):
losses = dict()
if cls_score is not None:
avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
losses['loss_cls'] = self.loss_cls(
cls_score,
labels,
label_weights,
avg_factor=avg_factor,
reduction_override=reduction_override)
losses['acc'] = accuracy(cls_score, labels)
if bbox_pred is not None:
bucket_cls_preds, bucket_offset_preds = bbox_pred
bucket_cls_targets, bucket_offset_targets = bbox_targets
bucket_cls_weights, bucket_offset_weights = bbox_weights
# edge cls
bucket_cls_preds = bucket_cls_preds.view(-1, self.side_num)
bucket_cls_targets = bucket_cls_targets.view(-1, self.side_num)
bucket_cls_weights = bucket_cls_weights.view(-1, self.side_num)
losses['loss_bbox_cls'] = self.loss_bbox_cls(
bucket_cls_preds,
bucket_cls_targets,
bucket_cls_weights,
avg_factor=bucket_cls_targets.size(0),
reduction_override=reduction_override)
losses['loss_bbox_reg'] = self.loss_bbox_reg(
bucket_offset_preds,
bucket_offset_targets,
bucket_offset_weights,
avg_factor=bucket_offset_targets.size(0),
reduction_override=reduction_override)
return losses
@force_fp32(apply_to=('cls_score', 'bbox_pred'))
def get_bboxes(self,
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=False,
cfg=None):
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
scores = F.softmax(cls_score, dim=1) if cls_score is not None else None
if bbox_pred is not None:
bboxes, confids = self.bbox_coder.decode(rois[:, 1:], bbox_pred,
img_shape)
else:
bboxes = rois[:, 1:].clone()
confids = None
if img_shape is not None:
bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1] - 1)
bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0] - 1)
if rescale and bboxes.size(0) > 0:
if isinstance(scale_factor, float):
bboxes /= scale_factor
else:
bboxes /= torch.from_numpy(scale_factor).to(bboxes.device)
if cfg is None:
return bboxes, scores
else:
det_bboxes, det_labels = multiclass_nms(
bboxes,
scores,
cfg.score_thr,
cfg.nms,
cfg.max_per_img,
score_factors=confids)
return det_bboxes, det_labels
@force_fp32(apply_to=('bbox_preds', ))
def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):
"""Refine bboxes during training.
Args:
rois (Tensor): Shape (n*bs, 5), where n is image number per GPU,
and bs is the sampled RoIs per image.
labels (Tensor): Shape (n*bs, ).
bbox_preds (list[Tensor]): Shape [(n*bs, num_buckets*2), \
(n*bs, num_buckets*2)].
pos_is_gts (list[Tensor]): Flags indicating if each positive bbox
is a gt bbox.
img_metas (list[dict]): Meta info of each image.
Returns:
list[Tensor]: Refined bboxes of each image in a mini-batch.
"""
img_ids = rois[:, 0].long().unique(sorted=True)
assert img_ids.numel() == len(img_metas)
bboxes_list = []
for i in range(len(img_metas)):
inds = torch.nonzero(
rois[:, 0] == i, as_tuple=False).squeeze(dim=1)
num_rois = inds.numel()
bboxes_ = rois[inds, 1:]
label_ = labels[inds]
edge_cls_preds, edge_offset_preds = bbox_preds
edge_cls_preds_ = edge_cls_preds[inds]
edge_offset_preds_ = edge_offset_preds[inds]
bbox_pred_ = [edge_cls_preds_, edge_offset_preds_]
img_meta_ = img_metas[i]
pos_is_gts_ = pos_is_gts[i]
bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_,
img_meta_)
# filter gt bboxes
pos_keep = 1 - pos_is_gts_
keep_inds = pos_is_gts_.new_ones(num_rois)
keep_inds[:len(pos_is_gts_)] = pos_keep
bboxes_list.append(bboxes[keep_inds.type(torch.bool)])
return bboxes_list
@force_fp32(apply_to=('bbox_pred', ))
def regress_by_class(self, rois, label, bbox_pred, img_meta):
"""Regress the bbox for the predicted class. Used in Cascade R-CNN.
Args:
rois (Tensor): shape (n, 4) or (n, 5)
label (Tensor): shape (n, )
bbox_pred (list[Tensor]): shape [(n, num_buckets *2), \
(n, num_buckets *2)]
img_meta (dict): Image meta info.
Returns:
Tensor: Regressed bboxes, the same shape as input rois.
"""
assert rois.size(1) == 4 or rois.size(1) == 5
if rois.size(1) == 4:
new_rois, _ = self.bbox_coder.decode(rois, bbox_pred,
img_meta['img_shape'])
else:
bboxes, _ = self.bbox_coder.decode(rois[:, 1:], bbox_pred,
img_meta['img_shape'])
new_rois = torch.cat((rois[:, [0]], bboxes), dim=1)
return new_rois

@ -324,12 +324,21 @@ def _check_bbox_head(bbox_cfg, bbox_head):
_check_bbox_head(bbox_cfg, single_bbox_head)
else:
assert bbox_cfg['type'] == bbox_head.__class__.__name__
if bbox_cfg['type'] == 'SABLHead':
assert bbox_cfg.cls_in_channels == bbox_head.cls_in_channels
assert bbox_cfg.reg_in_channels == bbox_head.reg_in_channels
cls_out_channels = bbox_cfg.get('cls_out_channels', 1024)
assert (cls_out_channels == bbox_head.fc_cls.in_features)
assert (bbox_cfg.num_classes + 1 == bbox_head.fc_cls.out_features)
else:
assert bbox_cfg.in_channels == bbox_head.in_channels
with_cls = bbox_cfg.get('with_cls', True)
if with_cls:
fc_out_channels = bbox_cfg.get('fc_out_channels', 2048)
assert (fc_out_channels == bbox_head.fc_cls.in_features)
assert bbox_cfg.num_classes + 1 == bbox_head.fc_cls.out_features
assert (bbox_cfg.num_classes +
1 == bbox_head.fc_cls.out_features)
with_reg = bbox_cfg.get('with_reg', True)
if with_reg:
@ -350,6 +359,10 @@ def _check_anchorhead(config, head):
assert (config.feat_channels == head.atss_cls.in_channels)
assert (config.feat_channels == head.atss_reg.in_channels)
assert (config.feat_channels == head.atss_centerness.in_channels)
elif config['type'] == 'SABLRetinaHead':
assert (config.feat_channels == head.retina_cls.in_channels)
assert (config.feat_channels == head.retina_bbox_reg.in_channels)
assert (config.feat_channels == head.retina_bbox_cls.in_channels)
else:
assert (config.in_channels == head.conv_cls.in_channels)
assert (config.in_channels == head.conv_reg.in_channels)

@ -6,9 +6,9 @@ from mmdet.core import bbox2roi, build_assigner, build_sampler
from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
from mmdet.models.dense_heads import (AnchorHead, CornerHead, FCOSHead,
FSAFHead, GuidedAnchorHead, PAAHead,
paa_head)
SABLRetinaHead, paa_head)
from mmdet.models.dense_heads.paa_head import levels_to_images
from mmdet.models.roi_heads.bbox_heads import BBoxHead
from mmdet.models.roi_heads.bbox_heads import BBoxHead, SABLHead
from mmdet.models.roi_heads.mask_heads import FCNMaskHead, MaskIoUHead
@ -473,6 +473,147 @@ def test_bbox_head_loss():
assert losses.get('loss_bbox', 0) > 0, 'box-loss should be non-zero'
def test_sabl_bbox_head_loss():
"""Tests bbox head loss when truth is empty and non-empty."""
self = SABLHead(
num_classes=4,
cls_in_channels=3,
reg_in_channels=3,
cls_out_channels=3,
reg_offset_out_channels=3,
reg_cls_out_channels=3,
roi_feat_size=7)
# Dummy proposals
proposal_list = [
torch.Tensor([[23.6667, 23.8757, 228.6326, 153.8874]]),
]
target_cfg = mmcv.Config(dict(pos_weight=1))
# Test bbox loss when truth is empty
gt_bboxes = [torch.empty((0, 4))]
gt_labels = [torch.LongTensor([])]
sampling_results = _dummy_bbox_sampling(proposal_list, gt_bboxes,
gt_labels)
bbox_targets = self.get_targets(sampling_results, gt_bboxes, gt_labels,
target_cfg)
labels, label_weights, bbox_targets, bbox_weights = bbox_targets
# Create dummy features "extracted" for each sampled bbox
num_sampled = sum(len(res.bboxes) for res in sampling_results)
rois = bbox2roi([res.bboxes for res in sampling_results])
dummy_feats = torch.rand(num_sampled, 3, 7, 7)
cls_scores, bbox_preds = self.forward(dummy_feats)
losses = self.loss(cls_scores, bbox_preds, rois, labels, label_weights,
bbox_targets, bbox_weights)
assert losses.get('loss_cls', 0) > 0, 'cls-loss should be non-zero'
assert losses.get('loss_bbox_cls',
0) == 0, 'empty gt bbox-cls-loss should be zero'
assert losses.get('loss_bbox_reg',
0) == 0, 'empty gt bbox-reg-loss should be zero'
# Test bbox loss when truth is non-empty
gt_bboxes = [
torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
]
gt_labels = [torch.LongTensor([2])]
sampling_results = _dummy_bbox_sampling(proposal_list, gt_bboxes,
gt_labels)
rois = bbox2roi([res.bboxes for res in sampling_results])
bbox_targets = self.get_targets(sampling_results, gt_bboxes, gt_labels,
target_cfg)
labels, label_weights, bbox_targets, bbox_weights = bbox_targets
# Create dummy features "extracted" for each sampled bbox
num_sampled = sum(len(res.bboxes) for res in sampling_results)
dummy_feats = torch.rand(num_sampled, 3, 7, 7)
cls_scores, bbox_preds = self.forward(dummy_feats)
losses = self.loss(cls_scores, bbox_preds, rois, labels, label_weights,
bbox_targets, bbox_weights)
assert losses.get('loss_bbox_cls',
0) > 0, 'empty gt bbox-cls-loss should be zero'
assert losses.get('loss_bbox_reg',
0) > 0, 'empty gt bbox-reg-loss should be zero'
def test_sabl_retina_head_loss():
"""Tests anchor head loss when truth is empty and non-empty."""
s = 256
img_metas = [{
'img_shape': (s, s, 3),
'scale_factor': 1,
'pad_shape': (s, s, 3)
}]
cfg = mmcv.Config(
dict(
assigner=dict(
type='ApproxMaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0.0,
ignore_iof_thr=-1),
allowed_border=-1,
pos_weight=-1,
debug=False))
head = SABLRetinaHead(
num_classes=4,
in_channels=3,
feat_channels=10,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
train_cfg=cfg)
if torch.cuda.is_available():
head.cuda()
# Anchor head expects a multiple levels of features per image
feat = [
torch.rand(1, 3, s // (2**(i + 2)), s // (2**(i + 2))).cuda()
for i in range(len(head.approx_anchor_generator.base_anchors))
]
cls_scores, bbox_preds = head.forward(feat)
# Test that empty ground truth encourages the network
# to predict background
gt_bboxes = [torch.empty((0, 4)).cuda()]
gt_labels = [torch.LongTensor([]).cuda()]
gt_bboxes_ignore = None
empty_gt_losses = head.loss(cls_scores, bbox_preds, gt_bboxes,
gt_labels, img_metas, gt_bboxes_ignore)
# When there is no truth, the cls loss should be nonzero but there
# should be no box loss.
empty_cls_loss = sum(empty_gt_losses['loss_cls'])
empty_box_cls_loss = sum(empty_gt_losses['loss_bbox_cls'])
empty_box_reg_loss = sum(empty_gt_losses['loss_bbox_reg'])
assert empty_cls_loss.item() > 0, 'cls loss should be non-zero'
assert empty_box_cls_loss.item() == 0, (
'there should be no box cls loss when there are no true boxes')
assert empty_box_reg_loss.item() == 0, (
'there should be no box reg loss when there are no true boxes')
# When truth is non-empty then both cls and box loss should
# be nonzero for random inputs
gt_bboxes = [
torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]).cuda(),
]
gt_labels = [torch.LongTensor([2]).cuda()]
one_gt_losses = head.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels,
img_metas, gt_bboxes_ignore)
onegt_cls_loss = sum(one_gt_losses['loss_cls'])
onegt_box_cls_loss = sum(one_gt_losses['loss_bbox_cls'])
onegt_box_reg_loss = sum(one_gt_losses['loss_bbox_reg'])
assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero'
assert onegt_box_cls_loss.item() > 0, 'box loss cls should be non-zero'
assert onegt_box_reg_loss.item() > 0, 'box loss reg should be non-zero'
def test_refine_boxes():
"""Mirrors the doctest in
``mmdet.models.bbox_heads.bbox_head.BBoxHead.refine_boxes`` but checks for

Loading…
Cancel
Save