Code for Cascade RPN - NeurIPS 2019 (#1900)
* Code for Cascade RPN - NeurIPS 2019 * using mmdet2 apis * update mmdet2 coords and minor fixes * fix format * update configs * fix simple test * update fast rcnn config * minor fixes * update docstring * update readme * update doc string * minor fixes * standard adaptive interface * update agrs docs Co-authored-by: Cao Yuhang <yhcao6@gmail.com>pull/4259/head^2
parent
b951522ef8
commit
23ded99365
12 changed files with 1148 additions and 8 deletions
@ -0,0 +1,27 @@ |
||||
# Cascade RPN |
||||
|
||||
We provide the code for reproducing experiment results of [Cascade RPN](https://arxiv.org/abs/1909.06720). |
||||
|
||||
``` |
||||
@inproceedings{vu2019cascade, |
||||
title={Cascade RPN: Delving into High-Quality Region Proposal Network with Adaptive Convolution}, |
||||
author={Vu, Thang and Jang, Hyunjun and Pham, Trung X and Yoo, Chang D}, |
||||
booktitle={Conference on Neural Information Processing Systems (NeurIPS)}, |
||||
year={2019} |
||||
} |
||||
``` |
||||
|
||||
## Benchmark |
||||
|
||||
### Region proposal performance |
||||
|
||||
| Method | Backbone | Style | Mem (GB) | Train time (s/iter) | Inf time (fps) | AR 1000 | Download | |
||||
|:------:|:--------:|:-----:|:--------:|:-------------------:|:--------------:|:-------:|:--------------------------------------:| |
||||
| CRPN | R-50-FPN | caffe | - | - | - | 72.0 | [model](https://drive.google.com/file/d/1qxVdOnCgK-ee7_z0x6mvAir_glMu2Ihi/view?usp=sharing) | |
||||
|
||||
### Detection performance |
||||
|
||||
| Method | Proposal | Backbone | Style | Schedule | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download | |
||||
|:-------------:|:-----------:|:--------:|:-------:|:--------:|:--------:|:-------------------:|:--------------:|:------:|:--------------------------------------------:| |
||||
| Fast R-CNN | Cascade RPN | R-50-FPN | caffe | 1x | - | - | - | 39.9 | [model](https://drive.google.com/file/d/1NmbnuY5VHi8I9FE8xnp5uNvh2i-t-6_L/view?usp=sharing) | |
||||
| Faster R-CNN | Cascade RPN | R-50-FPN | caffe | 1x | - | - | - | 40.4 | [model](https://drive.google.com/file/d/1dS3Q66qXMJpcuuQgDNkLp669E5w1UMuZ/view?usp=sharing) | |
@ -0,0 +1,74 @@ |
||||
_base_ = '../fast_rcnn/fast_rcnn_r50_fpn_1x_coco.py' |
||||
model = dict( |
||||
pretrained='open-mmlab://detectron2/resnet50_caffe', |
||||
backbone=dict( |
||||
type='ResNet', |
||||
depth=50, |
||||
num_stages=4, |
||||
out_indices=(0, 1, 2, 3), |
||||
frozen_stages=1, |
||||
norm_cfg=dict(type='BN', requires_grad=False), |
||||
norm_eval=True, |
||||
style='caffe'), |
||||
roi_head=dict( |
||||
bbox_head=dict( |
||||
bbox_coder=dict(target_stds=[0.04, 0.04, 0.08, 0.08]), |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.5), |
||||
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))) |
||||
# model training and testing settings |
||||
train_cfg = dict( |
||||
rcnn=dict( |
||||
assigner=dict(pos_iou_thr=0.65, neg_iou_thr=0.65, min_pos_iou=0.65), |
||||
sampler=dict(num=256))) |
||||
test_cfg = dict(rcnn=dict(score_thr=1e-3)) |
||||
dataset_type = 'CocoDataset' |
||||
data_root = 'data/coco/' |
||||
img_norm_cfg = dict( |
||||
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) |
||||
train_pipeline = [ |
||||
dict(type='LoadImageFromFile'), |
||||
dict(type='LoadProposals', num_max_proposals=300), |
||||
dict(type='LoadAnnotations', with_bbox=True), |
||||
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), |
||||
dict(type='RandomFlip', flip_ratio=0.5), |
||||
dict(type='Normalize', **img_norm_cfg), |
||||
dict(type='Pad', size_divisor=32), |
||||
dict(type='DefaultFormatBundle'), |
||||
dict(type='Collect', keys=['img', 'proposals', 'gt_bboxes', 'gt_labels']), |
||||
] |
||||
test_pipeline = [ |
||||
dict(type='LoadImageFromFile'), |
||||
dict(type='LoadProposals', num_max_proposals=300), |
||||
dict( |
||||
type='MultiScaleFlipAug', |
||||
img_scale=(1333, 800), |
||||
flip=False, |
||||
transforms=[ |
||||
dict(type='Resize', keep_ratio=True), |
||||
dict(type='RandomFlip'), |
||||
dict(type='Normalize', **img_norm_cfg), |
||||
dict(type='Pad', size_divisor=32), |
||||
dict(type='ImageToTensor', keys=['img']), |
||||
dict(type='ToTensor', keys=['proposals']), |
||||
dict( |
||||
type='ToDataContainer', |
||||
fields=[dict(key='proposals', stack=False)]), |
||||
dict(type='Collect', keys=['img', 'proposals']), |
||||
]) |
||||
] |
||||
data = dict( |
||||
train=dict( |
||||
proposal_file=data_root + |
||||
'proposals/crpn_r50_caffe_fpn_1x_train2017.pkl', |
||||
pipeline=train_pipeline), |
||||
val=dict( |
||||
proposal_file=data_root + |
||||
'proposals/crpn_r50_caffe_fpn_1x_val2017.pkl', |
||||
pipeline=test_pipeline), |
||||
test=dict( |
||||
proposal_file=data_root + |
||||
'proposals/crpn_r50_caffe_fpn_1x_val2017.pkl', |
||||
pipeline=test_pipeline)) |
||||
optimizer_config = dict( |
||||
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) |
@ -0,0 +1,89 @@ |
||||
_base_ = '../faster_rcnn/faster_rcnn_r50_caffe_fpn_1x_coco.py' |
||||
rpn_weight = 0.7 |
||||
model = dict( |
||||
rpn_head=dict( |
||||
_delete_=True, |
||||
type='CascadeRPNHead', |
||||
num_stages=2, |
||||
stages=[ |
||||
dict( |
||||
type='StageCascadeRPNHead', |
||||
in_channels=256, |
||||
feat_channels=256, |
||||
anchor_generator=dict( |
||||
type='AnchorGenerator', |
||||
scales=[8], |
||||
ratios=[1.0], |
||||
strides=[4, 8, 16, 32, 64]), |
||||
adapt_cfg=dict(type='dilation', dilation=3), |
||||
bridged_feature=True, |
||||
sampling=False, |
||||
with_cls=False, |
||||
reg_decoded_bbox=True, |
||||
bbox_coder=dict( |
||||
type='DeltaXYWHBBoxCoder', |
||||
target_means=(.0, .0, .0, .0), |
||||
target_stds=(0.1, 0.1, 0.5, 0.5)), |
||||
loss_bbox=dict( |
||||
type='IoULoss', linear=True, |
||||
loss_weight=10.0 * rpn_weight)), |
||||
dict( |
||||
type='StageCascadeRPNHead', |
||||
in_channels=256, |
||||
feat_channels=256, |
||||
adapt_cfg=dict(type='offset'), |
||||
bridged_feature=False, |
||||
sampling=True, |
||||
with_cls=True, |
||||
reg_decoded_bbox=True, |
||||
bbox_coder=dict( |
||||
type='DeltaXYWHBBoxCoder', |
||||
target_means=(.0, .0, .0, .0), |
||||
target_stds=(0.05, 0.05, 0.1, 0.1)), |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', |
||||
use_sigmoid=True, |
||||
loss_weight=1.0 * rpn_weight), |
||||
loss_bbox=dict( |
||||
type='IoULoss', linear=True, |
||||
loss_weight=10.0 * rpn_weight)) |
||||
]), |
||||
roi_head=dict( |
||||
bbox_head=dict( |
||||
bbox_coder=dict(target_stds=[0.04, 0.04, 0.08, 0.08]), |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.5), |
||||
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))) |
||||
# model training and testing settings |
||||
train_cfg = dict( |
||||
rpn=[ |
||||
dict( |
||||
assigner=dict( |
||||
type='RegionAssigner', center_ratio=0.2, ignore_ratio=0.5), |
||||
allowed_border=-1, |
||||
pos_weight=-1, |
||||
debug=False), |
||||
dict( |
||||
assigner=dict( |
||||
type='MaxIoUAssigner', |
||||
pos_iou_thr=0.7, |
||||
neg_iou_thr=0.7, |
||||
min_pos_iou=0.3, |
||||
ignore_iof_thr=-1), |
||||
sampler=dict( |
||||
type='RandomSampler', |
||||
num=256, |
||||
pos_fraction=0.5, |
||||
neg_pos_ub=-1, |
||||
add_gt_as_proposals=False), |
||||
allowed_border=-1, |
||||
pos_weight=-1, |
||||
debug=False) |
||||
], |
||||
rpn_proposal=dict(max_num=300, nms_thr=0.8), |
||||
rcnn=dict( |
||||
assigner=dict(pos_iou_thr=0.65, neg_iou_thr=0.65, min_pos_iou=0.65), |
||||
sampler=dict(type='RandomSampler', num=256))) |
||||
test_cfg = dict(rpn=dict(max_num=300, nms_thr=0.8), rcnn=dict(score_thr=1e-3)) |
||||
optimizer_config = dict( |
||||
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) |
@ -0,0 +1,79 @@ |
||||
_base_ = '../rpn/rpn_r50_caffe_fpn_1x_coco.py' |
||||
model = dict( |
||||
rpn_head=dict( |
||||
_delete_=True, |
||||
type='CascadeRPNHead', |
||||
num_stages=2, |
||||
stages=[ |
||||
dict( |
||||
type='StageCascadeRPNHead', |
||||
in_channels=256, |
||||
feat_channels=256, |
||||
anchor_generator=dict( |
||||
type='AnchorGenerator', |
||||
scales=[8], |
||||
ratios=[1.0], |
||||
strides=[4, 8, 16, 32, 64]), |
||||
adapt_cfg=dict(type='dilation', dilation=3), |
||||
bridged_feature=True, |
||||
sampling=False, |
||||
with_cls=False, |
||||
reg_decoded_bbox=True, |
||||
bbox_coder=dict( |
||||
type='DeltaXYWHBBoxCoder', |
||||
target_means=(.0, .0, .0, .0), |
||||
target_stds=(0.1, 0.1, 0.5, 0.5)), |
||||
loss_bbox=dict(type='IoULoss', linear=True, loss_weight=10.0)), |
||||
dict( |
||||
type='StageCascadeRPNHead', |
||||
in_channels=256, |
||||
feat_channels=256, |
||||
adapt_cfg=dict(type='offset'), |
||||
bridged_feature=False, |
||||
sampling=True, |
||||
with_cls=True, |
||||
reg_decoded_bbox=True, |
||||
bbox_coder=dict( |
||||
type='DeltaXYWHBBoxCoder', |
||||
target_means=(.0, .0, .0, .0), |
||||
target_stds=(0.05, 0.05, 0.1, 0.1)), |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', use_sigmoid=True, |
||||
loss_weight=1.0), |
||||
loss_bbox=dict(type='IoULoss', linear=True, loss_weight=10.0)) |
||||
])) |
||||
train_cfg = dict(rpn=[ |
||||
dict( |
||||
assigner=dict( |
||||
type='RegionAssigner', center_ratio=0.2, ignore_ratio=0.5), |
||||
allowed_border=-1, |
||||
pos_weight=-1, |
||||
debug=False), |
||||
dict( |
||||
assigner=dict( |
||||
type='MaxIoUAssigner', |
||||
pos_iou_thr=0.7, |
||||
neg_iou_thr=0.7, |
||||
min_pos_iou=0.3, |
||||
ignore_iof_thr=-1, |
||||
iou_calculator=dict(type='BboxOverlaps2D')), |
||||
sampler=dict( |
||||
type='RandomSampler', |
||||
num=256, |
||||
pos_fraction=0.5, |
||||
neg_pos_ub=-1, |
||||
add_gt_as_proposals=False), |
||||
allowed_border=-1, |
||||
pos_weight=-1, |
||||
debug=False) |
||||
]) |
||||
test_cfg = dict( |
||||
rpn=dict( |
||||
nms_across_levels=False, |
||||
nms_pre=2000, |
||||
nms_post=2000, |
||||
max_num=2000, |
||||
nms_thr=0.8, |
||||
min_bbox_size=0)) |
||||
optimizer_config = dict( |
||||
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) |
@ -0,0 +1,204 @@ |
||||
import torch |
||||
|
||||
from mmdet.core import anchor_inside_flags |
||||
from ..builder import BBOX_ASSIGNERS |
||||
from .assign_result import AssignResult |
||||
from .base_assigner import BaseAssigner |
||||
|
||||
|
||||
def calc_region(bbox, ratio, stride, featmap_size=None): |
||||
"""Calculate region of the box defined by the ratio, the ratio is from the |
||||
center of the box to every edge.""" |
||||
# project bbox on the feature |
||||
f_bbox = bbox / stride |
||||
x1 = torch.round((1 - ratio) * f_bbox[0] + ratio * f_bbox[2]) |
||||
y1 = torch.round((1 - ratio) * f_bbox[1] + ratio * f_bbox[3]) |
||||
x2 = torch.round(ratio * f_bbox[0] + (1 - ratio) * f_bbox[2]) |
||||
y2 = torch.round(ratio * f_bbox[1] + (1 - ratio) * f_bbox[3]) |
||||
if featmap_size is not None: |
||||
x1 = x1.clamp(min=0, max=featmap_size[1]) |
||||
y1 = y1.clamp(min=0, max=featmap_size[0]) |
||||
x2 = x2.clamp(min=0, max=featmap_size[1]) |
||||
y2 = y2.clamp(min=0, max=featmap_size[0]) |
||||
return (x1, y1, x2, y2) |
||||
|
||||
|
||||
def anchor_ctr_inside_region_flags(anchors, stride, region): |
||||
"""Get the flag indicate whether anchor centers are inside regions.""" |
||||
x1, y1, x2, y2 = region |
||||
f_anchors = anchors / stride |
||||
x = (f_anchors[:, 0] + f_anchors[:, 2]) * 0.5 |
||||
y = (f_anchors[:, 1] + f_anchors[:, 3]) * 0.5 |
||||
flags = (x >= x1) & (x <= x2) & (y >= y1) & (y <= y2) |
||||
return flags |
||||
|
||||
|
||||
@BBOX_ASSIGNERS.register_module() |
||||
class RegionAssigner(BaseAssigner): |
||||
"""Assign a corresponding gt bbox or background to each bbox. |
||||
|
||||
Each proposals will be assigned with `-1`, `0`, or a positive integer |
||||
indicating the ground truth index. |
||||
|
||||
- -1: don't care |
||||
- 0: negative sample, no assigned gt |
||||
- positive integer: positive sample, index (1-based) of assigned gt |
||||
|
||||
Args: |
||||
center_ratio: ratio of the region in the center of the bbox to |
||||
define positive sample. |
||||
ignore_ratio: ratio of the region to define ignore samples. |
||||
""" |
||||
|
||||
def __init__(self, center_ratio=0.2, ignore_ratio=0.5): |
||||
self.center_ratio = center_ratio |
||||
self.ignore_ratio = ignore_ratio |
||||
|
||||
def assign(self, |
||||
mlvl_anchors, |
||||
mlvl_valid_flags, |
||||
gt_bboxes, |
||||
img_meta, |
||||
featmap_sizes, |
||||
anchor_scale, |
||||
anchor_strides, |
||||
gt_bboxes_ignore=None, |
||||
gt_labels=None, |
||||
allowed_border=0): |
||||
"""Assign gt to anchors. |
||||
|
||||
This method assign a gt bbox to every bbox (proposal/anchor), each bbox |
||||
will be assigned with -1, 0, or a positive number. -1 means don't care, |
||||
0 means negative sample, positive number is the index (1-based) of |
||||
assigned gt. |
||||
The assignment is done in following steps, the order matters. |
||||
|
||||
1. Assign every anchor to 0 (negative) |
||||
For each gt_bboxes: |
||||
2. Compute ignore flags based on ignore_region then |
||||
assign -1 to anchors w.r.t. ignore flags |
||||
3. Compute pos flags based on center_region then |
||||
assign gt_bboxes to anchors w.r.t. pos flags |
||||
4. Compute ignore flags based on adjacent anchor lvl then |
||||
assign -1 to anchors w.r.t. ignore flags |
||||
5. Assign anchor outside of image to -1 |
||||
|
||||
Args: |
||||
mlvl_anchors (list[Tensor]): Multi level anchors. |
||||
mlvl_valid_flags (list[Tensor]): Multi level valid flags. |
||||
gt_bboxes (Tensor): Ground truth bboxes of image |
||||
img_meta (dict): Meta info of image. |
||||
featmap_sizes (list[Tensor]): Feature mapsize each level |
||||
anchor_scale (int): Scale of the anchor. |
||||
anchor_strides (list[int]): Stride of the anchor. |
||||
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4). |
||||
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are |
||||
labelled as `ignored`, e.g., crowd boxes in COCO. |
||||
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). |
||||
allowed_border (int, optional): The border to allow the valid |
||||
anchor. Defaults to 0. |
||||
|
||||
Returns: |
||||
:obj:`AssignResult`: The assign result. |
||||
""" |
||||
# TODO support gt_bboxes_ignore |
||||
if gt_bboxes_ignore is not None: |
||||
raise NotImplementedError |
||||
if gt_bboxes.shape[0] == 0: |
||||
raise ValueError('No gt bboxes') |
||||
num_gts = gt_bboxes.shape[0] |
||||
num_lvls = len(mlvl_anchors) |
||||
r1 = (1 - self.center_ratio) / 2 |
||||
r2 = (1 - self.ignore_ratio) / 2 |
||||
|
||||
scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) * |
||||
(gt_bboxes[:, 3] - gt_bboxes[:, 1])) |
||||
min_anchor_size = scale.new_full( |
||||
(1, ), float(anchor_scale * anchor_strides[0])) |
||||
target_lvls = torch.floor( |
||||
torch.log2(scale) - torch.log2(min_anchor_size) + 0.5) |
||||
target_lvls = target_lvls.clamp(min=0, max=num_lvls - 1).long() |
||||
|
||||
# 1. assign 0 (negative) by default |
||||
mlvl_assigned_gt_inds = [] |
||||
mlvl_ignore_flags = [] |
||||
for lvl in range(num_lvls): |
||||
h, w = featmap_sizes[lvl] |
||||
assert h * w == mlvl_anchors[lvl].shape[0] |
||||
assigned_gt_inds = gt_bboxes.new_full((h * w, ), |
||||
0, |
||||
dtype=torch.long) |
||||
ignore_flags = torch.zeros_like(assigned_gt_inds) |
||||
mlvl_assigned_gt_inds.append(assigned_gt_inds) |
||||
mlvl_ignore_flags.append(ignore_flags) |
||||
|
||||
for gt_id in range(num_gts): |
||||
lvl = target_lvls[gt_id].item() |
||||
featmap_size = featmap_sizes[lvl] |
||||
stride = anchor_strides[lvl] |
||||
anchors = mlvl_anchors[lvl] |
||||
gt_bbox = gt_bboxes[gt_id, :4] |
||||
|
||||
# Compute regions |
||||
ignore_region = calc_region(gt_bbox, r2, stride, featmap_size) |
||||
ctr_region = calc_region(gt_bbox, r1, stride, featmap_size) |
||||
|
||||
# 2. Assign -1 to ignore flags |
||||
ignore_flags = anchor_ctr_inside_region_flags( |
||||
anchors, stride, ignore_region) |
||||
mlvl_assigned_gt_inds[lvl][ignore_flags] = -1 |
||||
|
||||
# 3. Assign gt_bboxes to pos flags |
||||
pos_flags = anchor_ctr_inside_region_flags(anchors, stride, |
||||
ctr_region) |
||||
mlvl_assigned_gt_inds[lvl][pos_flags] = gt_id + 1 |
||||
|
||||
# 4. Assign -1 to ignore adjacent lvl |
||||
if lvl > 0: |
||||
d_lvl = lvl - 1 |
||||
d_anchors = mlvl_anchors[d_lvl] |
||||
d_featmap_size = featmap_sizes[d_lvl] |
||||
d_stride = anchor_strides[d_lvl] |
||||
d_ignore_region = calc_region(gt_bbox, r2, d_stride, |
||||
d_featmap_size) |
||||
ignore_flags = anchor_ctr_inside_region_flags( |
||||
d_anchors, d_stride, d_ignore_region) |
||||
mlvl_ignore_flags[d_lvl][ignore_flags] = 1 |
||||
if lvl < num_lvls - 1: |
||||
u_lvl = lvl + 1 |
||||
u_anchors = mlvl_anchors[u_lvl] |
||||
u_featmap_size = featmap_sizes[u_lvl] |
||||
u_stride = anchor_strides[u_lvl] |
||||
u_ignore_region = calc_region(gt_bbox, r2, u_stride, |
||||
u_featmap_size) |
||||
ignore_flags = anchor_ctr_inside_region_flags( |
||||
u_anchors, u_stride, u_ignore_region) |
||||
mlvl_ignore_flags[u_lvl][ignore_flags] = 1 |
||||
|
||||
# 4. (cont.) Assign -1 to ignore adjacent lvl |
||||
for lvl in range(num_lvls): |
||||
ignore_flags = mlvl_ignore_flags[lvl] |
||||
mlvl_assigned_gt_inds[lvl][ignore_flags] = -1 |
||||
|
||||
# 5. Assign -1 to anchor outside of image |
||||
flat_assigned_gt_inds = torch.cat(mlvl_assigned_gt_inds) |
||||
flat_anchors = torch.cat(mlvl_anchors) |
||||
flat_valid_flags = torch.cat(mlvl_valid_flags) |
||||
assert (flat_assigned_gt_inds.shape[0] == flat_anchors.shape[0] == |
||||
flat_valid_flags.shape[0]) |
||||
inside_flags = anchor_inside_flags(flat_anchors, flat_valid_flags, |
||||
img_meta['img_shape'], |
||||
allowed_border) |
||||
outside_flags = ~inside_flags |
||||
flat_assigned_gt_inds[outside_flags] = -1 |
||||
|
||||
if gt_labels is not None: |
||||
assigned_labels = torch.zeros_like(flat_assigned_gt_inds) |
||||
pos_flags = assigned_gt_inds > 0 |
||||
assigned_labels[pos_flags] = gt_labels[ |
||||
flat_assigned_gt_inds[pos_flags] - 1] |
||||
else: |
||||
assigned_labels = None |
||||
|
||||
return AssignResult( |
||||
num_gts, flat_assigned_gt_inds, None, labels=assigned_labels) |
@ -0,0 +1,648 @@ |
||||
from __future__ import division |
||||
|
||||
import torch |
||||
import torch.nn as nn |
||||
from mmcv.cnn import normal_init |
||||
from mmcv.ops import DeformConv2d |
||||
|
||||
from mmdet.core import (RegionAssigner, build_assigner, build_sampler, |
||||
images_to_levels, multi_apply) |
||||
from ..builder import HEADS, build_head |
||||
from .base_dense_head import BaseDenseHead |
||||
from .rpn_head import RPNHead |
||||
|
||||
|
||||
class AdaptiveConv(nn.Module): |
||||
"""AdaptiveConv used to adapt the sampling location with the anchors. |
||||
|
||||
Args: |
||||
in_channels (int): Number of channels in the input image |
||||
out_channels (int): Number of channels produced by the convolution |
||||
kernel_size (int or tuple): Size of the conv kernel. Default: 3 |
||||
stride (int or tuple, optional): Stride of the convolution. Default: 1 |
||||
padding (int or tuple, optional): Zero-padding added to both sides of |
||||
the input. Default: 1 |
||||
dilation (int or tuple, optional): Spacing between kernel elements. |
||||
Default: 3 |
||||
groups (int, optional): Number of blocked connections from input |
||||
channels to output channels. Default: 1 |
||||
bias (bool, optional): If set True, adds a learnable bias to the |
||||
output. Default: False. |
||||
type (str, optional): Type of adaptive conv, can be either 'offset' |
||||
(arbitrary anchors) or 'dilation' (uniform anchor). |
||||
Default: 'dilation'. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
in_channels, |
||||
out_channels, |
||||
kernel_size=3, |
||||
stride=1, |
||||
padding=1, |
||||
dilation=3, |
||||
groups=1, |
||||
bias=False, |
||||
type='dilation'): |
||||
super(AdaptiveConv, self).__init__() |
||||
assert type in ['offset', 'dilation'] |
||||
self.adapt_type = type |
||||
|
||||
assert kernel_size == 3, 'Adaptive conv only supports kernels 3' |
||||
if self.adapt_type == 'offset': |
||||
assert stride == 1 and padding == 1 and groups == 1, \ |
||||
'Addptive conv offset mode only supports padding: {1}, ' \ |
||||
f'stride: {1}, groups: {1}' |
||||
self.conv = DeformConv2d( |
||||
in_channels, |
||||
out_channels, |
||||
kernel_size, |
||||
padding=padding, |
||||
stride=stride, |
||||
groups=groups, |
||||
bias=bias) |
||||
else: |
||||
self.conv = nn.Conv2d( |
||||
in_channels, |
||||
out_channels, |
||||
kernel_size, |
||||
padding=dilation, |
||||
dilation=dilation) |
||||
|
||||
def init_weights(self): |
||||
"""Init weights.""" |
||||
normal_init(self.conv, std=0.01) |
||||
|
||||
def forward(self, x, offset): |
||||
"""Forward function.""" |
||||
if self.adapt_type == 'offset': |
||||
N, _, H, W = x.shape |
||||
assert offset is not None |
||||
assert H * W == offset.shape[1] |
||||
# reshape [N, NA, 18] to (N, 18, H, W) |
||||
offset = offset.permute(0, 2, 1).reshape(N, -1, H, W) |
||||
offset = offset.contiguous() |
||||
x = self.conv(x, offset) |
||||
else: |
||||
assert offset is None |
||||
x = self.conv(x) |
||||
return x |
||||
|
||||
|
||||
@HEADS.register_module() |
||||
class StageCascadeRPNHead(RPNHead): |
||||
"""Stage of CascadeRPNHead. |
||||
|
||||
Args: |
||||
in_channels (int): Number of channels in the input feature map. |
||||
anchor_generator (dict): anchor generator config. |
||||
adapt_cfg (dict): adaptation config. |
||||
bridged_feature (bool, optional): wheater update rpn feature. |
||||
Default: False. |
||||
with_cls (bool, optional): wheather use classification branch. |
||||
Default: True. |
||||
sampling (bool, optional): wheather use sampling. Default: True. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
in_channels, |
||||
anchor_generator=dict( |
||||
type='AnchorGenerator', |
||||
scales=[8], |
||||
ratios=[1.0], |
||||
strides=[4, 8, 16, 32, 64]), |
||||
adapt_cfg=dict(type='dilation', dilation=3), |
||||
bridged_feature=False, |
||||
with_cls=True, |
||||
sampling=True, |
||||
**kwargs): |
||||
self.with_cls = with_cls |
||||
self.anchor_strides = anchor_generator['strides'] |
||||
self.anchor_scales = anchor_generator['scales'] |
||||
self.bridged_feature = bridged_feature |
||||
self.adapt_cfg = adapt_cfg |
||||
super(StageCascadeRPNHead, self).__init__( |
||||
in_channels, anchor_generator=anchor_generator, **kwargs) |
||||
|
||||
# override sampling and sampler |
||||
self.sampling = sampling |
||||
if self.train_cfg: |
||||
self.assigner = build_assigner(self.train_cfg.assigner) |
||||
# use PseudoSampler when sampling is False |
||||
if self.sampling and hasattr(self.train_cfg, 'sampler'): |
||||
sampler_cfg = self.train_cfg.sampler |
||||
else: |
||||
sampler_cfg = dict(type='PseudoSampler') |
||||
self.sampler = build_sampler(sampler_cfg, context=self) |
||||
|
||||
def _init_layers(self): |
||||
"""Init layers of a CascadeRPN stage.""" |
||||
self.rpn_conv = AdaptiveConv(self.in_channels, self.feat_channels, |
||||
**self.adapt_cfg) |
||||
if self.with_cls: |
||||
self.rpn_cls = nn.Conv2d(self.feat_channels, |
||||
self.num_anchors * self.cls_out_channels, |
||||
1) |
||||
self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1) |
||||
self.relu = nn.ReLU(inplace=True) |
||||
|
||||
def init_weights(self): |
||||
"""Init weights of a CascadeRPN stage.""" |
||||
self.rpn_conv.init_weights() |
||||
normal_init(self.rpn_reg, std=0.01) |
||||
if self.with_cls: |
||||
normal_init(self.rpn_cls, std=0.01) |
||||
|
||||
def forward_single(self, x, offset): |
||||
"""Forward function of single scale.""" |
||||
bridged_x = x |
||||
x = self.relu(self.rpn_conv(x, offset)) |
||||
if self.bridged_feature: |
||||
bridged_x = x # update feature |
||||
cls_score = self.rpn_cls(x) if self.with_cls else None |
||||
bbox_pred = self.rpn_reg(x) |
||||
return bridged_x, cls_score, bbox_pred |
||||
|
||||
def forward(self, feats, offset_list=None): |
||||
"""Forward function.""" |
||||
if offset_list is None: |
||||
offset_list = [None for _ in range(len(feats))] |
||||
return multi_apply(self.forward_single, feats, offset_list) |
||||
|
||||
def _region_targets_single(self, |
||||
anchors, |
||||
valid_flags, |
||||
gt_bboxes, |
||||
gt_bboxes_ignore, |
||||
gt_labels, |
||||
img_meta, |
||||
featmap_sizes, |
||||
label_channels=1): |
||||
"""Get anchor targets based on region for single level.""" |
||||
assign_result = self.assigner.assign( |
||||
anchors, |
||||
valid_flags, |
||||
gt_bboxes, |
||||
img_meta, |
||||
featmap_sizes, |
||||
self.anchor_scales[0], |
||||
self.anchor_strides, |
||||
gt_bboxes_ignore=gt_bboxes_ignore, |
||||
gt_labels=None, |
||||
allowed_border=self.train_cfg.allowed_border) |
||||
flat_anchors = torch.cat(anchors) |
||||
sampling_result = self.sampler.sample(assign_result, flat_anchors, |
||||
gt_bboxes) |
||||
|
||||
num_anchors = flat_anchors.shape[0] |
||||
bbox_targets = torch.zeros_like(flat_anchors) |
||||
bbox_weights = torch.zeros_like(flat_anchors) |
||||
labels = flat_anchors.new_zeros(num_anchors, dtype=torch.long) |
||||
label_weights = flat_anchors.new_zeros(num_anchors, dtype=torch.float) |
||||
|
||||
pos_inds = sampling_result.pos_inds |
||||
neg_inds = sampling_result.neg_inds |
||||
if len(pos_inds) > 0: |
||||
if not self.reg_decoded_bbox: |
||||
pos_bbox_targets = self.bbox_coder.encode( |
||||
sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes) |
||||
else: |
||||
pos_bbox_targets = sampling_result.pos_gt_bboxes |
||||
bbox_targets[pos_inds, :] = pos_bbox_targets |
||||
bbox_weights[pos_inds, :] = 1.0 |
||||
if gt_labels is None: |
||||
labels[pos_inds] = 1 |
||||
else: |
||||
labels[pos_inds] = gt_labels[ |
||||
sampling_result.pos_assigned_gt_inds] |
||||
if self.train_cfg.pos_weight <= 0: |
||||
label_weights[pos_inds] = 1.0 |
||||
else: |
||||
label_weights[pos_inds] = self.train_cfg.pos_weight |
||||
if len(neg_inds) > 0: |
||||
label_weights[neg_inds] = 1.0 |
||||
|
||||
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, |
||||
neg_inds) |
||||
|
||||
def region_targets(self, |
||||
anchor_list, |
||||
valid_flag_list, |
||||
gt_bboxes_list, |
||||
img_metas, |
||||
featmap_sizes, |
||||
gt_bboxes_ignore_list=None, |
||||
gt_labels_list=None, |
||||
label_channels=1, |
||||
unmap_outputs=True): |
||||
"""See :func:`StageCascadeRPNHead.get_targets`.""" |
||||
num_imgs = len(img_metas) |
||||
assert len(anchor_list) == len(valid_flag_list) == num_imgs |
||||
|
||||
# anchor number of multi levels |
||||
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] |
||||
|
||||
# compute targets for each image |
||||
if gt_bboxes_ignore_list is None: |
||||
gt_bboxes_ignore_list = [None for _ in range(num_imgs)] |
||||
if gt_labels_list is None: |
||||
gt_labels_list = [None for _ in range(num_imgs)] |
||||
(all_labels, all_label_weights, all_bbox_targets, all_bbox_weights, |
||||
pos_inds_list, neg_inds_list) = multi_apply( |
||||
self._region_targets_single, |
||||
anchor_list, |
||||
valid_flag_list, |
||||
gt_bboxes_list, |
||||
gt_bboxes_ignore_list, |
||||
gt_labels_list, |
||||
img_metas, |
||||
featmap_sizes=featmap_sizes, |
||||
label_channels=label_channels) |
||||
# no valid anchors |
||||
if any([labels is None for labels in all_labels]): |
||||
return None |
||||
# sampled anchors of all images |
||||
num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) |
||||
num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) |
||||
# split targets to a list w.r.t. multiple levels |
||||
labels_list = images_to_levels(all_labels, num_level_anchors) |
||||
label_weights_list = images_to_levels(all_label_weights, |
||||
num_level_anchors) |
||||
bbox_targets_list = images_to_levels(all_bbox_targets, |
||||
num_level_anchors) |
||||
bbox_weights_list = images_to_levels(all_bbox_weights, |
||||
num_level_anchors) |
||||
return (labels_list, label_weights_list, bbox_targets_list, |
||||
bbox_weights_list, num_total_pos, num_total_neg) |
||||
|
||||
def get_targets(self, |
||||
anchor_list, |
||||
valid_flag_list, |
||||
gt_bboxes, |
||||
img_metas, |
||||
featmap_sizes, |
||||
gt_bboxes_ignore=None, |
||||
label_channels=1): |
||||
"""Compute regression and classification targets for anchors. |
||||
|
||||
Args: |
||||
anchor_list (list[list]): Multi level anchors of each image. |
||||
valid_flag_list (list[list]): Multi level valid flags of each |
||||
image. |
||||
gt_bboxes (list[Tensor]): Ground truth bboxes of each image. |
||||
img_metas (list[dict]): Meta info of each image. |
||||
featmap_sizes (list[Tensor]): Feature mapsize each level |
||||
gt_bboxes_ignore (list[Tensor]): Ignore bboxes of each images |
||||
label_channels (int): Channel of label. |
||||
|
||||
Returns: |
||||
cls_reg_targets (tuple) |
||||
""" |
||||
if isinstance(self.assigner, RegionAssigner): |
||||
cls_reg_targets = self.region_targets( |
||||
anchor_list, |
||||
valid_flag_list, |
||||
gt_bboxes, |
||||
img_metas, |
||||
featmap_sizes, |
||||
gt_bboxes_ignore_list=gt_bboxes_ignore, |
||||
label_channels=label_channels) |
||||
else: |
||||
cls_reg_targets = super(StageCascadeRPNHead, self).get_targets( |
||||
anchor_list, |
||||
valid_flag_list, |
||||
gt_bboxes, |
||||
img_metas, |
||||
gt_bboxes_ignore_list=gt_bboxes_ignore, |
||||
label_channels=label_channels) |
||||
return cls_reg_targets |
||||
|
||||
def anchor_offset(self, anchor_list, anchor_strides, featmap_sizes): |
||||
""" Get offest for deformable conv based on anchor shape |
||||
NOTE: currently support deformable kernel_size=3 and dilation=1 |
||||
|
||||
Args: |
||||
anchor_list (list[list[tensor])): [NI, NLVL, NA, 4] list of |
||||
multi-level anchors |
||||
anchor_strides (list[int]): anchor stride of each level |
||||
|
||||
Returns: |
||||
offset_list (list[tensor]): [NLVL, NA, 2, 18]: offset of DeformConv |
||||
kernel. |
||||
""" |
||||
|
||||
def _shape_offset(anchors, stride, ks=3, dilation=1): |
||||
# currently support kernel_size=3 and dilation=1 |
||||
assert ks == 3 and dilation == 1 |
||||
pad = (ks - 1) // 2 |
||||
idx = torch.arange(-pad, pad + 1, dtype=dtype, device=device) |
||||
yy, xx = torch.meshgrid(idx, idx) # return order matters |
||||
xx = xx.reshape(-1) |
||||
yy = yy.reshape(-1) |
||||
w = (anchors[:, 2] - anchors[:, 0]) / stride |
||||
h = (anchors[:, 3] - anchors[:, 1]) / stride |
||||
w = w / (ks - 1) - dilation |
||||
h = h / (ks - 1) - dilation |
||||
offset_x = w[:, None] * xx # (NA, ks**2) |
||||
offset_y = h[:, None] * yy # (NA, ks**2) |
||||
return offset_x, offset_y |
||||
|
||||
def _ctr_offset(anchors, stride, featmap_size): |
||||
feat_h, feat_w = featmap_size |
||||
assert len(anchors) == feat_h * feat_w |
||||
|
||||
x = (anchors[:, 0] + anchors[:, 2]) * 0.5 |
||||
y = (anchors[:, 1] + anchors[:, 3]) * 0.5 |
||||
# compute centers on feature map |
||||
x = x / stride |
||||
y = y / stride |
||||
# compute predefine centers |
||||
xx = torch.arange(0, feat_w, device=anchors.device) |
||||
yy = torch.arange(0, feat_h, device=anchors.device) |
||||
yy, xx = torch.meshgrid(yy, xx) |
||||
xx = xx.reshape(-1).type_as(x) |
||||
yy = yy.reshape(-1).type_as(y) |
||||
|
||||
offset_x = x - xx # (NA, ) |
||||
offset_y = y - yy # (NA, ) |
||||
return offset_x, offset_y |
||||
|
||||
num_imgs = len(anchor_list) |
||||
num_lvls = len(anchor_list[0]) |
||||
dtype = anchor_list[0][0].dtype |
||||
device = anchor_list[0][0].device |
||||
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] |
||||
|
||||
offset_list = [] |
||||
for i in range(num_imgs): |
||||
mlvl_offset = [] |
||||
for lvl in range(num_lvls): |
||||
c_offset_x, c_offset_y = _ctr_offset(anchor_list[i][lvl], |
||||
anchor_strides[lvl], |
||||
featmap_sizes[lvl]) |
||||
s_offset_x, s_offset_y = _shape_offset(anchor_list[i][lvl], |
||||
anchor_strides[lvl]) |
||||
|
||||
# offset = ctr_offset + shape_offset |
||||
offset_x = s_offset_x + c_offset_x[:, None] |
||||
offset_y = s_offset_y + c_offset_y[:, None] |
||||
|
||||
# offset order (y0, x0, y1, x2, .., y8, x8, y9, x9) |
||||
offset = torch.stack([offset_y, offset_x], dim=-1) |
||||
offset = offset.reshape(offset.size(0), -1) # [NA, 2*ks**2] |
||||
mlvl_offset.append(offset) |
||||
offset_list.append(torch.cat(mlvl_offset)) # [totalNA, 2*ks**2] |
||||
offset_list = images_to_levels(offset_list, num_level_anchors) |
||||
return offset_list |
||||
|
||||
def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights, |
||||
bbox_targets, bbox_weights, num_total_samples): |
||||
"""Loss function on single scale.""" |
||||
# classification loss |
||||
if self.with_cls: |
||||
labels = labels.reshape(-1) |
||||
label_weights = label_weights.reshape(-1) |
||||
cls_score = cls_score.permute(0, 2, 3, |
||||
1).reshape(-1, self.cls_out_channels) |
||||
loss_cls = self.loss_cls( |
||||
cls_score, labels, label_weights, avg_factor=num_total_samples) |
||||
# regression loss |
||||
bbox_targets = bbox_targets.reshape(-1, 4) |
||||
bbox_weights = bbox_weights.reshape(-1, 4) |
||||
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) |
||||
if self.reg_decoded_bbox: |
||||
anchors = anchors.reshape(-1, 4) |
||||
bbox_pred = self.bbox_coder.decode(anchors, bbox_pred) |
||||
loss_reg = self.loss_bbox( |
||||
bbox_pred, |
||||
bbox_targets, |
||||
bbox_weights, |
||||
avg_factor=num_total_samples) |
||||
if self.with_cls: |
||||
return loss_cls, loss_reg |
||||
return None, loss_reg |
||||
|
||||
def loss(self, |
||||
anchor_list, |
||||
valid_flag_list, |
||||
cls_scores, |
||||
bbox_preds, |
||||
gt_bboxes, |
||||
img_metas, |
||||
gt_bboxes_ignore=None): |
||||
"""Compute losses of the head. |
||||
|
||||
Args: |
||||
anchor_list (list[list]): Multi level anchors of each image. |
||||
cls_scores (list[Tensor]): Box scores for each scale level |
||||
Has shape (N, num_anchors * num_classes, H, W) |
||||
bbox_preds (list[Tensor]): Box energies / deltas for each scale |
||||
level with shape (N, num_anchors * 4, H, W) |
||||
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with |
||||
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. |
||||
img_metas (list[dict]): Meta information of each image, e.g., |
||||
image size, scaling factor, etc. |
||||
gt_bboxes_ignore (None | list[Tensor]): specify which bounding |
||||
boxes can be ignored when computing the loss. Default: None |
||||
|
||||
Returns: |
||||
dict[str, Tensor]: A dictionary of loss components. |
||||
""" |
||||
featmap_sizes = [featmap.size()[-2:] for featmap in bbox_preds] |
||||
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 |
||||
cls_reg_targets = self.get_targets( |
||||
anchor_list, |
||||
valid_flag_list, |
||||
gt_bboxes, |
||||
img_metas, |
||||
featmap_sizes, |
||||
gt_bboxes_ignore=gt_bboxes_ignore, |
||||
label_channels=label_channels) |
||||
if cls_reg_targets is None: |
||||
return None |
||||
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, |
||||
num_total_pos, num_total_neg) = cls_reg_targets |
||||
if self.sampling: |
||||
num_total_samples = num_total_pos + num_total_neg |
||||
else: |
||||
# 200 is hard-coded average factor, |
||||
# which follows guilded anchoring. |
||||
num_total_samples = sum([label.numel() |
||||
for label in labels_list]) / 200.0 |
||||
|
||||
# change per image, per level anchor_list to per_level, per_image |
||||
mlvl_anchor_list = list(zip(*anchor_list)) |
||||
# concat mlvl_anchor_list |
||||
mlvl_anchor_list = [ |
||||
torch.cat(anchors, dim=0) for anchors in mlvl_anchor_list |
||||
] |
||||
|
||||
losses = multi_apply( |
||||
self.loss_single, |
||||
cls_scores, |
||||
bbox_preds, |
||||
mlvl_anchor_list, |
||||
labels_list, |
||||
label_weights_list, |
||||
bbox_targets_list, |
||||
bbox_weights_list, |
||||
num_total_samples=num_total_samples) |
||||
if self.with_cls: |
||||
return dict(loss_rpn_cls=losses[0], loss_rpn_reg=losses[1]) |
||||
return dict(loss_rpn_reg=losses[1]) |
||||
|
||||
def get_bboxes(self, |
||||
anchor_list, |
||||
cls_scores, |
||||
bbox_preds, |
||||
img_metas, |
||||
cfg, |
||||
rescale=False): |
||||
"""Get proposal predict.""" |
||||
assert len(cls_scores) == len(bbox_preds) |
||||
num_levels = len(cls_scores) |
||||
|
||||
result_list = [] |
||||
for img_id in range(len(img_metas)): |
||||
cls_score_list = [ |
||||
cls_scores[i][img_id].detach() for i in range(num_levels) |
||||
] |
||||
bbox_pred_list = [ |
||||
bbox_preds[i][img_id].detach() for i in range(num_levels) |
||||
] |
||||
img_shape = img_metas[img_id]['img_shape'] |
||||
scale_factor = img_metas[img_id]['scale_factor'] |
||||
proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list, |
||||
anchor_list[img_id], img_shape, |
||||
scale_factor, cfg, rescale) |
||||
result_list.append(proposals) |
||||
return result_list |
||||
|
||||
def refine_bboxes(self, anchor_list, bbox_preds, img_metas): |
||||
"""Refine bboxes through stages.""" |
||||
num_levels = len(bbox_preds) |
||||
new_anchor_list = [] |
||||
for img_id in range(len(img_metas)): |
||||
mlvl_anchors = [] |
||||
for i in range(num_levels): |
||||
bbox_pred = bbox_preds[i][img_id].detach() |
||||
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) |
||||
img_shape = img_metas[img_id]['img_shape'] |
||||
bboxes = self.bbox_coder.decode(anchor_list[img_id][i], |
||||
bbox_pred, img_shape) |
||||
mlvl_anchors.append(bboxes) |
||||
new_anchor_list.append(mlvl_anchors) |
||||
return new_anchor_list |
||||
|
||||
|
||||
@HEADS.register_module() |
||||
class CascadeRPNHead(BaseDenseHead): |
||||
"""The CascadeRPNHead will predict more accurate region proposals, which is |
||||
required for two-stage detectors (such as Fast/Faster R-CNN). CascadeRPN |
||||
consists of a sequence of RPNStage to progressively improve the accuracy of |
||||
the detected proposals. |
||||
|
||||
More details can be found in ``https://arxiv.org/abs/1909.06720``. |
||||
|
||||
Args: |
||||
num_stages (int): number of CascadeRPN stages. |
||||
stages (list[dict]): list of configs to build the stages. |
||||
train_cfg (list[dict]): list of configs at training time each stage. |
||||
test_cfg (dict): config at testing time. |
||||
""" |
||||
|
||||
def __init__(self, num_stages, stages, train_cfg, test_cfg): |
||||
super(CascadeRPNHead, self).__init__() |
||||
assert num_stages == len(stages) |
||||
self.num_stages = num_stages |
||||
self.stages = nn.ModuleList() |
||||
for i in range(len(stages)): |
||||
train_cfg_i = train_cfg[i] if train_cfg is not None else None |
||||
stages[i].update(train_cfg=train_cfg_i) |
||||
stages[i].update(test_cfg=test_cfg) |
||||
self.stages.append(build_head(stages[i])) |
||||
self.train_cfg = train_cfg |
||||
self.test_cfg = test_cfg |
||||
|
||||
def init_weights(self): |
||||
"""Init weight of CascadeRPN.""" |
||||
for i in range(self.num_stages): |
||||
self.stages[i].init_weights() |
||||
|
||||
def loss(self): |
||||
"""loss() is implemented in StageCascadeRPNHead.""" |
||||
pass |
||||
|
||||
def get_bboxes(self): |
||||
"""get_bboxes() is implemented in StageCascadeRPNHead.""" |
||||
pass |
||||
|
||||
def forward_train(self, |
||||
x, |
||||
img_metas, |
||||
gt_bboxes, |
||||
gt_labels=None, |
||||
gt_bboxes_ignore=None, |
||||
proposal_cfg=None): |
||||
"""Forward train function.""" |
||||
assert gt_labels is None, 'RPN does not require gt_labels' |
||||
|
||||
featmap_sizes = [featmap.size()[-2:] for featmap in x] |
||||
anchor_list, valid_flag_list = self.stages[0].get_anchors( |
||||
featmap_sizes, img_metas) |
||||
|
||||
losses = dict() |
||||
|
||||
for i in range(self.num_stages): |
||||
stage = self.stages[i] |
||||
|
||||
if stage.adapt_cfg['type'] == 'offset': |
||||
offset_list = stage.anchor_offset(anchor_list, |
||||
stage.anchor_strides, |
||||
featmap_sizes) |
||||
else: |
||||
offset_list = None |
||||
x, cls_score, bbox_pred = stage(x, offset_list) |
||||
rpn_loss_inputs = (anchor_list, valid_flag_list, cls_score, |
||||
bbox_pred, gt_bboxes, img_metas) |
||||
stage_loss = stage.loss(*rpn_loss_inputs) |
||||
for name, value in stage_loss.items(): |
||||
losses['s{}.{}'.format(i, name)] = value |
||||
|
||||
# refine boxes |
||||
if i < self.num_stages - 1: |
||||
anchor_list = stage.refine_bboxes(anchor_list, bbox_pred, |
||||
img_metas) |
||||
if proposal_cfg is None: |
||||
return losses |
||||
else: |
||||
proposal_list = self.stages[-1].get_bboxes(anchor_list, cls_score, |
||||
bbox_pred, img_metas, |
||||
self.test_cfg) |
||||
return losses, proposal_list |
||||
|
||||
def simple_test_rpn(self, x, img_metas): |
||||
"""Simple forward test function.""" |
||||
featmap_sizes = [featmap.size()[-2:] for featmap in x] |
||||
anchor_list, _ = self.stages[0].get_anchors(featmap_sizes, img_metas) |
||||
|
||||
for i in range(self.num_stages): |
||||
stage = self.stages[i] |
||||
if stage.adapt_cfg['type'] == 'offset': |
||||
offset_list = stage.anchor_offset(anchor_list, |
||||
stage.anchor_strides, |
||||
featmap_sizes) |
||||
else: |
||||
offset_list = None |
||||
x, cls_score, bbox_pred = stage(x, offset_list) |
||||
if i < self.num_stages - 1: |
||||
anchor_list = stage.refine_bboxes(anchor_list, bbox_pred, |
||||
img_metas) |
||||
|
||||
proposal_list = self.stages[-1].get_bboxes(anchor_list, cls_score, |
||||
bbox_pred, img_metas, |
||||
self.test_cfg) |
||||
return proposal_list |
||||
|
||||
def aug_test_rpn(self, x, img_metas): |
||||
"""Augmented forward test function.""" |
||||
raise NotImplementedError |
Loading…
Reference in new issue