Code Release of ECCV 2020 Spotlight paper for Side-Aware Boundary Localization for More Precise Object Detection (#3603)
* add sabl two stage * add sabl retina * ret cfg bug fix * test bug fix * minor update * update * add r101 two stage * update * add cfgs * add cfgs * update cfgs * format * format * add readme * fix isort * update * update readme * add doc string for sabl retina head * add doc string and rename some functions * add docstring for bucketing coder * update docstring * bucket_num -> num_buckets * bucket_pw -> bucket_w bucket_ph -> bucket_h * update label2onehot * update bucketing bbox coder doc * update * typo fix * bboxes_ -> rescaled_bboxes * rename some params in sabl head * init with mmcv.cnn * update doc * rename pos->post * update cfgs * update test cfg * update * add unitest for sabl head * add unitest for sabl retina * rename * minor rename * minor update * update docstring * update * use F.one_hot * update docstring * update test heads * update ReadMe * fixpull/3680/head
parent
08d1402c1e
commit
26562a1d0e
22 changed files with 2365 additions and 19 deletions
@ -0,0 +1,36 @@ |
||||
# Side-Aware Boundary Localization for More Precise Object Detection |
||||
|
||||
## Introduction |
||||
|
||||
We provide config files to reproduce the object detection results in the ECCV 2020 Spotlight paper for [Side-Aware Boundary Localization for More Precise Object Detection](https://arxiv.org/abs/1912.04260). |
||||
|
||||
``` |
||||
@inproceedings{Wang_2020_ECCV, |
||||
title = {Side-Aware Boundary Localization for More Precise Object Detection}, |
||||
author = {Wang, Jiaqi and Zhang, Wenwei and Cao, Yuhang and Chen, Kai and Pang, Jiangmiao and Gong, Tao and Shi, Jianping, Loy, Chen Change and Lin, Dahua}, |
||||
booktitle = {ECCV}, |
||||
year = {2020} |
||||
} |
||||
``` |
||||
|
||||
## Results and Models |
||||
|
||||
The results on COCO 2017 val is shown in the below table. (results on test-dev are usually slightly higher than val). |
||||
Single-scale testing (1333x800) is adopted in all results. |
||||
|
||||
|
||||
| Method | Backbone | Lr schd | ms-train | box AP | Download | |
||||
| :----------------: | :-------: | :-----: | :------: | :----: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | |
||||
| SABL Faster R-CNN | R-50-FPN | 1x | N | 39.9 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_faster_rcnn_r50_fpn_1x_coco/sabl_faster_rcnn_r50_fpn_1x_coco-e867595b.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_faster_rcnn_r50_fpn_1x_coco/20200830_130324.log.json) | |
||||
| SABL Faster R-CNN | R-101-FPN | 1x | N | 41.7 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_faster_rcnn_r101_fpn_1x_coco/sabl_faster_rcnn_r101_fpn_1x_coco-f804c6c1.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_faster_rcnn_r101_fpn_1x_coco/20200830_183949.log.json) | |
||||
| SABL Cascade R-CNN | R-50-FPN | 1x | N | 41.6 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_cascade_rcnn_r50_fpn_1x_coco/sabl_cascade_rcnn_r50_fpn_1x_coco-e1748e5e.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_cascade_rcnn_r50_fpn_1x_coco/20200831_033726.log.json) | |
||||
| SABL Cascade R-CNN | R-101-FPN | 1x | N | 43.0 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_cascade_rcnn_r101_fpn_1x_coco/sabl_cascade_rcnn_r101_fpn_1x_coco-2b83e87c.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_cascade_rcnn_r101_fpn_1x_coco/20200831_141745.log.json) | |
||||
|
||||
| Method | Backbone | GN | Lr schd | ms-train | box AP | Download | |
||||
| :------------: | :-------: | :---: | :-----: | :---------: | :----: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | |
||||
| SABL RetinaNet | R-50-FPN | N | 1x | N | 37.7 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r50_fpn_1x_coco/sabl_retinanet_r50_fpn_1x_coco-6c54fd4f.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r50_fpn_1x_coco/20200830_053451.log.json) | |
||||
| SABL RetinaNet | R-50-FPN | Y | 1x | N | 38.8 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r50_fpn_gn_1x_coco/sabl_retinanet_r50_fpn_gn_1x_coco-e16dfcf1.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r50_fpn_gn_1x_coco/20200831_141955.log.json) | |
||||
| SABL RetinaNet | R-101-FPN | N | 1x | N | 39.7 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_1x_coco/sabl_retinanet_r101_fpn_1x_coco-42026904.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_1x_coco/20200831_034256.log.json) | |
||||
| SABL RetinaNet | R-101-FPN | Y | 1x | N | 40.5 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_gn_1x_coco/sabl_retinanet_r101_fpn_gn_1x_coco-40a893e8.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_gn_1x_coco/20200830_201422.log.json) | |
||||
| SABL RetinaNet | R-101-FPN | Y | 2x | Y (640~800) | 42.9 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_640_800_coco/sabl_retinanet_r101_fpn_gn_2x_ms_640_800_coco-1e63382c.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_640_800_coco/20200830_144807.log.json) | |
||||
| SABL RetinaNet | R-101-FPN | Y | 2x | Y (480~960) | 43.6 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_480_960_coco/sabl_retinanet_r101_fpn_gn_2x_ms_480_960_coco-5342f857.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/sabl/sabl_retinanet_r101_fpn_gn_2x_ms_480_960_coco/20200830_164537.log.json) | |
@ -0,0 +1,88 @@ |
||||
_base_ = [ |
||||
'../_base_/models/cascade_rcnn_r50_fpn.py', |
||||
'../_base_/datasets/coco_detection.py', |
||||
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' |
||||
] |
||||
# model settings |
||||
model = dict( |
||||
pretrained='torchvision://resnet101', |
||||
backbone=dict(depth=101), |
||||
roi_head=dict(bbox_head=[ |
||||
dict( |
||||
type='SABLHead', |
||||
num_classes=80, |
||||
cls_in_channels=256, |
||||
reg_in_channels=256, |
||||
roi_feat_size=7, |
||||
reg_feat_up_ratio=2, |
||||
reg_pre_kernel=3, |
||||
reg_post_kernel=3, |
||||
reg_pre_num=2, |
||||
reg_post_num=1, |
||||
cls_out_channels=1024, |
||||
reg_offset_out_channels=256, |
||||
reg_cls_out_channels=256, |
||||
num_cls_fcs=1, |
||||
num_reg_fcs=0, |
||||
reg_class_agnostic=True, |
||||
norm_cfg=None, |
||||
bbox_coder=dict( |
||||
type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.7), |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), |
||||
loss_bbox_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), |
||||
loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1, |
||||
loss_weight=1.0)), |
||||
dict( |
||||
type='SABLHead', |
||||
num_classes=80, |
||||
cls_in_channels=256, |
||||
reg_in_channels=256, |
||||
roi_feat_size=7, |
||||
reg_feat_up_ratio=2, |
||||
reg_pre_kernel=3, |
||||
reg_post_kernel=3, |
||||
reg_pre_num=2, |
||||
reg_post_num=1, |
||||
cls_out_channels=1024, |
||||
reg_offset_out_channels=256, |
||||
reg_cls_out_channels=256, |
||||
num_cls_fcs=1, |
||||
num_reg_fcs=0, |
||||
reg_class_agnostic=True, |
||||
norm_cfg=None, |
||||
bbox_coder=dict( |
||||
type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.5), |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), |
||||
loss_bbox_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), |
||||
loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1, |
||||
loss_weight=1.0)), |
||||
dict( |
||||
type='SABLHead', |
||||
num_classes=80, |
||||
cls_in_channels=256, |
||||
reg_in_channels=256, |
||||
roi_feat_size=7, |
||||
reg_feat_up_ratio=2, |
||||
reg_pre_kernel=3, |
||||
reg_post_kernel=3, |
||||
reg_pre_num=2, |
||||
reg_post_num=1, |
||||
cls_out_channels=1024, |
||||
reg_offset_out_channels=256, |
||||
reg_cls_out_channels=256, |
||||
num_cls_fcs=1, |
||||
num_reg_fcs=0, |
||||
reg_class_agnostic=True, |
||||
norm_cfg=None, |
||||
bbox_coder=dict( |
||||
type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.3), |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), |
||||
loss_bbox_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), |
||||
loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1, loss_weight=1.0)) |
||||
])) |
@ -0,0 +1,86 @@ |
||||
_base_ = [ |
||||
'../_base_/models/cascade_rcnn_r50_fpn.py', |
||||
'../_base_/datasets/coco_detection.py', |
||||
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' |
||||
] |
||||
# model settings |
||||
model = dict( |
||||
roi_head=dict(bbox_head=[ |
||||
dict( |
||||
type='SABLHead', |
||||
num_classes=80, |
||||
cls_in_channels=256, |
||||
reg_in_channels=256, |
||||
roi_feat_size=7, |
||||
reg_feat_up_ratio=2, |
||||
reg_pre_kernel=3, |
||||
reg_post_kernel=3, |
||||
reg_pre_num=2, |
||||
reg_post_num=1, |
||||
cls_out_channels=1024, |
||||
reg_offset_out_channels=256, |
||||
reg_cls_out_channels=256, |
||||
num_cls_fcs=1, |
||||
num_reg_fcs=0, |
||||
reg_class_agnostic=True, |
||||
norm_cfg=None, |
||||
bbox_coder=dict( |
||||
type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.7), |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), |
||||
loss_bbox_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), |
||||
loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1, |
||||
loss_weight=1.0)), |
||||
dict( |
||||
type='SABLHead', |
||||
num_classes=80, |
||||
cls_in_channels=256, |
||||
reg_in_channels=256, |
||||
roi_feat_size=7, |
||||
reg_feat_up_ratio=2, |
||||
reg_pre_kernel=3, |
||||
reg_post_kernel=3, |
||||
reg_pre_num=2, |
||||
reg_post_num=1, |
||||
cls_out_channels=1024, |
||||
reg_offset_out_channels=256, |
||||
reg_cls_out_channels=256, |
||||
num_cls_fcs=1, |
||||
num_reg_fcs=0, |
||||
reg_class_agnostic=True, |
||||
norm_cfg=None, |
||||
bbox_coder=dict( |
||||
type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.5), |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), |
||||
loss_bbox_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), |
||||
loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1, |
||||
loss_weight=1.0)), |
||||
dict( |
||||
type='SABLHead', |
||||
num_classes=80, |
||||
cls_in_channels=256, |
||||
reg_in_channels=256, |
||||
roi_feat_size=7, |
||||
reg_feat_up_ratio=2, |
||||
reg_pre_kernel=3, |
||||
reg_post_kernel=3, |
||||
reg_pre_num=2, |
||||
reg_post_num=1, |
||||
cls_out_channels=1024, |
||||
reg_offset_out_channels=256, |
||||
reg_cls_out_channels=256, |
||||
num_cls_fcs=1, |
||||
num_reg_fcs=0, |
||||
reg_class_agnostic=True, |
||||
norm_cfg=None, |
||||
bbox_coder=dict( |
||||
type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.3), |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), |
||||
loss_bbox_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), |
||||
loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1, loss_weight=1.0)) |
||||
])) |
@ -0,0 +1,36 @@ |
||||
_base_ = [ |
||||
'../_base_/models/faster_rcnn_r50_fpn.py', |
||||
'../_base_/datasets/coco_detection.py', |
||||
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' |
||||
] |
||||
model = dict( |
||||
pretrained='torchvision://resnet101', |
||||
backbone=dict(depth=101), |
||||
roi_head=dict( |
||||
bbox_head=dict( |
||||
_delete_=True, |
||||
type='SABLHead', |
||||
num_classes=80, |
||||
cls_in_channels=256, |
||||
reg_in_channels=256, |
||||
roi_feat_size=7, |
||||
reg_feat_up_ratio=2, |
||||
reg_pre_kernel=3, |
||||
reg_post_kernel=3, |
||||
reg_pre_num=2, |
||||
reg_post_num=1, |
||||
cls_out_channels=1024, |
||||
reg_offset_out_channels=256, |
||||
reg_cls_out_channels=256, |
||||
num_cls_fcs=1, |
||||
num_reg_fcs=0, |
||||
reg_class_agnostic=True, |
||||
norm_cfg=None, |
||||
bbox_coder=dict( |
||||
type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.7), |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), |
||||
loss_bbox_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), |
||||
loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1, |
||||
loss_weight=1.0)))) |
@ -0,0 +1,34 @@ |
||||
_base_ = [ |
||||
'../_base_/models/faster_rcnn_r50_fpn.py', |
||||
'../_base_/datasets/coco_detection.py', |
||||
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' |
||||
] |
||||
model = dict( |
||||
roi_head=dict( |
||||
bbox_head=dict( |
||||
_delete_=True, |
||||
type='SABLHead', |
||||
num_classes=80, |
||||
cls_in_channels=256, |
||||
reg_in_channels=256, |
||||
roi_feat_size=7, |
||||
reg_feat_up_ratio=2, |
||||
reg_pre_kernel=3, |
||||
reg_post_kernel=3, |
||||
reg_pre_num=2, |
||||
reg_post_num=1, |
||||
cls_out_channels=1024, |
||||
reg_offset_out_channels=256, |
||||
reg_cls_out_channels=256, |
||||
num_cls_fcs=1, |
||||
num_reg_fcs=0, |
||||
reg_class_agnostic=True, |
||||
norm_cfg=None, |
||||
bbox_coder=dict( |
||||
type='BucketingBBoxCoder', num_buckets=14, scale_factor=1.7), |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), |
||||
loss_bbox_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), |
||||
loss_bbox_reg=dict(type='SmoothL1Loss', beta=0.1, |
||||
loss_weight=1.0)))) |
@ -0,0 +1,52 @@ |
||||
_base_ = [ |
||||
'../_base_/models/retinanet_r50_fpn.py', |
||||
'../_base_/datasets/coco_detection.py', |
||||
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' |
||||
] |
||||
# model settings |
||||
model = dict( |
||||
pretrained='torchvision://resnet101', |
||||
backbone=dict(depth=101), |
||||
bbox_head=dict( |
||||
_delete_=True, |
||||
type='SABLRetinaHead', |
||||
num_classes=80, |
||||
in_channels=256, |
||||
stacked_convs=4, |
||||
feat_channels=256, |
||||
approx_anchor_generator=dict( |
||||
type='AnchorGenerator', |
||||
octave_base_scale=4, |
||||
scales_per_octave=3, |
||||
ratios=[0.5, 1.0, 2.0], |
||||
strides=[8, 16, 32, 64, 128]), |
||||
square_anchor_generator=dict( |
||||
type='AnchorGenerator', |
||||
ratios=[1.0], |
||||
scales=[4], |
||||
strides=[8, 16, 32, 64, 128]), |
||||
bbox_coder=dict( |
||||
type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0), |
||||
loss_cls=dict( |
||||
type='FocalLoss', |
||||
use_sigmoid=True, |
||||
gamma=2.0, |
||||
alpha=0.25, |
||||
loss_weight=1.0), |
||||
loss_bbox_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5), |
||||
loss_bbox_reg=dict( |
||||
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5))) |
||||
# training and testing settings |
||||
train_cfg = dict( |
||||
assigner=dict( |
||||
type='ApproxMaxIoUAssigner', |
||||
pos_iou_thr=0.5, |
||||
neg_iou_thr=0.4, |
||||
min_pos_iou=0.0, |
||||
ignore_iof_thr=-1), |
||||
allowed_border=-1, |
||||
pos_weight=-1, |
||||
debug=False) |
||||
# optimizer |
||||
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) |
@ -0,0 +1,54 @@ |
||||
_base_ = [ |
||||
'../_base_/models/retinanet_r50_fpn.py', |
||||
'../_base_/datasets/coco_detection.py', |
||||
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' |
||||
] |
||||
# model settings |
||||
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) |
||||
model = dict( |
||||
pretrained='torchvision://resnet101', |
||||
backbone=dict(depth=101), |
||||
bbox_head=dict( |
||||
_delete_=True, |
||||
type='SABLRetinaHead', |
||||
num_classes=80, |
||||
in_channels=256, |
||||
stacked_convs=4, |
||||
feat_channels=256, |
||||
approx_anchor_generator=dict( |
||||
type='AnchorGenerator', |
||||
octave_base_scale=4, |
||||
scales_per_octave=3, |
||||
ratios=[0.5, 1.0, 2.0], |
||||
strides=[8, 16, 32, 64, 128]), |
||||
square_anchor_generator=dict( |
||||
type='AnchorGenerator', |
||||
ratios=[1.0], |
||||
scales=[4], |
||||
strides=[8, 16, 32, 64, 128]), |
||||
norm_cfg=norm_cfg, |
||||
bbox_coder=dict( |
||||
type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0), |
||||
loss_cls=dict( |
||||
type='FocalLoss', |
||||
use_sigmoid=True, |
||||
gamma=2.0, |
||||
alpha=0.25, |
||||
loss_weight=1.0), |
||||
loss_bbox_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5), |
||||
loss_bbox_reg=dict( |
||||
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5))) |
||||
# training and testing settings |
||||
train_cfg = dict( |
||||
assigner=dict( |
||||
type='ApproxMaxIoUAssigner', |
||||
pos_iou_thr=0.5, |
||||
neg_iou_thr=0.4, |
||||
min_pos_iou=0.0, |
||||
ignore_iof_thr=-1), |
||||
allowed_border=-1, |
||||
pos_weight=-1, |
||||
debug=False) |
||||
# optimizer |
||||
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) |
@ -0,0 +1,71 @@ |
||||
_base_ = [ |
||||
'../_base_/models/retinanet_r50_fpn.py', |
||||
'../_base_/datasets/coco_detection.py', |
||||
'../_base_/schedules/schedule_2x.py', '../_base_/default_runtime.py' |
||||
] |
||||
# model settings |
||||
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) |
||||
model = dict( |
||||
pretrained='torchvision://resnet101', |
||||
backbone=dict(depth=101), |
||||
bbox_head=dict( |
||||
_delete_=True, |
||||
type='SABLRetinaHead', |
||||
num_classes=80, |
||||
in_channels=256, |
||||
stacked_convs=4, |
||||
feat_channels=256, |
||||
approx_anchor_generator=dict( |
||||
type='AnchorGenerator', |
||||
octave_base_scale=4, |
||||
scales_per_octave=3, |
||||
ratios=[0.5, 1.0, 2.0], |
||||
strides=[8, 16, 32, 64, 128]), |
||||
square_anchor_generator=dict( |
||||
type='AnchorGenerator', |
||||
ratios=[1.0], |
||||
scales=[4], |
||||
strides=[8, 16, 32, 64, 128]), |
||||
norm_cfg=norm_cfg, |
||||
bbox_coder=dict( |
||||
type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0), |
||||
loss_cls=dict( |
||||
type='FocalLoss', |
||||
use_sigmoid=True, |
||||
gamma=2.0, |
||||
alpha=0.25, |
||||
loss_weight=1.0), |
||||
loss_bbox_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5), |
||||
loss_bbox_reg=dict( |
||||
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5))) |
||||
# training and testing settings |
||||
train_cfg = dict( |
||||
assigner=dict( |
||||
type='ApproxMaxIoUAssigner', |
||||
pos_iou_thr=0.5, |
||||
neg_iou_thr=0.4, |
||||
min_pos_iou=0.0, |
||||
ignore_iof_thr=-1), |
||||
allowed_border=-1, |
||||
pos_weight=-1, |
||||
debug=False) |
||||
img_norm_cfg = dict( |
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) |
||||
train_pipeline = [ |
||||
dict(type='LoadImageFromFile'), |
||||
dict(type='LoadAnnotations', with_bbox=True), |
||||
dict( |
||||
type='Resize', |
||||
img_scale=[(1333, 480), (1333, 960)], |
||||
multiscale_mode='range', |
||||
keep_ratio=True), |
||||
dict(type='RandomFlip', flip_ratio=0.5), |
||||
dict(type='Normalize', **img_norm_cfg), |
||||
dict(type='Pad', size_divisor=32), |
||||
dict(type='DefaultFormatBundle'), |
||||
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), |
||||
] |
||||
data = dict(train=dict(pipeline=train_pipeline)) |
||||
# optimizer |
||||
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) |
@ -0,0 +1,71 @@ |
||||
_base_ = [ |
||||
'../_base_/models/retinanet_r50_fpn.py', |
||||
'../_base_/datasets/coco_detection.py', |
||||
'../_base_/schedules/schedule_2x.py', '../_base_/default_runtime.py' |
||||
] |
||||
# model settings |
||||
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) |
||||
model = dict( |
||||
pretrained='torchvision://resnet101', |
||||
backbone=dict(depth=101), |
||||
bbox_head=dict( |
||||
_delete_=True, |
||||
type='SABLRetinaHead', |
||||
num_classes=80, |
||||
in_channels=256, |
||||
stacked_convs=4, |
||||
feat_channels=256, |
||||
approx_anchor_generator=dict( |
||||
type='AnchorGenerator', |
||||
octave_base_scale=4, |
||||
scales_per_octave=3, |
||||
ratios=[0.5, 1.0, 2.0], |
||||
strides=[8, 16, 32, 64, 128]), |
||||
square_anchor_generator=dict( |
||||
type='AnchorGenerator', |
||||
ratios=[1.0], |
||||
scales=[4], |
||||
strides=[8, 16, 32, 64, 128]), |
||||
norm_cfg=norm_cfg, |
||||
bbox_coder=dict( |
||||
type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0), |
||||
loss_cls=dict( |
||||
type='FocalLoss', |
||||
use_sigmoid=True, |
||||
gamma=2.0, |
||||
alpha=0.25, |
||||
loss_weight=1.0), |
||||
loss_bbox_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5), |
||||
loss_bbox_reg=dict( |
||||
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5))) |
||||
# training and testing settings |
||||
train_cfg = dict( |
||||
assigner=dict( |
||||
type='ApproxMaxIoUAssigner', |
||||
pos_iou_thr=0.5, |
||||
neg_iou_thr=0.4, |
||||
min_pos_iou=0.0, |
||||
ignore_iof_thr=-1), |
||||
allowed_border=-1, |
||||
pos_weight=-1, |
||||
debug=False) |
||||
img_norm_cfg = dict( |
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) |
||||
train_pipeline = [ |
||||
dict(type='LoadImageFromFile'), |
||||
dict(type='LoadAnnotations', with_bbox=True), |
||||
dict( |
||||
type='Resize', |
||||
img_scale=[(1333, 640), (1333, 800)], |
||||
multiscale_mode='range', |
||||
keep_ratio=True), |
||||
dict(type='RandomFlip', flip_ratio=0.5), |
||||
dict(type='Normalize', **img_norm_cfg), |
||||
dict(type='Pad', size_divisor=32), |
||||
dict(type='DefaultFormatBundle'), |
||||
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), |
||||
] |
||||
data = dict(train=dict(pipeline=train_pipeline)) |
||||
# optimizer |
||||
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) |
@ -0,0 +1,50 @@ |
||||
_base_ = [ |
||||
'../_base_/models/retinanet_r50_fpn.py', |
||||
'../_base_/datasets/coco_detection.py', |
||||
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' |
||||
] |
||||
# model settings |
||||
model = dict( |
||||
bbox_head=dict( |
||||
_delete_=True, |
||||
type='SABLRetinaHead', |
||||
num_classes=80, |
||||
in_channels=256, |
||||
stacked_convs=4, |
||||
feat_channels=256, |
||||
approx_anchor_generator=dict( |
||||
type='AnchorGenerator', |
||||
octave_base_scale=4, |
||||
scales_per_octave=3, |
||||
ratios=[0.5, 1.0, 2.0], |
||||
strides=[8, 16, 32, 64, 128]), |
||||
square_anchor_generator=dict( |
||||
type='AnchorGenerator', |
||||
ratios=[1.0], |
||||
scales=[4], |
||||
strides=[8, 16, 32, 64, 128]), |
||||
bbox_coder=dict( |
||||
type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0), |
||||
loss_cls=dict( |
||||
type='FocalLoss', |
||||
use_sigmoid=True, |
||||
gamma=2.0, |
||||
alpha=0.25, |
||||
loss_weight=1.0), |
||||
loss_bbox_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5), |
||||
loss_bbox_reg=dict( |
||||
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5))) |
||||
# training and testing settings |
||||
train_cfg = dict( |
||||
assigner=dict( |
||||
type='ApproxMaxIoUAssigner', |
||||
pos_iou_thr=0.5, |
||||
neg_iou_thr=0.4, |
||||
min_pos_iou=0.0, |
||||
ignore_iof_thr=-1), |
||||
allowed_border=-1, |
||||
pos_weight=-1, |
||||
debug=False) |
||||
# optimizer |
||||
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) |
@ -0,0 +1,52 @@ |
||||
_base_ = [ |
||||
'../_base_/models/retinanet_r50_fpn.py', |
||||
'../_base_/datasets/coco_detection.py', |
||||
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' |
||||
] |
||||
# model settings |
||||
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) |
||||
model = dict( |
||||
bbox_head=dict( |
||||
_delete_=True, |
||||
type='SABLRetinaHead', |
||||
num_classes=80, |
||||
in_channels=256, |
||||
stacked_convs=4, |
||||
feat_channels=256, |
||||
approx_anchor_generator=dict( |
||||
type='AnchorGenerator', |
||||
octave_base_scale=4, |
||||
scales_per_octave=3, |
||||
ratios=[0.5, 1.0, 2.0], |
||||
strides=[8, 16, 32, 64, 128]), |
||||
square_anchor_generator=dict( |
||||
type='AnchorGenerator', |
||||
ratios=[1.0], |
||||
scales=[4], |
||||
strides=[8, 16, 32, 64, 128]), |
||||
norm_cfg=norm_cfg, |
||||
bbox_coder=dict( |
||||
type='BucketingBBoxCoder', num_buckets=14, scale_factor=3.0), |
||||
loss_cls=dict( |
||||
type='FocalLoss', |
||||
use_sigmoid=True, |
||||
gamma=2.0, |
||||
alpha=0.25, |
||||
loss_weight=1.0), |
||||
loss_bbox_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.5), |
||||
loss_bbox_reg=dict( |
||||
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5))) |
||||
# training and testing settings |
||||
train_cfg = dict( |
||||
assigner=dict( |
||||
type='ApproxMaxIoUAssigner', |
||||
pos_iou_thr=0.5, |
||||
neg_iou_thr=0.4, |
||||
min_pos_iou=0.0, |
||||
ignore_iof_thr=-1), |
||||
allowed_border=-1, |
||||
pos_weight=-1, |
||||
debug=False) |
||||
# optimizer |
||||
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) |
@ -0,0 +1,339 @@ |
||||
import numpy as np |
||||
import torch |
||||
import torch.nn.functional as F |
||||
|
||||
from ..builder import BBOX_CODERS |
||||
from ..transforms import bbox_rescale |
||||
from .base_bbox_coder import BaseBBoxCoder |
||||
|
||||
|
||||
@BBOX_CODERS.register_module() |
||||
class BucketingBBoxCoder(BaseBBoxCoder): |
||||
"""Bucketing BBox Coder for Side-Aware Bounday Localization (SABL). |
||||
|
||||
Boundary Localization with Bucketing and Bucketing Guided Rescoring |
||||
are implemented here. |
||||
|
||||
Please refer to https://arxiv.org/abs/1912.04260 for more details. |
||||
|
||||
Args: |
||||
num_buckets (int): Number of buckets. |
||||
scale_factor (int): Scale factor of proposals to generate buckets. |
||||
offset_topk (int): Topk buckets are used to generate |
||||
bucket fine regression targets. Defaults to 2. |
||||
offset_upperbound (float): Offset upperbound to generate |
||||
bucket fine regression targets. |
||||
To avoid too large offset displacements. Defaults to 1.0. |
||||
cls_ignore_neighbor (bool): Ignore second nearest bucket or Not. |
||||
Defaults to True. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
num_buckets, |
||||
scale_factor, |
||||
offset_topk=2, |
||||
offset_upperbound=1.0, |
||||
cls_ignore_neighbor=True): |
||||
super(BucketingBBoxCoder, self).__init__() |
||||
self.num_buckets = num_buckets |
||||
self.scale_factor = scale_factor |
||||
self.offset_topk = offset_topk |
||||
self.offset_upperbound = offset_upperbound |
||||
self.cls_ignore_neighbor = cls_ignore_neighbor |
||||
|
||||
def encode(self, bboxes, gt_bboxes): |
||||
"""Get bucketing estimation and fine regression targets during |
||||
training. |
||||
|
||||
Args: |
||||
bboxes (torch.Tensor): source boxes, e.g., object proposals. |
||||
gt_bboxes (torch.Tensor): target of the transformation, e.g., |
||||
ground truth boxes. |
||||
|
||||
Returns: |
||||
encoded_bboxes(tuple[Tensor]): bucketing estimation |
||||
and fine regression targets and weights |
||||
""" |
||||
|
||||
assert bboxes.size(0) == gt_bboxes.size(0) |
||||
assert bboxes.size(-1) == gt_bboxes.size(-1) == 4 |
||||
encoded_bboxes = bbox2bucket(bboxes, gt_bboxes, self.num_buckets, |
||||
self.scale_factor, self.offset_topk, |
||||
self.offset_upperbound, |
||||
self.cls_ignore_neighbor) |
||||
return encoded_bboxes |
||||
|
||||
def decode(self, bboxes, pred_bboxes, max_shape=None): |
||||
"""Apply transformation `pred_bboxes` to `boxes`. |
||||
Args: |
||||
boxes (torch.Tensor): Basic boxes. |
||||
pred_bboxes (torch.Tensor): Predictions for bucketing estimation |
||||
and fine regression |
||||
max_shape (tuple[int], optional): Maximum shape of boxes. |
||||
Defaults to None. |
||||
|
||||
Returns: |
||||
torch.Tensor: Decoded boxes. |
||||
""" |
||||
assert len(pred_bboxes) == 2 |
||||
cls_preds, offset_preds = pred_bboxes |
||||
assert cls_preds.size(0) == bboxes.size(0) and offset_preds.size( |
||||
0) == bboxes.size(0) |
||||
decoded_bboxes = bucket2bbox(bboxes, cls_preds, offset_preds, |
||||
self.num_buckets, self.scale_factor, |
||||
max_shape) |
||||
|
||||
return decoded_bboxes |
||||
|
||||
|
||||
def generat_buckets(proposals, num_buckets, scale_factor=1.0): |
||||
"""Generate buckets w.r.t bucket number and scale factor of proposals. |
||||
|
||||
Args: |
||||
proposals (Tensor): Shape (n, 4) |
||||
num_buckets (int): Number of buckets. |
||||
scale_factor (float): Scale factor to rescale proposals. |
||||
|
||||
Returns: |
||||
tuple[Tensor]: (bucket_w, bucket_h, l_buckets, r_buckets, |
||||
t_buckets, d_buckets) |
||||
|
||||
- bucket_w: Width of buckets on x-axis. Shape (n, ). |
||||
- bucket_h: Height of buckets on y-axis. Shape (n, ). |
||||
- l_buckets: Left buckets. Shape (n, ceil(side_num/2)). |
||||
- r_buckets: Right buckets. Shape (n, ceil(side_num/2)). |
||||
- t_buckets: Top buckets. Shape (n, ceil(side_num/2)). |
||||
- d_buckets: Down buckets. Shape (n, ceil(side_num/2)). |
||||
""" |
||||
proposals = bbox_rescale(proposals, scale_factor) |
||||
|
||||
# number of buckets in each side |
||||
side_num = int(np.ceil(num_buckets / 2.0)) |
||||
pw = proposals[..., 2] - proposals[..., 0] |
||||
ph = proposals[..., 3] - proposals[..., 1] |
||||
px1 = proposals[..., 0] |
||||
py1 = proposals[..., 1] |
||||
px2 = proposals[..., 2] |
||||
py2 = proposals[..., 3] |
||||
|
||||
bucket_w = pw / num_buckets |
||||
bucket_h = ph / num_buckets |
||||
|
||||
# left buckets |
||||
l_buckets = px1[:, None] + (0.5 + torch.arange( |
||||
0, side_num).to(proposals).float())[None, :] * bucket_w[:, None] |
||||
# right buckets |
||||
r_buckets = px2[:, None] - (0.5 + torch.arange( |
||||
0, side_num).to(proposals).float())[None, :] * bucket_w[:, None] |
||||
# top buckets |
||||
t_buckets = py1[:, None] + (0.5 + torch.arange( |
||||
0, side_num).to(proposals).float())[None, :] * bucket_h[:, None] |
||||
# down buckets |
||||
d_buckets = py2[:, None] - (0.5 + torch.arange( |
||||
0, side_num).to(proposals).float())[None, :] * bucket_h[:, None] |
||||
return bucket_w, bucket_h, l_buckets, r_buckets, t_buckets, d_buckets |
||||
|
||||
|
||||
def bbox2bucket(proposals, |
||||
gt, |
||||
num_buckets, |
||||
scale_factor, |
||||
offset_topk=2, |
||||
offset_upperbound=1.0, |
||||
cls_ignore_neighbor=True): |
||||
"""Generate buckets estimation and fine regression targets. |
||||
|
||||
Args: |
||||
proposals (Tensor): Shape (n, 4) |
||||
gt (Tensor): Shape (n, 4) |
||||
num_buckets (int): Number of buckets. |
||||
scale_factor (float): Scale factor to rescale proposals. |
||||
offset_topk (int): Topk buckets are used to generate |
||||
bucket fine regression targets. Defaults to 2. |
||||
offset_upperbound (float): Offset allowance to generate |
||||
bucket fine regression targets. |
||||
To avoid too large offset displacements. Defaults to 1.0. |
||||
cls_ignore_neighbor (bool): Ignore second nearest bucket or Not. |
||||
Defaults to True. |
||||
|
||||
Returns: |
||||
tuple[Tensor]: (offsets, offsets_weights, bucket_labels, cls_weights). |
||||
|
||||
- offsets: Fine regression targets. \ |
||||
Shape (n, num_buckets*2). |
||||
- offsets_weights: Fine regression weights. \ |
||||
Shape (n, num_buckets*2). |
||||
- bucket_labels: Bucketing estimation labels. \ |
||||
Shape (n, num_buckets*2). |
||||
- cls_weights: Bucketing estimation weights. \ |
||||
Shape (n, num_buckets*2). |
||||
""" |
||||
assert proposals.size() == gt.size() |
||||
|
||||
# generate buckets |
||||
proposals = proposals.float() |
||||
gt = gt.float() |
||||
(bucket_w, bucket_h, l_buckets, r_buckets, t_buckets, |
||||
d_buckets) = generat_buckets(proposals, num_buckets, scale_factor) |
||||
|
||||
gx1 = gt[..., 0] |
||||
gy1 = gt[..., 1] |
||||
gx2 = gt[..., 2] |
||||
gy2 = gt[..., 3] |
||||
|
||||
# generate offset targets and weights |
||||
# offsets from buckets to gts |
||||
l_offsets = (l_buckets - gx1[:, None]) / bucket_w[:, None] |
||||
r_offsets = (r_buckets - gx2[:, None]) / bucket_w[:, None] |
||||
t_offsets = (t_buckets - gy1[:, None]) / bucket_h[:, None] |
||||
d_offsets = (d_buckets - gy2[:, None]) / bucket_h[:, None] |
||||
|
||||
# select top-k nearset buckets |
||||
l_topk, l_label = l_offsets.abs().topk( |
||||
offset_topk, dim=1, largest=False, sorted=True) |
||||
r_topk, r_label = r_offsets.abs().topk( |
||||
offset_topk, dim=1, largest=False, sorted=True) |
||||
t_topk, t_label = t_offsets.abs().topk( |
||||
offset_topk, dim=1, largest=False, sorted=True) |
||||
d_topk, d_label = d_offsets.abs().topk( |
||||
offset_topk, dim=1, largest=False, sorted=True) |
||||
|
||||
offset_l_weights = l_offsets.new_zeros(l_offsets.size()) |
||||
offset_r_weights = r_offsets.new_zeros(r_offsets.size()) |
||||
offset_t_weights = t_offsets.new_zeros(t_offsets.size()) |
||||
offset_d_weights = d_offsets.new_zeros(d_offsets.size()) |
||||
inds = torch.arange(0, proposals.size(0)).to(proposals).long() |
||||
|
||||
# generate offset weights of top-k nearset buckets |
||||
for k in range(offset_topk): |
||||
if k >= 1: |
||||
offset_l_weights[inds, l_label[:, |
||||
k]] = (l_topk[:, k] < |
||||
offset_upperbound).float() |
||||
offset_r_weights[inds, r_label[:, |
||||
k]] = (r_topk[:, k] < |
||||
offset_upperbound).float() |
||||
offset_t_weights[inds, t_label[:, |
||||
k]] = (t_topk[:, k] < |
||||
offset_upperbound).float() |
||||
offset_d_weights[inds, d_label[:, |
||||
k]] = (d_topk[:, k] < |
||||
offset_upperbound).float() |
||||
else: |
||||
offset_l_weights[inds, l_label[:, k]] = 1.0 |
||||
offset_r_weights[inds, r_label[:, k]] = 1.0 |
||||
offset_t_weights[inds, t_label[:, k]] = 1.0 |
||||
offset_d_weights[inds, d_label[:, k]] = 1.0 |
||||
|
||||
offsets = torch.cat([l_offsets, r_offsets, t_offsets, d_offsets], dim=-1) |
||||
offsets_weights = torch.cat([ |
||||
offset_l_weights, offset_r_weights, offset_t_weights, offset_d_weights |
||||
], |
||||
dim=-1) |
||||
|
||||
# generate bucket labels and weight |
||||
side_num = int(np.ceil(num_buckets / 2.0)) |
||||
labels = torch.stack( |
||||
[l_label[:, 0], r_label[:, 0], t_label[:, 0], d_label[:, 0]], dim=-1) |
||||
|
||||
batch_size = labels.size(0) |
||||
bucket_labels = F.one_hot(labels.view(-1), side_num).view(batch_size, |
||||
-1).float() |
||||
bucket_cls_l_weights = (l_offsets.abs() < 1).float() |
||||
bucket_cls_r_weights = (r_offsets.abs() < 1).float() |
||||
bucket_cls_t_weights = (t_offsets.abs() < 1).float() |
||||
bucket_cls_d_weights = (d_offsets.abs() < 1).float() |
||||
bucket_cls_weights = torch.cat([ |
||||
bucket_cls_l_weights, bucket_cls_r_weights, bucket_cls_t_weights, |
||||
bucket_cls_d_weights |
||||
], |
||||
dim=-1) |
||||
# ignore second nearest buckets for cls if necessay |
||||
if cls_ignore_neighbor: |
||||
bucket_cls_weights = (~((bucket_cls_weights == 1) & |
||||
(bucket_labels == 0))).float() |
||||
else: |
||||
bucket_cls_weights[:] = 1.0 |
||||
return offsets, offsets_weights, bucket_labels, bucket_cls_weights |
||||
|
||||
|
||||
def bucket2bbox(proposals, |
||||
cls_preds, |
||||
offset_preds, |
||||
num_buckets, |
||||
scale_factor=1.0, |
||||
max_shape=None): |
||||
"""Apply bucketing estimation (cls preds) and fine regression (offset |
||||
preds) to generate det bboxes. |
||||
|
||||
Args: |
||||
proposals (Tensor): Boxes to be transformed. Shape (n, 4) |
||||
cls_preds (Tensor): bucketing estimation. Shape (n, num_buckets*2). |
||||
offset_preds (Tensor): fine regression. Shape (n, num_buckets*2). |
||||
num_buckets (int): Number of buckets. |
||||
scale_factor (float): Scale factor to rescale proposals. |
||||
max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W) |
||||
|
||||
Returns: |
||||
tuple[Tensor]: (bboxes, loc_confidence). |
||||
|
||||
- bboxes: predicted bboxes. Shape (n, 4) |
||||
- loc_confidence: localization confidence of predicted bboxes. |
||||
Shape (n,). |
||||
""" |
||||
|
||||
side_num = int(np.ceil(num_buckets / 2.0)) |
||||
cls_preds = cls_preds.view(-1, side_num) |
||||
offset_preds = offset_preds.view(-1, side_num) |
||||
|
||||
scores = F.softmax(cls_preds, dim=1) |
||||
score_topk, score_label = scores.topk(2, dim=1, largest=True, sorted=True) |
||||
|
||||
rescaled_proposals = bbox_rescale(proposals, scale_factor) |
||||
|
||||
pw = rescaled_proposals[..., 2] - rescaled_proposals[..., 0] |
||||
ph = rescaled_proposals[..., 3] - rescaled_proposals[..., 1] |
||||
px1 = rescaled_proposals[..., 0] |
||||
py1 = rescaled_proposals[..., 1] |
||||
px2 = rescaled_proposals[..., 2] |
||||
py2 = rescaled_proposals[..., 3] |
||||
|
||||
bucket_w = pw / num_buckets |
||||
bucket_h = ph / num_buckets |
||||
|
||||
score_inds_l = score_label[0::4, 0] |
||||
score_inds_r = score_label[1::4, 0] |
||||
score_inds_t = score_label[2::4, 0] |
||||
score_inds_d = score_label[3::4, 0] |
||||
l_buckets = px1 + (0.5 + score_inds_l.float()) * bucket_w |
||||
r_buckets = px2 - (0.5 + score_inds_r.float()) * bucket_w |
||||
t_buckets = py1 + (0.5 + score_inds_t.float()) * bucket_h |
||||
d_buckets = py2 - (0.5 + score_inds_d.float()) * bucket_h |
||||
|
||||
offsets = offset_preds.view(-1, 4, side_num) |
||||
inds = torch.arange(proposals.size(0)).to(proposals).long() |
||||
l_offsets = offsets[:, 0, :][inds, score_inds_l] |
||||
r_offsets = offsets[:, 1, :][inds, score_inds_r] |
||||
t_offsets = offsets[:, 2, :][inds, score_inds_t] |
||||
d_offsets = offsets[:, 3, :][inds, score_inds_d] |
||||
|
||||
x1 = l_buckets - l_offsets * bucket_w |
||||
x2 = r_buckets - r_offsets * bucket_w |
||||
y1 = t_buckets - t_offsets * bucket_h |
||||
y2 = d_buckets - d_offsets * bucket_h |
||||
|
||||
if max_shape is not None: |
||||
x1 = x1.clamp(min=0, max=max_shape[1] - 1) |
||||
y1 = y1.clamp(min=0, max=max_shape[0] - 1) |
||||
x2 = x2.clamp(min=0, max=max_shape[1] - 1) |
||||
y2 = y2.clamp(min=0, max=max_shape[0] - 1) |
||||
bboxes = torch.cat([x1[:, None], y1[:, None], x2[:, None], y2[:, None]], |
||||
dim=-1) |
||||
|
||||
# bucketing guided rescoring |
||||
loc_confidence = score_topk[:, 0] |
||||
top2_neighbor_inds = (score_label[:, 0] - score_label[:, 1]).abs() == 1 |
||||
loc_confidence += score_topk[:, 1] * top2_neighbor_inds.float() |
||||
loc_confidence = loc_confidence.view(-1, 4).mean(dim=1) |
||||
|
||||
return bboxes, loc_confidence |
@ -0,0 +1,622 @@ |
||||
import numpy as np |
||||
import torch |
||||
import torch.nn as nn |
||||
from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init |
||||
|
||||
from mmdet.core import (build_anchor_generator, build_assigner, |
||||
build_bbox_coder, build_sampler, force_fp32, |
||||
images_to_levels, multi_apply, multiclass_nms, unmap) |
||||
from ..builder import HEADS, build_loss |
||||
from .base_dense_head import BaseDenseHead |
||||
from .guided_anchor_head import GuidedAnchorHead |
||||
|
||||
|
||||
@HEADS.register_module |
||||
class SABLRetinaHead(BaseDenseHead): |
||||
"""Side-Aware Boundary Localization (SABL) for RetinaNet. |
||||
|
||||
The anchor generation, assigning and sampling in SABLRetinaHead |
||||
are the same as GuidedAnchorHead for guided anchoring. |
||||
|
||||
Please refer to https://arxiv.org/abs/1912.04260 for more details. |
||||
|
||||
Args: |
||||
num_classes (int): Number of classes. |
||||
in_channels (int): Number of channels in the input feature map. |
||||
stacked_convs (int): Number of Convs for classification \ |
||||
and regression branches. Defaults to 4. |
||||
feat_channels (int): Number of hidden channels. \ |
||||
Defaults to 256. |
||||
approx_anchor_generator (dict): Config dict for approx generator. |
||||
square_anchor_generator (dict): Config dict for square generator. |
||||
conv_cfg (dict): Config dict for ConvModule. Defaults to None. |
||||
norm_cfg (dict): Config dict for Norm Layer. Defaults to None. |
||||
bbox_coder (dict): Config dict for bbox coder. |
||||
reg_decoded_bbox (bool): Whether to regress decoded bbox. \ |
||||
Defaults to False. |
||||
background_label (int): Background label. Defaults to None. |
||||
train_cfg (dict): Training config of SABLRetinaHead. |
||||
test_cfg (dict): Testing config of SABLRetinaHead. |
||||
loss_cls (dict): Config of classification loss. |
||||
loss_bbox_cls (dict): Config of classification loss for bbox branch. |
||||
loss_bbox_reg (dict): Config of regression loss for bbox branch. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
num_classes, |
||||
in_channels, |
||||
stacked_convs=4, |
||||
feat_channels=256, |
||||
approx_anchor_generator=dict( |
||||
type='AnchorGenerator', |
||||
octave_base_scale=4, |
||||
scales_per_octave=3, |
||||
ratios=[0.5, 1.0, 2.0], |
||||
strides=[8, 16, 32, 64, 128]), |
||||
square_anchor_generator=dict( |
||||
type='AnchorGenerator', |
||||
ratios=[1.0], |
||||
scales=[4], |
||||
strides=[8, 16, 32, 64, 128]), |
||||
conv_cfg=None, |
||||
norm_cfg=None, |
||||
bbox_coder=dict( |
||||
type='BucketingBBoxCoder', |
||||
num_buckets=14, |
||||
scale_factor=3.0), |
||||
reg_decoded_bbox=False, |
||||
background_label=None, |
||||
train_cfg=None, |
||||
test_cfg=None, |
||||
loss_cls=dict( |
||||
type='FocalLoss', |
||||
use_sigmoid=True, |
||||
gamma=2.0, |
||||
alpha=0.25, |
||||
loss_weight=1.0), |
||||
loss_bbox_cls=dict( |
||||
type='CrossEntropyLoss', |
||||
use_sigmoid=True, |
||||
loss_weight=1.5), |
||||
loss_bbox_reg=dict( |
||||
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5)): |
||||
super(SABLRetinaHead, self).__init__() |
||||
self.in_channels = in_channels |
||||
self.num_classes = num_classes |
||||
self.feat_channels = feat_channels |
||||
self.num_buckets = bbox_coder['num_buckets'] |
||||
self.side_num = int(np.ceil(self.num_buckets / 2)) |
||||
|
||||
assert (approx_anchor_generator['octave_base_scale'] == |
||||
square_anchor_generator['scales'][0]) |
||||
assert (approx_anchor_generator['strides'] == |
||||
square_anchor_generator['strides']) |
||||
|
||||
self.approx_anchor_generator = build_anchor_generator( |
||||
approx_anchor_generator) |
||||
self.square_anchor_generator = build_anchor_generator( |
||||
square_anchor_generator) |
||||
self.approxs_per_octave = ( |
||||
self.approx_anchor_generator.num_base_anchors[0]) |
||||
|
||||
# one anchor per location |
||||
self.num_anchors = 1 |
||||
self.stacked_convs = stacked_convs |
||||
self.conv_cfg = conv_cfg |
||||
self.norm_cfg = norm_cfg |
||||
|
||||
self.reg_decoded_bbox = reg_decoded_bbox |
||||
self.background_label = ( |
||||
num_classes if background_label is None else background_label) |
||||
|
||||
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) |
||||
self.sampling = loss_cls['type'] not in [ |
||||
'FocalLoss', 'GHMC', 'QualityFocalLoss' |
||||
] |
||||
if self.use_sigmoid_cls: |
||||
self.cls_out_channels = num_classes |
||||
else: |
||||
self.cls_out_channels = num_classes + 1 |
||||
|
||||
self.bbox_coder = build_bbox_coder(bbox_coder) |
||||
self.loss_cls = build_loss(loss_cls) |
||||
self.loss_bbox_cls = build_loss(loss_bbox_cls) |
||||
self.loss_bbox_reg = build_loss(loss_bbox_reg) |
||||
|
||||
self.train_cfg = train_cfg |
||||
self.test_cfg = test_cfg |
||||
|
||||
if self.train_cfg: |
||||
self.assigner = build_assigner(self.train_cfg.assigner) |
||||
# use PseudoSampler when sampling is False |
||||
if self.sampling and hasattr(self.train_cfg, 'sampler'): |
||||
sampler_cfg = self.train_cfg.sampler |
||||
else: |
||||
sampler_cfg = dict(type='PseudoSampler') |
||||
self.sampler = build_sampler(sampler_cfg, context=self) |
||||
|
||||
self.fp16_enabled = False |
||||
self._init_layers() |
||||
|
||||
def _init_layers(self): |
||||
self.relu = nn.ReLU(inplace=True) |
||||
self.cls_convs = nn.ModuleList() |
||||
self.reg_convs = nn.ModuleList() |
||||
for i in range(self.stacked_convs): |
||||
chn = self.in_channels if i == 0 else self.feat_channels |
||||
self.cls_convs.append( |
||||
ConvModule( |
||||
chn, |
||||
self.feat_channels, |
||||
3, |
||||
stride=1, |
||||
padding=1, |
||||
conv_cfg=self.conv_cfg, |
||||
norm_cfg=self.norm_cfg)) |
||||
self.reg_convs.append( |
||||
ConvModule( |
||||
chn, |
||||
self.feat_channels, |
||||
3, |
||||
stride=1, |
||||
padding=1, |
||||
conv_cfg=self.conv_cfg, |
||||
norm_cfg=self.norm_cfg)) |
||||
self.retina_cls = nn.Conv2d( |
||||
self.feat_channels, self.cls_out_channels, 3, padding=1) |
||||
self.retina_bbox_reg = nn.Conv2d( |
||||
self.feat_channels, self.side_num * 4, 3, padding=1) |
||||
self.retina_bbox_cls = nn.Conv2d( |
||||
self.feat_channels, self.side_num * 4, 3, padding=1) |
||||
|
||||
def init_weights(self): |
||||
for m in self.cls_convs: |
||||
normal_init(m.conv, std=0.01) |
||||
for m in self.reg_convs: |
||||
normal_init(m.conv, std=0.01) |
||||
bias_cls = bias_init_with_prob(0.01) |
||||
normal_init(self.retina_cls, std=0.01, bias=bias_cls) |
||||
normal_init(self.retina_bbox_reg, std=0.01) |
||||
normal_init(self.retina_bbox_cls, std=0.01) |
||||
|
||||
def forward_single(self, x): |
||||
cls_feat = x |
||||
reg_feat = x |
||||
for cls_conv in self.cls_convs: |
||||
cls_feat = cls_conv(cls_feat) |
||||
for reg_conv in self.reg_convs: |
||||
reg_feat = reg_conv(reg_feat) |
||||
cls_score = self.retina_cls(cls_feat) |
||||
bbox_cls_pred = self.retina_bbox_cls(reg_feat) |
||||
bbox_reg_pred = self.retina_bbox_reg(reg_feat) |
||||
bbox_pred = (bbox_cls_pred, bbox_reg_pred) |
||||
return cls_score, bbox_pred |
||||
|
||||
def forward(self, feats): |
||||
return multi_apply(self.forward_single, feats) |
||||
|
||||
def get_anchors(self, featmap_sizes, img_metas, device='cuda'): |
||||
"""Get squares according to feature map sizes and guided anchors. |
||||
|
||||
Args: |
||||
featmap_sizes (list[tuple]): Multi-level feature map sizes. |
||||
img_metas (list[dict]): Image meta info. |
||||
device (torch.device | str): device for returned tensors |
||||
|
||||
Returns: |
||||
tuple: square approxs of each image |
||||
""" |
||||
num_imgs = len(img_metas) |
||||
|
||||
# since feature map sizes of all images are the same, we only compute |
||||
# squares for one time |
||||
multi_level_squares = self.square_anchor_generator.grid_anchors( |
||||
featmap_sizes, device=device) |
||||
squares_list = [multi_level_squares for _ in range(num_imgs)] |
||||
|
||||
return squares_list |
||||
|
||||
def get_target(self, |
||||
approx_list, |
||||
inside_flag_list, |
||||
square_list, |
||||
gt_bboxes_list, |
||||
img_metas, |
||||
gt_bboxes_ignore_list=None, |
||||
gt_labels_list=None, |
||||
label_channels=None, |
||||
sampling=True, |
||||
unmap_outputs=True): |
||||
"""Compute bucketing targets. |
||||
Args: |
||||
approx_list (list[list]): Multi level approxs of each image. |
||||
inside_flag_list (list[list]): Multi level inside flags of each |
||||
image. |
||||
square_list (list[list]): Multi level squares of each image. |
||||
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. |
||||
img_metas (list[dict]): Meta info of each image. |
||||
gt_bboxes_ignore_list (list[Tensor]): ignore list of gt bboxes. |
||||
gt_bboxes_list (list[Tensor]): Gt bboxes of each image. |
||||
label_channels (int): Channel of label. |
||||
sampling (bool): Sample Anchors or not. |
||||
unmap_outputs (bool): unmap outputs or not. |
||||
|
||||
Returns: |
||||
tuple: Returns a tuple containing learning targets. |
||||
|
||||
- labels_list (list[Tensor]): Labels of each level. |
||||
- label_weights_list (list[Tensor]): Label weights of each \ |
||||
level. |
||||
- bbox_cls_targets_list (list[Tensor]): BBox cls targets of \ |
||||
each level. |
||||
- bbox_cls_weights_list (list[Tensor]): BBox cls weights of \ |
||||
each level. |
||||
- bbox_reg_targets_list (list[Tensor]): BBox reg targets of \ |
||||
each level. |
||||
- bbox_reg_weights_list (list[Tensor]): BBox reg weights of \ |
||||
each level. |
||||
- num_total_pos (int): Number of positive samples in all \ |
||||
images. |
||||
- num_total_neg (int): Number of negative samples in all \ |
||||
images. |
||||
""" |
||||
num_imgs = len(img_metas) |
||||
assert len(approx_list) == len(inside_flag_list) == len( |
||||
square_list) == num_imgs |
||||
# anchor number of multi levels |
||||
num_level_squares = [squares.size(0) for squares in square_list[0]] |
||||
# concat all level anchors and flags to a single tensor |
||||
inside_flag_flat_list = [] |
||||
approx_flat_list = [] |
||||
square_flat_list = [] |
||||
for i in range(num_imgs): |
||||
assert len(square_list[i]) == len(inside_flag_list[i]) |
||||
inside_flag_flat_list.append(torch.cat(inside_flag_list[i])) |
||||
approx_flat_list.append(torch.cat(approx_list[i])) |
||||
square_flat_list.append(torch.cat(square_list[i])) |
||||
|
||||
# compute targets for each image |
||||
if gt_bboxes_ignore_list is None: |
||||
gt_bboxes_ignore_list = [None for _ in range(num_imgs)] |
||||
if gt_labels_list is None: |
||||
gt_labels_list = [None for _ in range(num_imgs)] |
||||
(all_labels, all_label_weights, all_bbox_cls_targets, |
||||
all_bbox_cls_weights, all_bbox_reg_targets, all_bbox_reg_weights, |
||||
pos_inds_list, neg_inds_list) = multi_apply( |
||||
self._get_target_single, |
||||
approx_flat_list, |
||||
inside_flag_flat_list, |
||||
square_flat_list, |
||||
gt_bboxes_list, |
||||
gt_bboxes_ignore_list, |
||||
gt_labels_list, |
||||
img_metas, |
||||
label_channels=label_channels, |
||||
sampling=sampling, |
||||
unmap_outputs=unmap_outputs) |
||||
# no valid anchors |
||||
if any([labels is None for labels in all_labels]): |
||||
return None |
||||
# sampled anchors of all images |
||||
num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) |
||||
num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) |
||||
# split targets to a list w.r.t. multiple levels |
||||
labels_list = images_to_levels(all_labels, num_level_squares) |
||||
label_weights_list = images_to_levels(all_label_weights, |
||||
num_level_squares) |
||||
bbox_cls_targets_list = images_to_levels(all_bbox_cls_targets, |
||||
num_level_squares) |
||||
bbox_cls_weights_list = images_to_levels(all_bbox_cls_weights, |
||||
num_level_squares) |
||||
bbox_reg_targets_list = images_to_levels(all_bbox_reg_targets, |
||||
num_level_squares) |
||||
bbox_reg_weights_list = images_to_levels(all_bbox_reg_weights, |
||||
num_level_squares) |
||||
return (labels_list, label_weights_list, bbox_cls_targets_list, |
||||
bbox_cls_weights_list, bbox_reg_targets_list, |
||||
bbox_reg_weights_list, num_total_pos, num_total_neg) |
||||
|
||||
def _get_target_single(self, |
||||
flat_approxs, |
||||
inside_flags, |
||||
flat_squares, |
||||
gt_bboxes, |
||||
gt_bboxes_ignore, |
||||
gt_labels, |
||||
img_meta, |
||||
label_channels=None, |
||||
sampling=True, |
||||
unmap_outputs=True): |
||||
"""Compute regression and classification targets for anchors in a |
||||
single image. |
||||
|
||||
Args: |
||||
flat_approxs (Tensor): flat approxs of a single image, |
||||
shape (n, 4) |
||||
inside_flags (Tensor): inside flags of a single image, |
||||
shape (n, ). |
||||
flat_squares (Tensor): flat squares of a single image, |
||||
shape (approxs_per_octave * n, 4) |
||||
gt_bboxes (Tensor): Ground truth bboxes of a single image, \ |
||||
shape (num_gts, 4). |
||||
gt_bboxes_ignore (Tensor): Ground truth bboxes to be |
||||
ignored, shape (num_ignored_gts, 4). |
||||
gt_labels (Tensor): Ground truth labels of each box, |
||||
shape (num_gts,). |
||||
img_meta (dict): Meta info of the image. |
||||
label_channels (int): Channel of label. |
||||
sampling (bool): Sample Anchors or not. |
||||
unmap_outputs (bool): unmap outputs or not. |
||||
|
||||
Returns: |
||||
tuple: |
||||
|
||||
- labels_list (Tensor): Labels in a single image |
||||
- label_weights (Tensor): Label weights in a single image |
||||
- bbox_cls_targets (Tensor): BBox cls targets in a single image |
||||
- bbox_cls_weights (Tensor): BBox cls weights in a single image |
||||
- bbox_reg_targets (Tensor): BBox reg targets in a single image |
||||
- bbox_reg_weights (Tensor): BBox reg weights in a single image |
||||
- num_total_pos (int): Number of positive samples \ |
||||
in a single image |
||||
- num_total_neg (int): Number of negative samples \ |
||||
in a single image |
||||
""" |
||||
if not inside_flags.any(): |
||||
return (None, ) * 8 |
||||
# assign gt and sample anchors |
||||
expand_inside_flags = inside_flags[:, None].expand( |
||||
-1, self.approxs_per_octave).reshape(-1) |
||||
approxs = flat_approxs[expand_inside_flags, :] |
||||
squares = flat_squares[inside_flags, :] |
||||
|
||||
assign_result = self.assigner.assign(approxs, squares, |
||||
self.approxs_per_octave, |
||||
gt_bboxes, gt_bboxes_ignore) |
||||
sampling_result = self.sampler.sample(assign_result, squares, |
||||
gt_bboxes) |
||||
|
||||
num_valid_squares = squares.shape[0] |
||||
bbox_cls_targets = squares.new_zeros( |
||||
(num_valid_squares, self.side_num * 4)) |
||||
bbox_cls_weights = squares.new_zeros( |
||||
(num_valid_squares, self.side_num * 4)) |
||||
bbox_reg_targets = squares.new_zeros( |
||||
(num_valid_squares, self.side_num * 4)) |
||||
bbox_reg_weights = squares.new_zeros( |
||||
(num_valid_squares, self.side_num * 4)) |
||||
labels = squares.new_full((num_valid_squares, ), |
||||
self.background_label, |
||||
dtype=torch.long) |
||||
label_weights = squares.new_zeros(num_valid_squares, dtype=torch.float) |
||||
|
||||
pos_inds = sampling_result.pos_inds |
||||
neg_inds = sampling_result.neg_inds |
||||
if len(pos_inds) > 0: |
||||
(pos_bbox_reg_targets, pos_bbox_reg_weights, pos_bbox_cls_targets, |
||||
pos_bbox_cls_weights) = self.bbox_coder.encode( |
||||
sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes) |
||||
|
||||
bbox_cls_targets[pos_inds, :] = pos_bbox_cls_targets |
||||
bbox_reg_targets[pos_inds, :] = pos_bbox_reg_targets |
||||
bbox_cls_weights[pos_inds, :] = pos_bbox_cls_weights |
||||
bbox_reg_weights[pos_inds, :] = pos_bbox_reg_weights |
||||
if gt_labels is None: |
||||
labels[pos_inds] = 1 |
||||
else: |
||||
labels[pos_inds] = gt_labels[ |
||||
sampling_result.pos_assigned_gt_inds] |
||||
if self.train_cfg.pos_weight <= 0: |
||||
label_weights[pos_inds] = 1.0 |
||||
else: |
||||
label_weights[pos_inds] = self.train_cfg.pos_weight |
||||
if len(neg_inds) > 0: |
||||
label_weights[neg_inds] = 1.0 |
||||
|
||||
# map up to original set of anchors |
||||
if unmap_outputs: |
||||
num_total_anchors = flat_squares.size(0) |
||||
labels = unmap( |
||||
labels, |
||||
num_total_anchors, |
||||
inside_flags, |
||||
fill=self.background_label) |
||||
label_weights = unmap(label_weights, num_total_anchors, |
||||
inside_flags) |
||||
bbox_cls_targets = unmap(bbox_cls_targets, num_total_anchors, |
||||
inside_flags) |
||||
bbox_cls_weights = unmap(bbox_cls_weights, num_total_anchors, |
||||
inside_flags) |
||||
bbox_reg_targets = unmap(bbox_reg_targets, num_total_anchors, |
||||
inside_flags) |
||||
bbox_reg_weights = unmap(bbox_reg_weights, num_total_anchors, |
||||
inside_flags) |
||||
return (labels, label_weights, bbox_cls_targets, bbox_cls_weights, |
||||
bbox_reg_targets, bbox_reg_weights, pos_inds, neg_inds) |
||||
|
||||
def loss_single(self, cls_score, bbox_pred, labels, label_weights, |
||||
bbox_cls_targets, bbox_cls_weights, bbox_reg_targets, |
||||
bbox_reg_weights, num_total_samples): |
||||
# classification loss |
||||
labels = labels.reshape(-1) |
||||
label_weights = label_weights.reshape(-1) |
||||
cls_score = cls_score.permute(0, 2, 3, |
||||
1).reshape(-1, self.cls_out_channels) |
||||
loss_cls = self.loss_cls( |
||||
cls_score, labels, label_weights, avg_factor=num_total_samples) |
||||
# regression loss |
||||
bbox_cls_targets = bbox_cls_targets.reshape(-1, self.side_num * 4) |
||||
bbox_cls_weights = bbox_cls_weights.reshape(-1, self.side_num * 4) |
||||
bbox_reg_targets = bbox_reg_targets.reshape(-1, self.side_num * 4) |
||||
bbox_reg_weights = bbox_reg_weights.reshape(-1, self.side_num * 4) |
||||
(bbox_cls_pred, bbox_reg_pred) = bbox_pred |
||||
bbox_cls_pred = bbox_cls_pred.permute(0, 2, 3, 1).reshape( |
||||
-1, self.side_num * 4) |
||||
bbox_reg_pred = bbox_reg_pred.permute(0, 2, 3, 1).reshape( |
||||
-1, self.side_num * 4) |
||||
loss_bbox_cls = self.loss_bbox_cls( |
||||
bbox_cls_pred, |
||||
bbox_cls_targets.long(), |
||||
bbox_cls_weights, |
||||
avg_factor=num_total_samples * 4 * self.side_num) |
||||
loss_bbox_reg = self.loss_bbox_reg( |
||||
bbox_reg_pred, |
||||
bbox_reg_targets, |
||||
bbox_reg_weights, |
||||
avg_factor=num_total_samples * 4 * self.bbox_coder.offset_topk) |
||||
return loss_cls, loss_bbox_cls, loss_bbox_reg |
||||
|
||||
@force_fp32(apply_to=('cls_scores', 'bbox_preds')) |
||||
def loss(self, |
||||
cls_scores, |
||||
bbox_preds, |
||||
gt_bboxes, |
||||
gt_labels, |
||||
img_metas, |
||||
gt_bboxes_ignore=None): |
||||
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] |
||||
assert len(featmap_sizes) == self.approx_anchor_generator.num_levels |
||||
|
||||
device = cls_scores[0].device |
||||
|
||||
# get sampled approxes |
||||
approxs_list, inside_flag_list = GuidedAnchorHead.get_sampled_approxs( |
||||
self, featmap_sizes, img_metas, device=device) |
||||
|
||||
square_list = self.get_anchors(featmap_sizes, img_metas, device=device) |
||||
|
||||
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 |
||||
|
||||
cls_reg_targets = self.get_target( |
||||
approxs_list, |
||||
inside_flag_list, |
||||
square_list, |
||||
gt_bboxes, |
||||
img_metas, |
||||
gt_bboxes_ignore_list=gt_bboxes_ignore, |
||||
gt_labels_list=gt_labels, |
||||
label_channels=label_channels, |
||||
sampling=self.sampling) |
||||
if cls_reg_targets is None: |
||||
return None |
||||
(labels_list, label_weights_list, bbox_cls_targets_list, |
||||
bbox_cls_weights_list, bbox_reg_targets_list, bbox_reg_weights_list, |
||||
num_total_pos, num_total_neg) = cls_reg_targets |
||||
num_total_samples = ( |
||||
num_total_pos + num_total_neg if self.sampling else num_total_pos) |
||||
losses_cls, losses_bbox_cls, losses_bbox_reg = multi_apply( |
||||
self.loss_single, |
||||
cls_scores, |
||||
bbox_preds, |
||||
labels_list, |
||||
label_weights_list, |
||||
bbox_cls_targets_list, |
||||
bbox_cls_weights_list, |
||||
bbox_reg_targets_list, |
||||
bbox_reg_weights_list, |
||||
num_total_samples=num_total_samples) |
||||
return dict( |
||||
loss_cls=losses_cls, |
||||
loss_bbox_cls=losses_bbox_cls, |
||||
loss_bbox_reg=losses_bbox_reg) |
||||
|
||||
@force_fp32(apply_to=('cls_scores', 'bbox_preds')) |
||||
def get_bboxes(self, |
||||
cls_scores, |
||||
bbox_preds, |
||||
img_metas, |
||||
cfg=None, |
||||
rescale=False): |
||||
assert len(cls_scores) == len(bbox_preds) |
||||
num_levels = len(cls_scores) |
||||
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] |
||||
|
||||
device = cls_scores[0].device |
||||
mlvl_anchors = self.get_anchors( |
||||
featmap_sizes, img_metas, device=device) |
||||
result_list = [] |
||||
for img_id in range(len(img_metas)): |
||||
cls_score_list = [ |
||||
cls_scores[i][img_id].detach() for i in range(num_levels) |
||||
] |
||||
bbox_cls_pred_list = [ |
||||
bbox_preds[i][0][img_id].detach() for i in range(num_levels) |
||||
] |
||||
bbox_reg_pred_list = [ |
||||
bbox_preds[i][1][img_id].detach() for i in range(num_levels) |
||||
] |
||||
img_shape = img_metas[img_id]['img_shape'] |
||||
scale_factor = img_metas[img_id]['scale_factor'] |
||||
proposals = self.get_bboxes_single(cls_score_list, |
||||
bbox_cls_pred_list, |
||||
bbox_reg_pred_list, |
||||
mlvl_anchors[img_id], img_shape, |
||||
scale_factor, cfg, rescale) |
||||
result_list.append(proposals) |
||||
return result_list |
||||
|
||||
def get_bboxes_single(self, |
||||
cls_scores, |
||||
bbox_cls_preds, |
||||
bbox_reg_preds, |
||||
mlvl_anchors, |
||||
img_shape, |
||||
scale_factor, |
||||
cfg, |
||||
rescale=False): |
||||
cfg = self.test_cfg if cfg is None else cfg |
||||
mlvl_bboxes = [] |
||||
mlvl_scores = [] |
||||
mlvl_confids = [] |
||||
assert len(cls_scores) == len(bbox_cls_preds) == len( |
||||
bbox_reg_preds) == len(mlvl_anchors) |
||||
for cls_score, bbox_cls_pred, bbox_reg_pred, anchors in zip( |
||||
cls_scores, bbox_cls_preds, bbox_reg_preds, mlvl_anchors): |
||||
assert cls_score.size()[-2:] == bbox_cls_pred.size( |
||||
)[-2:] == bbox_reg_pred.size()[-2::] |
||||
cls_score = cls_score.permute(1, 2, |
||||
0).reshape(-1, self.cls_out_channels) |
||||
if self.use_sigmoid_cls: |
||||
scores = cls_score.sigmoid() |
||||
else: |
||||
scores = cls_score.softmax(-1) |
||||
bbox_cls_pred = bbox_cls_pred.permute(1, 2, 0).reshape( |
||||
-1, self.side_num * 4) |
||||
bbox_reg_pred = bbox_reg_pred.permute(1, 2, 0).reshape( |
||||
-1, self.side_num * 4) |
||||
nms_pre = cfg.get('nms_pre', -1) |
||||
if nms_pre > 0 and scores.shape[0] > nms_pre: |
||||
if self.use_sigmoid_cls: |
||||
max_scores, _ = scores.max(dim=1) |
||||
else: |
||||
max_scores, _ = scores[:, :-1].max(dim=1) |
||||
_, topk_inds = max_scores.topk(nms_pre) |
||||
anchors = anchors[topk_inds, :] |
||||
bbox_cls_pred = bbox_cls_pred[topk_inds, :] |
||||
bbox_reg_pred = bbox_reg_pred[topk_inds, :] |
||||
scores = scores[topk_inds, :] |
||||
bbox_preds = [ |
||||
bbox_cls_pred.contiguous(), |
||||
bbox_reg_pred.contiguous() |
||||
] |
||||
bboxes, confids = self.bbox_coder.decode( |
||||
anchors.contiguous(), bbox_preds, max_shape=img_shape) |
||||
mlvl_bboxes.append(bboxes) |
||||
mlvl_scores.append(scores) |
||||
mlvl_confids.append(confids) |
||||
mlvl_bboxes = torch.cat(mlvl_bboxes) |
||||
if rescale: |
||||
mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) |
||||
mlvl_scores = torch.cat(mlvl_scores) |
||||
mlvl_confids = torch.cat(mlvl_confids) |
||||
if self.use_sigmoid_cls: |
||||
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) |
||||
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) |
||||
det_bboxes, det_labels = multiclass_nms( |
||||
mlvl_bboxes, |
||||
mlvl_scores, |
||||
cfg.score_thr, |
||||
cfg.nms, |
||||
cfg.max_per_img, |
||||
score_factors=mlvl_confids) |
||||
return det_bboxes, det_labels |
@ -0,0 +1,563 @@ |
||||
import numpy as np |
||||
import torch |
||||
import torch.nn as nn |
||||
import torch.nn.functional as F |
||||
from mmcv.cnn import ConvModule, kaiming_init, normal_init, xavier_init |
||||
|
||||
from mmdet.core import (build_bbox_coder, force_fp32, multi_apply, |
||||
multiclass_nms) |
||||
from mmdet.models.builder import HEADS, build_loss |
||||
from mmdet.models.losses import accuracy |
||||
|
||||
|
||||
@HEADS.register_module |
||||
class SABLHead(nn.Module): |
||||
"""Side-Aware Boundary Localization (SABL) for RoI-Head. |
||||
|
||||
Side-Aware features are extracted by conv layers |
||||
with an attention mechanism. |
||||
Boundary Localization with Bucketing and Bucketing Guided Rescoring |
||||
are implemented in BucketingBBoxCoder. |
||||
|
||||
Please refer to https://arxiv.org/abs/1912.04260 for more details. |
||||
|
||||
Args: |
||||
cls_in_channels (int): Input channels of cls RoI feature. \ |
||||
Defaults to 256. |
||||
reg_in_channels (int): Input channels of reg RoI feature. \ |
||||
Defaults to 256. |
||||
roi_feat_size (int): Size of RoI features. Defaults to 7. |
||||
reg_feat_up_ratio (int): Upsample ratio of reg features. \ |
||||
Defaults to 2. |
||||
reg_pre_kernel (int): Kernel of 2D conv layers before \ |
||||
attention pooling. Defaults to 3. |
||||
reg_post_kernel (int): Kernel of 1D conv layers after \ |
||||
attention pooling. Defaults to 3. |
||||
reg_pre_num (int): Number of pre convs. Defaults to 2. |
||||
reg_post_num (int): Number of post convs. Defaults to 1. |
||||
num_classes (int): Number of classes in dataset. Defaults to 80. |
||||
cls_out_channels (int): Hidden channels in cls fcs. Defaults to 1024. |
||||
reg_offset_out_channels (int): Hidden and output channel \ |
||||
of reg offset branch. Defaults to 256. |
||||
reg_cls_out_channels (int): Hidden and output channel \ |
||||
of reg cls branch. Defaults to 256. |
||||
num_cls_fcs (int): Number of fcs for cls branch. Defaults to 1. |
||||
num_reg_fcs (int): Number of fcs for reg branch.. Defaults to 0. |
||||
reg_class_agnostic (bool): Class agnostic regresion or not. \ |
||||
Defaults to True. |
||||
norm_cfg (dict): Config of norm layers. Defaults to None. |
||||
bbox_coder (dict): Config of bbox coder. Defaults 'BucketingBBoxCoder'. |
||||
loss_cls (dict): Config of classification loss. |
||||
loss_bbox_cls (dict): Config of classification loss for bbox branch. |
||||
loss_bbox_reg (dict): Config of regression loss for bbox branch. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
num_classes, |
||||
cls_in_channels=256, |
||||
reg_in_channels=256, |
||||
roi_feat_size=7, |
||||
reg_feat_up_ratio=2, |
||||
reg_pre_kernel=3, |
||||
reg_post_kernel=3, |
||||
reg_pre_num=2, |
||||
reg_post_num=1, |
||||
cls_out_channels=1024, |
||||
reg_offset_out_channels=256, |
||||
reg_cls_out_channels=256, |
||||
num_cls_fcs=1, |
||||
num_reg_fcs=0, |
||||
reg_class_agnostic=True, |
||||
norm_cfg=None, |
||||
bbox_coder=dict( |
||||
type='BucketingBBoxCoder', |
||||
num_buckets=14, |
||||
scale_factor=1.7), |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', |
||||
use_sigmoid=False, |
||||
loss_weight=1.0), |
||||
loss_bbox_cls=dict( |
||||
type='CrossEntropyLoss', |
||||
use_sigmoid=True, |
||||
loss_weight=1.0), |
||||
loss_bbox_reg=dict( |
||||
type='SmoothL1Loss', beta=0.1, loss_weight=1.0)): |
||||
super(SABLHead, self).__init__() |
||||
self.cls_in_channels = cls_in_channels |
||||
self.reg_in_channels = reg_in_channels |
||||
self.roi_feat_size = roi_feat_size |
||||
self.reg_feat_up_ratio = int(reg_feat_up_ratio) |
||||
self.num_buckets = bbox_coder['num_buckets'] |
||||
assert self.reg_feat_up_ratio // 2 >= 1 |
||||
self.up_reg_feat_size = roi_feat_size * self.reg_feat_up_ratio |
||||
assert self.up_reg_feat_size == bbox_coder['num_buckets'] |
||||
self.reg_pre_kernel = reg_pre_kernel |
||||
self.reg_post_kernel = reg_post_kernel |
||||
self.reg_pre_num = reg_pre_num |
||||
self.reg_post_num = reg_post_num |
||||
self.num_classes = num_classes |
||||
self.cls_out_channels = cls_out_channels |
||||
self.reg_offset_out_channels = reg_offset_out_channels |
||||
self.reg_cls_out_channels = reg_cls_out_channels |
||||
self.num_cls_fcs = num_cls_fcs |
||||
self.num_reg_fcs = num_reg_fcs |
||||
self.reg_class_agnostic = reg_class_agnostic |
||||
assert self.reg_class_agnostic |
||||
self.norm_cfg = norm_cfg |
||||
|
||||
self.bbox_coder = build_bbox_coder(bbox_coder) |
||||
self.loss_cls = build_loss(loss_cls) |
||||
self.loss_bbox_cls = build_loss(loss_bbox_cls) |
||||
self.loss_bbox_reg = build_loss(loss_bbox_reg) |
||||
|
||||
self.cls_fcs = self._add_fc_branch(self.num_cls_fcs, |
||||
self.cls_in_channels, |
||||
self.roi_feat_size, |
||||
self.cls_out_channels) |
||||
|
||||
self.side_num = int(np.ceil(self.num_buckets / 2)) |
||||
|
||||
if self.reg_feat_up_ratio > 1: |
||||
self.upsample_x = nn.ConvTranspose1d( |
||||
reg_in_channels, |
||||
reg_in_channels, |
||||
self.reg_feat_up_ratio, |
||||
stride=self.reg_feat_up_ratio) |
||||
self.upsample_y = nn.ConvTranspose1d( |
||||
reg_in_channels, |
||||
reg_in_channels, |
||||
self.reg_feat_up_ratio, |
||||
stride=self.reg_feat_up_ratio) |
||||
|
||||
self.reg_pre_convs = nn.ModuleList() |
||||
for i in range(self.reg_pre_num): |
||||
reg_pre_conv = ConvModule( |
||||
reg_in_channels, |
||||
reg_in_channels, |
||||
kernel_size=reg_pre_kernel, |
||||
padding=reg_pre_kernel // 2, |
||||
norm_cfg=norm_cfg, |
||||
act_cfg=dict(type='ReLU')) |
||||
self.reg_pre_convs.append(reg_pre_conv) |
||||
|
||||
self.reg_post_conv_xs = nn.ModuleList() |
||||
for i in range(self.reg_post_num): |
||||
reg_post_conv_x = ConvModule( |
||||
reg_in_channels, |
||||
reg_in_channels, |
||||
kernel_size=(1, reg_post_kernel), |
||||
padding=(0, reg_post_kernel // 2), |
||||
norm_cfg=norm_cfg, |
||||
act_cfg=dict(type='ReLU')) |
||||
self.reg_post_conv_xs.append(reg_post_conv_x) |
||||
self.reg_post_conv_ys = nn.ModuleList() |
||||
for i in range(self.reg_post_num): |
||||
reg_post_conv_y = ConvModule( |
||||
reg_in_channels, |
||||
reg_in_channels, |
||||
kernel_size=(reg_post_kernel, 1), |
||||
padding=(reg_post_kernel // 2, 0), |
||||
norm_cfg=norm_cfg, |
||||
act_cfg=dict(type='ReLU')) |
||||
self.reg_post_conv_ys.append(reg_post_conv_y) |
||||
|
||||
self.reg_conv_att_x = nn.Conv2d(reg_in_channels, 1, 1) |
||||
self.reg_conv_att_y = nn.Conv2d(reg_in_channels, 1, 1) |
||||
|
||||
self.fc_cls = nn.Linear(self.cls_out_channels, self.num_classes + 1) |
||||
self.relu = nn.ReLU(inplace=True) |
||||
|
||||
self.reg_cls_fcs = self._add_fc_branch(self.num_reg_fcs, |
||||
self.reg_in_channels, 1, |
||||
self.reg_cls_out_channels) |
||||
self.reg_offset_fcs = self._add_fc_branch(self.num_reg_fcs, |
||||
self.reg_in_channels, 1, |
||||
self.reg_offset_out_channels) |
||||
self.fc_reg_cls = nn.Linear(self.reg_cls_out_channels, 1) |
||||
self.fc_reg_offset = nn.Linear(self.reg_offset_out_channels, 1) |
||||
|
||||
def _add_fc_branch(self, num_branch_fcs, in_channels, roi_feat_size, |
||||
fc_out_channels): |
||||
in_channels = in_channels * roi_feat_size * roi_feat_size |
||||
branch_fcs = nn.ModuleList() |
||||
for i in range(num_branch_fcs): |
||||
fc_in_channels = (in_channels if i == 0 else fc_out_channels) |
||||
branch_fcs.append(nn.Linear(fc_in_channels, fc_out_channels)) |
||||
return branch_fcs |
||||
|
||||
def init_weights(self): |
||||
for module_list in [ |
||||
self.reg_cls_fcs, self.reg_offset_fcs, self.cls_fcs |
||||
]: |
||||
for m in module_list.modules(): |
||||
if isinstance(m, nn.Linear): |
||||
xavier_init(m, distribution='uniform') |
||||
if self.reg_feat_up_ratio > 1: |
||||
kaiming_init(self.upsample_x, distribution='normal') |
||||
kaiming_init(self.upsample_y, distribution='normal') |
||||
|
||||
normal_init(self.reg_conv_att_x, 0, 0.01) |
||||
normal_init(self.reg_conv_att_y, 0, 0.01) |
||||
normal_init(self.fc_reg_offset, 0, 0.001) |
||||
normal_init(self.fc_reg_cls, 0, 0.01) |
||||
normal_init(self.fc_cls, 0, 0.01) |
||||
|
||||
def cls_forward(self, cls_x): |
||||
cls_x = cls_x.view(cls_x.size(0), -1) |
||||
for fc in self.cls_fcs: |
||||
cls_x = self.relu(fc(cls_x)) |
||||
cls_score = self.fc_cls(cls_x) |
||||
return cls_score |
||||
|
||||
def attention_pool(self, reg_x): |
||||
"""Extract direction-specific features fx and fy with attention |
||||
methanism.""" |
||||
reg_fx = reg_x |
||||
reg_fy = reg_x |
||||
reg_fx_att = self.reg_conv_att_x(reg_fx).sigmoid() |
||||
reg_fy_att = self.reg_conv_att_y(reg_fy).sigmoid() |
||||
reg_fx_att = reg_fx_att / reg_fx_att.sum(dim=2).unsqueeze(2) |
||||
reg_fy_att = reg_fy_att / reg_fy_att.sum(dim=3).unsqueeze(3) |
||||
reg_fx = (reg_fx * reg_fx_att).sum(dim=2) |
||||
reg_fy = (reg_fy * reg_fy_att).sum(dim=3) |
||||
return reg_fx, reg_fy |
||||
|
||||
def side_aware_feature_extractor(self, reg_x): |
||||
"""Refine and extract side-aware features without split them.""" |
||||
for reg_pre_conv in self.reg_pre_convs: |
||||
reg_x = reg_pre_conv(reg_x) |
||||
reg_fx, reg_fy = self.attention_pool(reg_x) |
||||
|
||||
if self.reg_post_num > 0: |
||||
reg_fx = reg_fx.unsqueeze(2) |
||||
reg_fy = reg_fy.unsqueeze(3) |
||||
for i in range(self.reg_post_num): |
||||
reg_fx = self.reg_post_conv_xs[i](reg_fx) |
||||
reg_fy = self.reg_post_conv_ys[i](reg_fy) |
||||
reg_fx = reg_fx.squeeze(2) |
||||
reg_fy = reg_fy.squeeze(3) |
||||
if self.reg_feat_up_ratio > 1: |
||||
reg_fx = self.relu(self.upsample_x(reg_fx)) |
||||
reg_fy = self.relu(self.upsample_y(reg_fy)) |
||||
reg_fx = torch.transpose(reg_fx, 1, 2) |
||||
reg_fy = torch.transpose(reg_fy, 1, 2) |
||||
return reg_fx.contiguous(), reg_fy.contiguous() |
||||
|
||||
def reg_pred(self, x, offfset_fcs, cls_fcs): |
||||
"""Predict bucketing esimation (cls_pred) and fine regression (offset |
||||
pred) with side-aware features.""" |
||||
x_offset = x.view(-1, self.reg_in_channels) |
||||
x_cls = x.view(-1, self.reg_in_channels) |
||||
|
||||
for fc in offfset_fcs: |
||||
x_offset = self.relu(fc(x_offset)) |
||||
for fc in cls_fcs: |
||||
x_cls = self.relu(fc(x_cls)) |
||||
offset_pred = self.fc_reg_offset(x_offset) |
||||
cls_pred = self.fc_reg_cls(x_cls) |
||||
|
||||
offset_pred = offset_pred.view(x.size(0), -1) |
||||
cls_pred = cls_pred.view(x.size(0), -1) |
||||
|
||||
return offset_pred, cls_pred |
||||
|
||||
def side_aware_split(self, feat): |
||||
"""Split side-aware features aligned with orders of bucketing |
||||
targets.""" |
||||
l_end = int(np.ceil(self.up_reg_feat_size / 2)) |
||||
r_start = int(np.floor(self.up_reg_feat_size / 2)) |
||||
feat_fl = feat[:, :l_end] |
||||
feat_fr = feat[:, r_start:].flip(dims=(1, )) |
||||
feat_fl = feat_fl.contiguous() |
||||
feat_fr = feat_fr.contiguous() |
||||
feat = torch.cat([feat_fl, feat_fr], dim=-1) |
||||
return feat |
||||
|
||||
def reg_forward(self, reg_x): |
||||
outs = self.side_aware_feature_extractor(reg_x) |
||||
edge_offset_preds = [] |
||||
edge_cls_preds = [] |
||||
reg_fx = outs[0] |
||||
reg_fy = outs[1] |
||||
offset_pred_x, cls_pred_x = self.reg_pred(reg_fx, self.reg_offset_fcs, |
||||
self.reg_cls_fcs) |
||||
offset_pred_y, cls_pred_y = self.reg_pred(reg_fy, self.reg_offset_fcs, |
||||
self.reg_cls_fcs) |
||||
offset_pred_x = self.side_aware_split(offset_pred_x) |
||||
offset_pred_y = self.side_aware_split(offset_pred_y) |
||||
cls_pred_x = self.side_aware_split(cls_pred_x) |
||||
cls_pred_y = self.side_aware_split(cls_pred_y) |
||||
edge_offset_preds = torch.cat([offset_pred_x, offset_pred_y], dim=-1) |
||||
edge_cls_preds = torch.cat([cls_pred_x, cls_pred_y], dim=-1) |
||||
|
||||
return (edge_cls_preds, edge_offset_preds) |
||||
|
||||
def forward(self, x): |
||||
|
||||
bbox_pred = self.reg_forward(x) |
||||
cls_score = self.cls_forward(x) |
||||
|
||||
return cls_score, bbox_pred |
||||
|
||||
def get_targets(self, sampling_results, gt_bboxes, gt_labels, |
||||
rcnn_train_cfg): |
||||
pos_proposals = [res.pos_bboxes for res in sampling_results] |
||||
neg_proposals = [res.neg_bboxes for res in sampling_results] |
||||
pos_gt_bboxes = [res.pos_gt_bboxes for res in sampling_results] |
||||
pos_gt_labels = [res.pos_gt_labels for res in sampling_results] |
||||
cls_reg_targets = self.bucket_target(pos_proposals, neg_proposals, |
||||
pos_gt_bboxes, pos_gt_labels, |
||||
rcnn_train_cfg) |
||||
(labels, label_weights, bucket_cls_targets, bucket_cls_weights, |
||||
bucket_offset_targets, bucket_offset_weights) = cls_reg_targets |
||||
return (labels, label_weights, (bucket_cls_targets, |
||||
bucket_offset_targets), |
||||
(bucket_cls_weights, bucket_offset_weights)) |
||||
|
||||
def bucket_target(self, |
||||
pos_proposals_list, |
||||
neg_proposals_list, |
||||
pos_gt_bboxes_list, |
||||
pos_gt_labels_list, |
||||
rcnn_train_cfg, |
||||
concat=True): |
||||
(labels, label_weights, bucket_cls_targets, bucket_cls_weights, |
||||
bucket_offset_targets, bucket_offset_weights) = multi_apply( |
||||
self._bucket_target_single, |
||||
pos_proposals_list, |
||||
neg_proposals_list, |
||||
pos_gt_bboxes_list, |
||||
pos_gt_labels_list, |
||||
cfg=rcnn_train_cfg) |
||||
|
||||
if concat: |
||||
labels = torch.cat(labels, 0) |
||||
label_weights = torch.cat(label_weights, 0) |
||||
bucket_cls_targets = torch.cat(bucket_cls_targets, 0) |
||||
bucket_cls_weights = torch.cat(bucket_cls_weights, 0) |
||||
bucket_offset_targets = torch.cat(bucket_offset_targets, 0) |
||||
bucket_offset_weights = torch.cat(bucket_offset_weights, 0) |
||||
return (labels, label_weights, bucket_cls_targets, bucket_cls_weights, |
||||
bucket_offset_targets, bucket_offset_weights) |
||||
|
||||
def _bucket_target_single(self, pos_proposals, neg_proposals, |
||||
pos_gt_bboxes, pos_gt_labels, cfg): |
||||
"""Compute bucketing estimation targets and fine regression targets for |
||||
a single image. |
||||
|
||||
Args: |
||||
pos_proposals (Tensor): positive proposals of a single image, |
||||
Shape (n_pos, 4) |
||||
neg_proposals (Tensor): negative proposals of a single image, |
||||
Shape (n_neg, 4). |
||||
pos_gt_bboxes (Tensor): gt bboxes assigned to positive proposals |
||||
of a single image, Shape (n_pos, 4). |
||||
pos_gt_labels (Tensor): gt labels assigned to positive proposals |
||||
of a single image, Shape (n_pos, ). |
||||
cfg (dict): Config of calculating targets |
||||
|
||||
Returns: |
||||
tuple: |
||||
|
||||
- labels (Tensor): Labels in a single image. \ |
||||
Shape (n,). |
||||
- label_weights (Tensor): Label weights in a single image.\ |
||||
Shape (n,) |
||||
- bucket_cls_targets (Tensor): Bucket cls targets in \ |
||||
a single image. Shape (n, num_buckets*2). |
||||
- bucket_cls_weights (Tensor): Bucket cls weights in \ |
||||
a single image. Shape (n, num_buckets*2). |
||||
- bucket_offset_targets (Tensor): Bucket offset targets \ |
||||
in a single image. Shape (n, num_buckets*2). |
||||
- bucket_offset_targets (Tensor): Bucket offset weights \ |
||||
in a single image. Shape (n, num_buckets*2). |
||||
""" |
||||
num_pos = pos_proposals.size(0) |
||||
num_neg = neg_proposals.size(0) |
||||
num_samples = num_pos + num_neg |
||||
labels = pos_gt_bboxes.new_full((num_samples, ), |
||||
self.num_classes, |
||||
dtype=torch.long) |
||||
label_weights = pos_proposals.new_zeros(num_samples) |
||||
bucket_cls_targets = pos_proposals.new_zeros(num_samples, |
||||
4 * self.side_num) |
||||
bucket_cls_weights = pos_proposals.new_zeros(num_samples, |
||||
4 * self.side_num) |
||||
bucket_offset_targets = pos_proposals.new_zeros( |
||||
num_samples, 4 * self.side_num) |
||||
bucket_offset_weights = pos_proposals.new_zeros( |
||||
num_samples, 4 * self.side_num) |
||||
if num_pos > 0: |
||||
labels[:num_pos] = pos_gt_labels |
||||
label_weights[:num_pos] = 1.0 |
||||
(pos_bucket_offset_targets, pos_bucket_offset_weights, |
||||
pos_bucket_cls_targets, |
||||
pos_bucket_cls_weights) = self.bbox_coder.encode( |
||||
pos_proposals, pos_gt_bboxes) |
||||
bucket_cls_targets[:num_pos, :] = pos_bucket_cls_targets |
||||
bucket_cls_weights[:num_pos, :] = pos_bucket_cls_weights |
||||
bucket_offset_targets[:num_pos, :] = pos_bucket_offset_targets |
||||
bucket_offset_weights[:num_pos, :] = pos_bucket_offset_weights |
||||
if num_neg > 0: |
||||
label_weights[-num_neg:] = 1.0 |
||||
return (labels, label_weights, bucket_cls_targets, bucket_cls_weights, |
||||
bucket_offset_targets, bucket_offset_weights) |
||||
|
||||
def loss(self, |
||||
cls_score, |
||||
bbox_pred, |
||||
rois, |
||||
labels, |
||||
label_weights, |
||||
bbox_targets, |
||||
bbox_weights, |
||||
reduction_override=None): |
||||
losses = dict() |
||||
if cls_score is not None: |
||||
avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.) |
||||
losses['loss_cls'] = self.loss_cls( |
||||
cls_score, |
||||
labels, |
||||
label_weights, |
||||
avg_factor=avg_factor, |
||||
reduction_override=reduction_override) |
||||
losses['acc'] = accuracy(cls_score, labels) |
||||
|
||||
if bbox_pred is not None: |
||||
bucket_cls_preds, bucket_offset_preds = bbox_pred |
||||
bucket_cls_targets, bucket_offset_targets = bbox_targets |
||||
bucket_cls_weights, bucket_offset_weights = bbox_weights |
||||
# edge cls |
||||
bucket_cls_preds = bucket_cls_preds.view(-1, self.side_num) |
||||
bucket_cls_targets = bucket_cls_targets.view(-1, self.side_num) |
||||
bucket_cls_weights = bucket_cls_weights.view(-1, self.side_num) |
||||
losses['loss_bbox_cls'] = self.loss_bbox_cls( |
||||
bucket_cls_preds, |
||||
bucket_cls_targets, |
||||
bucket_cls_weights, |
||||
avg_factor=bucket_cls_targets.size(0), |
||||
reduction_override=reduction_override) |
||||
|
||||
losses['loss_bbox_reg'] = self.loss_bbox_reg( |
||||
bucket_offset_preds, |
||||
bucket_offset_targets, |
||||
bucket_offset_weights, |
||||
avg_factor=bucket_offset_targets.size(0), |
||||
reduction_override=reduction_override) |
||||
|
||||
return losses |
||||
|
||||
@force_fp32(apply_to=('cls_score', 'bbox_pred')) |
||||
def get_bboxes(self, |
||||
rois, |
||||
cls_score, |
||||
bbox_pred, |
||||
img_shape, |
||||
scale_factor, |
||||
rescale=False, |
||||
cfg=None): |
||||
if isinstance(cls_score, list): |
||||
cls_score = sum(cls_score) / float(len(cls_score)) |
||||
scores = F.softmax(cls_score, dim=1) if cls_score is not None else None |
||||
|
||||
if bbox_pred is not None: |
||||
bboxes, confids = self.bbox_coder.decode(rois[:, 1:], bbox_pred, |
||||
img_shape) |
||||
else: |
||||
bboxes = rois[:, 1:].clone() |
||||
confids = None |
||||
if img_shape is not None: |
||||
bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1] - 1) |
||||
bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0] - 1) |
||||
|
||||
if rescale and bboxes.size(0) > 0: |
||||
if isinstance(scale_factor, float): |
||||
bboxes /= scale_factor |
||||
else: |
||||
bboxes /= torch.from_numpy(scale_factor).to(bboxes.device) |
||||
|
||||
if cfg is None: |
||||
return bboxes, scores |
||||
else: |
||||
det_bboxes, det_labels = multiclass_nms( |
||||
bboxes, |
||||
scores, |
||||
cfg.score_thr, |
||||
cfg.nms, |
||||
cfg.max_per_img, |
||||
score_factors=confids) |
||||
|
||||
return det_bboxes, det_labels |
||||
|
||||
@force_fp32(apply_to=('bbox_preds', )) |
||||
def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas): |
||||
"""Refine bboxes during training. |
||||
|
||||
Args: |
||||
rois (Tensor): Shape (n*bs, 5), where n is image number per GPU, |
||||
and bs is the sampled RoIs per image. |
||||
labels (Tensor): Shape (n*bs, ). |
||||
bbox_preds (list[Tensor]): Shape [(n*bs, num_buckets*2), \ |
||||
(n*bs, num_buckets*2)]. |
||||
pos_is_gts (list[Tensor]): Flags indicating if each positive bbox |
||||
is a gt bbox. |
||||
img_metas (list[dict]): Meta info of each image. |
||||
|
||||
Returns: |
||||
list[Tensor]: Refined bboxes of each image in a mini-batch. |
||||
""" |
||||
img_ids = rois[:, 0].long().unique(sorted=True) |
||||
assert img_ids.numel() == len(img_metas) |
||||
|
||||
bboxes_list = [] |
||||
for i in range(len(img_metas)): |
||||
inds = torch.nonzero( |
||||
rois[:, 0] == i, as_tuple=False).squeeze(dim=1) |
||||
num_rois = inds.numel() |
||||
|
||||
bboxes_ = rois[inds, 1:] |
||||
label_ = labels[inds] |
||||
edge_cls_preds, edge_offset_preds = bbox_preds |
||||
edge_cls_preds_ = edge_cls_preds[inds] |
||||
edge_offset_preds_ = edge_offset_preds[inds] |
||||
bbox_pred_ = [edge_cls_preds_, edge_offset_preds_] |
||||
img_meta_ = img_metas[i] |
||||
pos_is_gts_ = pos_is_gts[i] |
||||
|
||||
bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_, |
||||
img_meta_) |
||||
# filter gt bboxes |
||||
pos_keep = 1 - pos_is_gts_ |
||||
keep_inds = pos_is_gts_.new_ones(num_rois) |
||||
keep_inds[:len(pos_is_gts_)] = pos_keep |
||||
|
||||
bboxes_list.append(bboxes[keep_inds.type(torch.bool)]) |
||||
|
||||
return bboxes_list |
||||
|
||||
@force_fp32(apply_to=('bbox_pred', )) |
||||
def regress_by_class(self, rois, label, bbox_pred, img_meta): |
||||
"""Regress the bbox for the predicted class. Used in Cascade R-CNN. |
||||
|
||||
Args: |
||||
rois (Tensor): shape (n, 4) or (n, 5) |
||||
label (Tensor): shape (n, ) |
||||
bbox_pred (list[Tensor]): shape [(n, num_buckets *2), \ |
||||
(n, num_buckets *2)] |
||||
img_meta (dict): Image meta info. |
||||
|
||||
Returns: |
||||
Tensor: Regressed bboxes, the same shape as input rois. |
||||
""" |
||||
assert rois.size(1) == 4 or rois.size(1) == 5 |
||||
|
||||
if rois.size(1) == 4: |
||||
new_rois, _ = self.bbox_coder.decode(rois, bbox_pred, |
||||
img_meta['img_shape']) |
||||
else: |
||||
bboxes, _ = self.bbox_coder.decode(rois[:, 1:], bbox_pred, |
||||
img_meta['img_shape']) |
||||
new_rois = torch.cat((rois[:, [0]], bboxes), dim=1) |
||||
|
||||
return new_rois |
Loading…
Reference in new issue