[Feature]: Implement YOLOF (#4833)
* Add YOLOF inference * Update * Update Uniform match * Update param * Update cfg * Fix bug about assigner * Add reduce_mean * update yolof * update docstr * add iter base config * update code and add docstr * fix comment and update docstr * update transforms and add docstr * Simplify the code. * Enhancement random_shift and add unit test * add rest of unit test and process empty gt. * add README.md * update comment * Fix unittest * update model linkpull/4778/head^2
parent
670ecc2546
commit
2a856efb6f
19 changed files with 1168 additions and 16 deletions
@ -0,0 +1,25 @@ |
||||
# You Only Look One-level Feature |
||||
|
||||
## Introduction |
||||
|
||||
<!-- [ALGORITHM] --> |
||||
|
||||
``` |
||||
@inproceedings{chen2021you, |
||||
title={You Only Look One-level Feature}, |
||||
author={Chen, Qiang and Wang, Yingming and Yang, Tong and Zhang, Xiangyu and Cheng, Jian and Sun, Jian}, |
||||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, |
||||
year={2021} |
||||
} |
||||
``` |
||||
|
||||
## Results and Models |
||||
|
||||
| Backbone | Style | Epoch | Lr schd | Mem (GB) | box AP | Config | Download | |
||||
|:---------:|:-------:|:-------:|:-------:|:--------:|:------:|:------:|:--------:| |
||||
| R-50-C5 | caffe | Y | 1x | 8.3 | 37.5 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolof/yolof_r50_c5_8x8_1x_coco.py) |[model](http://download.openmmlab.com/mmdetection/v2.0/yolof/yolof_r50_c5_8x8_1x_coco/yolof_r50_c5_8x8_1x_coco_20210425_024427-8e864411.pth) | [log](http://download.openmmlab.com/mmdetection/v2.0/yolof/yolof_r50_c5_8x8_1x_coco/yolof_r50_c5_8x8_1x_coco_20210425_024427.log.json) | |
||||
|
||||
**Note**: |
||||
|
||||
1. We find that the performance is unstable and may fluctuate by about 0.3 mAP. mAP 37.4 ~ 37.7 is acceptable in YOLOF_R_50_C5_1x. Such fluctuation can also be found in the [original implementation](https://github.com/chensnathan/YOLOF). |
||||
2. In addition to instability issues, sometimes there are large loss fluctuations and NAN, so there may still be problems with this project, which will be improved subsequently. |
@ -0,0 +1,103 @@ |
||||
_base_ = [ |
||||
'../_base_/datasets/coco_detection.py', |
||||
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' |
||||
] |
||||
model = dict( |
||||
type='YOLOF', |
||||
pretrained='open-mmlab://detectron/resnet50_caffe', |
||||
backbone=dict( |
||||
type='ResNet', |
||||
depth=50, |
||||
num_stages=4, |
||||
out_indices=(3, ), |
||||
frozen_stages=1, |
||||
norm_cfg=dict(type='BN', requires_grad=False), |
||||
norm_eval=True, |
||||
style='caffe'), |
||||
neck=dict( |
||||
type='DilatedEncoder', |
||||
in_channels=2048, |
||||
out_channels=512, |
||||
block_mid_channels=128, |
||||
num_residual_blocks=4), |
||||
bbox_head=dict( |
||||
type='YOLOFHead', |
||||
num_classes=80, |
||||
in_channels=512, |
||||
reg_decoded_bbox=True, |
||||
anchor_generator=dict( |
||||
type='AnchorGenerator', |
||||
ratios=[1.0], |
||||
scales=[1, 2, 4, 8, 16], |
||||
strides=[32]), |
||||
bbox_coder=dict( |
||||
type='DeltaXYWHBBoxCoder', |
||||
target_means=[.0, .0, .0, .0], |
||||
target_stds=[1., 1., 1., 1.], |
||||
add_ctr_clamp=True, |
||||
ctr_clamp=32), |
||||
loss_cls=dict( |
||||
type='FocalLoss', |
||||
use_sigmoid=True, |
||||
gamma=2.0, |
||||
alpha=0.25, |
||||
loss_weight=1.0), |
||||
loss_bbox=dict(type='GIoULoss', loss_weight=1.0)), |
||||
# training and testing settings |
||||
train_cfg=dict( |
||||
assigner=dict( |
||||
type='UniformAssigner', pos_ignore_thr=0.15, neg_ignore_thr=0.7), |
||||
allowed_border=-1, |
||||
pos_weight=-1, |
||||
debug=False), |
||||
test_cfg=dict( |
||||
nms_pre=1000, |
||||
min_bbox_size=0, |
||||
score_thr=0.05, |
||||
nms=dict(type='nms', iou_threshold=0.6), |
||||
max_per_img=100)) |
||||
# optimizer |
||||
optimizer = dict( |
||||
type='SGD', |
||||
lr=0.12, |
||||
momentum=0.9, |
||||
weight_decay=0.0001, |
||||
paramwise_cfg=dict( |
||||
norm_decay_mult=0., custom_keys={'backbone': dict(lr_mult=1. / 3)})) |
||||
lr_config = dict(warmup_iters=1500, warmup_ratio=0.00066667) |
||||
|
||||
# use caffe img_norm |
||||
img_norm_cfg = dict( |
||||
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) |
||||
train_pipeline = [ |
||||
dict(type='LoadImageFromFile'), |
||||
dict(type='LoadAnnotations', with_bbox=True), |
||||
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), |
||||
dict(type='RandomFlip', flip_ratio=0.5), |
||||
dict(type='RandomShift', shift_ratio=0.5, max_shift_px=32), |
||||
dict(type='Normalize', **img_norm_cfg), |
||||
dict(type='Pad', size_divisor=32), |
||||
dict(type='DefaultFormatBundle'), |
||||
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) |
||||
] |
||||
test_pipeline = [ |
||||
dict(type='LoadImageFromFile'), |
||||
dict( |
||||
type='MultiScaleFlipAug', |
||||
img_scale=(1333, 800), |
||||
flip=False, |
||||
transforms=[ |
||||
dict(type='Resize', keep_ratio=True), |
||||
dict(type='RandomFlip'), |
||||
dict(type='Normalize', **img_norm_cfg), |
||||
dict(type='Pad', size_divisor=32), |
||||
dict(type='ImageToTensor', keys=['img']), |
||||
dict(type='Collect', keys=['img']), |
||||
]) |
||||
] |
||||
data = dict( |
||||
samples_per_gpu=8, |
||||
workers_per_gpu=8, |
||||
train=dict(pipeline=train_pipeline), |
||||
val=dict(pipeline=test_pipeline), |
||||
test=dict(pipeline=test_pipeline)) |
@ -0,0 +1,14 @@ |
||||
_base_ = './yolof_r50_c5_8x8_1x_coco.py' |
||||
|
||||
# We implemented the iter-based config according to the source code. |
||||
# COCO dataset has 117266 images after filtering. We use 8 gpu and |
||||
# 8 batch size training, so 22500 is equivalent to |
||||
# 22500/(117266/(8x8))=12.3 epoch, 15000 is equivalent to 8.2 epoch, |
||||
# 20000 is equivalent to 10.9 epoch. Due to lr(0.12) is large, |
||||
# the iter-based and epoch-based setting have about 0.2 difference on |
||||
# the mAP evaluation value. |
||||
lr_config = dict(step=[15000, 20000]) |
||||
runner = dict(_delete_=True, type='IterBasedRunner', max_iters=22500) |
||||
checkpoint_config = dict(interval=2500) |
||||
evaluation = dict(interval=4500) |
||||
log_config = dict(interval=20) |
@ -0,0 +1,134 @@ |
||||
import torch |
||||
|
||||
from ..builder import BBOX_ASSIGNERS |
||||
from ..iou_calculators import build_iou_calculator |
||||
from ..transforms import bbox_xyxy_to_cxcywh |
||||
from .assign_result import AssignResult |
||||
from .base_assigner import BaseAssigner |
||||
|
||||
|
||||
@BBOX_ASSIGNERS.register_module() |
||||
class UniformAssigner(BaseAssigner): |
||||
"""Uniform Matching between the anchors and gt boxes, which can achieve |
||||
balance in positive anchors, and gt_bboxes_ignore was not considered for |
||||
now. |
||||
|
||||
Args: |
||||
pos_ignore_thr (float): the threshold to ignore positive anchors |
||||
neg_ignore_thr (float): the threshold to ignore negative anchors |
||||
match_times(int): Number of positive anchors for each gt box. |
||||
Default 4. |
||||
iou_calculator (dict): iou_calculator config |
||||
""" |
||||
|
||||
def __init__(self, |
||||
pos_ignore_thr, |
||||
neg_ignore_thr, |
||||
match_times=4, |
||||
iou_calculator=dict(type='BboxOverlaps2D')): |
||||
self.match_times = match_times |
||||
self.pos_ignore_thr = pos_ignore_thr |
||||
self.neg_ignore_thr = neg_ignore_thr |
||||
self.iou_calculator = build_iou_calculator(iou_calculator) |
||||
|
||||
def assign(self, |
||||
bbox_pred, |
||||
anchor, |
||||
gt_bboxes, |
||||
gt_bboxes_ignore=None, |
||||
gt_labels=None): |
||||
num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0) |
||||
|
||||
# 1. assign -1 by default |
||||
assigned_gt_inds = bbox_pred.new_full((num_bboxes, ), |
||||
0, |
||||
dtype=torch.long) |
||||
assigned_labels = bbox_pred.new_full((num_bboxes, ), |
||||
-1, |
||||
dtype=torch.long) |
||||
if num_gts == 0 or num_bboxes == 0: |
||||
# No ground truth or boxes, return empty assignment |
||||
if num_gts == 0: |
||||
# No ground truth, assign all to background |
||||
assigned_gt_inds[:] = 0 |
||||
assign_result = AssignResult( |
||||
num_gts, assigned_gt_inds, None, labels=assigned_labels) |
||||
assign_result.set_extra_property( |
||||
'pos_idx', bbox_pred.new_empty(0, dtype=torch.bool)) |
||||
assign_result.set_extra_property('pos_predicted_boxes', |
||||
bbox_pred.new_empty((0, 4))) |
||||
assign_result.set_extra_property('target_boxes', |
||||
bbox_pred.new_empty((0, 4))) |
||||
return assign_result |
||||
|
||||
# 2. Compute the L1 cost between boxes |
||||
# Note that we use anchors and predict boxes both |
||||
cost_bbox = torch.cdist( |
||||
bbox_xyxy_to_cxcywh(bbox_pred), |
||||
bbox_xyxy_to_cxcywh(gt_bboxes), |
||||
p=1) |
||||
cost_bbox_anchors = torch.cdist( |
||||
bbox_xyxy_to_cxcywh(anchor), bbox_xyxy_to_cxcywh(gt_bboxes), p=1) |
||||
|
||||
# We found that topk function has different results in cpu and |
||||
# cuda mode. In order to ensure consistency with the source code, |
||||
# we also use cpu mode. |
||||
# TODO: Check whether the performance of cpu and cuda are the same. |
||||
C = cost_bbox.cpu() |
||||
C1 = cost_bbox_anchors.cpu() |
||||
|
||||
# self.match_times x n |
||||
index = torch.topk( |
||||
C, # c=b,n,x c[i]=n,x |
||||
k=self.match_times, |
||||
dim=0, |
||||
largest=False)[1] |
||||
|
||||
# self.match_times x n |
||||
index1 = torch.topk(C1, k=self.match_times, dim=0, largest=False)[1] |
||||
# (self.match_times*2) x n |
||||
indexes = torch.cat((index, index1), |
||||
dim=1).reshape(-1).to(bbox_pred.device) |
||||
|
||||
pred_overlaps = self.iou_calculator(bbox_pred, gt_bboxes) |
||||
anchor_overlaps = self.iou_calculator(anchor, gt_bboxes) |
||||
pred_max_overlaps, _ = pred_overlaps.max(dim=1) |
||||
anchor_max_overlaps, _ = anchor_overlaps.max(dim=0) |
||||
|
||||
# 3. Compute the ignore indexes use gt_bboxes and predict boxes |
||||
ignore_idx = pred_max_overlaps > self.neg_ignore_thr |
||||
assigned_gt_inds[ignore_idx] = -1 |
||||
|
||||
# 4. Compute the ignore indexes of positive sample use anchors |
||||
# and predict boxes |
||||
pos_gt_index = torch.arange( |
||||
0, C1.size(1), |
||||
device=bbox_pred.device).repeat(self.match_times * 2) |
||||
pos_ious = anchor_overlaps[indexes, pos_gt_index] |
||||
pos_ignore_idx = pos_ious < self.pos_ignore_thr |
||||
|
||||
pos_gt_index_with_ignore = pos_gt_index + 1 |
||||
pos_gt_index_with_ignore[pos_ignore_idx] = -1 |
||||
assigned_gt_inds[indexes] = pos_gt_index_with_ignore |
||||
|
||||
if gt_labels is not None: |
||||
assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) |
||||
pos_inds = torch.nonzero( |
||||
assigned_gt_inds > 0, as_tuple=False).squeeze() |
||||
if pos_inds.numel() > 0: |
||||
assigned_labels[pos_inds] = gt_labels[ |
||||
assigned_gt_inds[pos_inds] - 1] |
||||
else: |
||||
assigned_labels = None |
||||
|
||||
assign_result = AssignResult( |
||||
num_gts, |
||||
assigned_gt_inds, |
||||
anchor_max_overlaps, |
||||
labels=assigned_labels) |
||||
assign_result.set_extra_property('pos_idx', ~pos_ignore_idx) |
||||
assign_result.set_extra_property('pos_predicted_boxes', |
||||
bbox_pred[indexes]) |
||||
assign_result.set_extra_property('target_boxes', |
||||
gt_bboxes[pos_gt_index]) |
||||
return assign_result |
@ -0,0 +1,415 @@ |
||||
import torch |
||||
import torch.nn as nn |
||||
from mmcv.cnn import (ConvModule, bias_init_with_prob, constant_init, is_norm, |
||||
normal_init) |
||||
from mmcv.runner import force_fp32 |
||||
|
||||
from mmdet.core import anchor_inside_flags, multi_apply, reduce_mean, unmap |
||||
from ..builder import HEADS |
||||
from .anchor_head import AnchorHead |
||||
|
||||
INF = 1e8 |
||||
|
||||
|
||||
def levels_to_images(mlvl_tensor): |
||||
"""Concat multi-level feature maps by image. |
||||
|
||||
[feature_level0, feature_level1...] -> [feature_image0, feature_image1...] |
||||
Convert the shape of each element in mlvl_tensor from (N, C, H, W) to |
||||
(N, H*W , C), then split the element to N elements with shape (H*W, C), and |
||||
concat elements in same image of all level along first dimension. |
||||
|
||||
Args: |
||||
mlvl_tensor (list[torch.Tensor]): list of Tensor which collect from |
||||
corresponding level. Each element is of shape (N, C, H, W) |
||||
|
||||
Returns: |
||||
list[torch.Tensor]: A list that contains N tensors and each tensor is |
||||
of shape (num_elements, C) |
||||
""" |
||||
batch_size = mlvl_tensor[0].size(0) |
||||
batch_list = [[] for _ in range(batch_size)] |
||||
channels = mlvl_tensor[0].size(1) |
||||
for t in mlvl_tensor: |
||||
t = t.permute(0, 2, 3, 1) |
||||
t = t.view(batch_size, -1, channels).contiguous() |
||||
for img in range(batch_size): |
||||
batch_list[img].append(t[img]) |
||||
return [torch.cat(item, 0) for item in batch_list] |
||||
|
||||
|
||||
@HEADS.register_module() |
||||
class YOLOFHead(AnchorHead): |
||||
"""YOLOFHead Paper link: https://arxiv.org/abs/2103.09460. |
||||
|
||||
Args: |
||||
num_classes (int): The number of object classes (w/o background) |
||||
in_channels (List[int]): The number of input channels per scale. |
||||
cls_num_convs (int): The number of convolutions of cls branch. |
||||
Default 2. |
||||
reg_num_convs (int): The number of convolutions of reg branch. |
||||
Default 4. |
||||
norm_cfg (dict): Dictionary to construct and config norm layer. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
num_classes, |
||||
in_channels, |
||||
num_cls_convs=2, |
||||
num_reg_convs=4, |
||||
norm_cfg=dict(type='BN', requires_grad=True), |
||||
**kwargs): |
||||
self.num_cls_convs = num_cls_convs |
||||
self.num_reg_convs = num_reg_convs |
||||
self.norm_cfg = norm_cfg |
||||
super(YOLOFHead, self).__init__(num_classes, in_channels, **kwargs) |
||||
|
||||
def _init_layers(self): |
||||
cls_subnet = [] |
||||
bbox_subnet = [] |
||||
for i in range(self.num_cls_convs): |
||||
cls_subnet.append( |
||||
ConvModule( |
||||
self.in_channels, |
||||
self.in_channels, |
||||
kernel_size=3, |
||||
padding=1, |
||||
norm_cfg=self.norm_cfg)) |
||||
for i in range(self.num_reg_convs): |
||||
bbox_subnet.append( |
||||
ConvModule( |
||||
self.in_channels, |
||||
self.in_channels, |
||||
kernel_size=3, |
||||
padding=1, |
||||
norm_cfg=self.norm_cfg)) |
||||
self.cls_subnet = nn.Sequential(*cls_subnet) |
||||
self.bbox_subnet = nn.Sequential(*bbox_subnet) |
||||
self.cls_score = nn.Conv2d( |
||||
self.in_channels, |
||||
self.num_anchors * self.num_classes, |
||||
kernel_size=3, |
||||
stride=1, |
||||
padding=1) |
||||
self.bbox_pred = nn.Conv2d( |
||||
self.in_channels, |
||||
self.num_anchors * 4, |
||||
kernel_size=3, |
||||
stride=1, |
||||
padding=1) |
||||
self.object_pred = nn.Conv2d( |
||||
self.in_channels, |
||||
self.num_anchors, |
||||
kernel_size=3, |
||||
stride=1, |
||||
padding=1) |
||||
|
||||
def init_weights(self): |
||||
for m in self.modules(): |
||||
if isinstance(m, nn.Conv2d): |
||||
normal_init(m, mean=0, std=0.01) |
||||
if is_norm(m): |
||||
constant_init(m, 1) |
||||
|
||||
# Use prior in model initialization to improve stability |
||||
bias_cls = bias_init_with_prob(0.01) |
||||
torch.nn.init.constant_(self.cls_score.bias, bias_cls) |
||||
|
||||
def forward_single(self, feature): |
||||
cls_score = self.cls_score(self.cls_subnet(feature)) |
||||
N, _, H, W = cls_score.shape |
||||
cls_score = cls_score.view(N, -1, self.num_classes, H, W) |
||||
|
||||
reg_feat = self.bbox_subnet(feature) |
||||
bbox_reg = self.bbox_pred(reg_feat) |
||||
objectness = self.object_pred(reg_feat) |
||||
|
||||
# implicit objectness |
||||
objectness = objectness.view(N, -1, 1, H, W) |
||||
normalized_cls_score = cls_score + objectness - torch.log( |
||||
1. + torch.clamp(cls_score.exp(), max=INF) + |
||||
torch.clamp(objectness.exp(), max=INF)) |
||||
normalized_cls_score = normalized_cls_score.view(N, -1, H, W) |
||||
return normalized_cls_score, bbox_reg |
||||
|
||||
@force_fp32(apply_to=('cls_scores', 'bbox_preds')) |
||||
def loss(self, |
||||
cls_scores, |
||||
bbox_preds, |
||||
gt_bboxes, |
||||
gt_labels, |
||||
img_metas, |
||||
gt_bboxes_ignore=None): |
||||
"""Compute losses of the head. |
||||
|
||||
Args: |
||||
cls_scores (list[Tensor]): Box scores for each scale level |
||||
Has shape (batch, num_anchors * num_classes, h, w) |
||||
bbox_preds (list[Tensor]): Box energies / deltas for each scale |
||||
level with shape (batch, num_anchors * 4, h, w) |
||||
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with |
||||
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. |
||||
gt_labels (list[Tensor]): class indices corresponding to each box |
||||
img_metas (list[dict]): Meta information of each image, e.g., |
||||
image size, scaling factor, etc. |
||||
gt_bboxes_ignore (None | list[Tensor]): specify which bounding |
||||
boxes can be ignored when computing the loss. Default: None |
||||
|
||||
Returns: |
||||
dict[str, Tensor]: A dictionary of loss components. |
||||
""" |
||||
assert len(cls_scores) == 1 |
||||
assert self.anchor_generator.num_levels == 1 |
||||
|
||||
device = cls_scores[0].device |
||||
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] |
||||
anchor_list, valid_flag_list = self.get_anchors( |
||||
featmap_sizes, img_metas, device=device) |
||||
|
||||
# The output level is always 1 |
||||
anchor_list = [anchors[0] for anchors in anchor_list] |
||||
valid_flag_list = [valid_flags[0] for valid_flags in valid_flag_list] |
||||
|
||||
cls_scores_list = levels_to_images(cls_scores) |
||||
bbox_preds_list = levels_to_images(bbox_preds) |
||||
|
||||
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 |
||||
cls_reg_targets = self.get_targets( |
||||
cls_scores_list, |
||||
bbox_preds_list, |
||||
anchor_list, |
||||
valid_flag_list, |
||||
gt_bboxes, |
||||
img_metas, |
||||
gt_bboxes_ignore_list=gt_bboxes_ignore, |
||||
gt_labels_list=gt_labels, |
||||
label_channels=label_channels) |
||||
if cls_reg_targets is None: |
||||
return None |
||||
(batch_labels, batch_label_weights, num_total_pos, num_total_neg, |
||||
batch_bbox_weights, batch_pos_predicted_boxes, |
||||
batch_target_boxes) = cls_reg_targets |
||||
|
||||
flatten_labels = batch_labels.reshape(-1) |
||||
batch_label_weights = batch_label_weights.reshape(-1) |
||||
cls_score = cls_scores[0].permute(0, 2, 3, |
||||
1).reshape(-1, self.cls_out_channels) |
||||
|
||||
num_total_samples = (num_total_pos + |
||||
num_total_neg) if self.sampling else num_total_pos |
||||
num_total_samples = reduce_mean( |
||||
cls_score.new_tensor(num_total_samples)).clamp_(1.0).item() |
||||
|
||||
# classification loss |
||||
loss_cls = self.loss_cls( |
||||
cls_score, |
||||
flatten_labels, |
||||
batch_label_weights, |
||||
avg_factor=num_total_samples) |
||||
|
||||
# regression loss |
||||
if batch_pos_predicted_boxes.shape[0] == 0: |
||||
# no pos sample |
||||
loss_bbox = batch_pos_predicted_boxes.sum() * 0 |
||||
else: |
||||
loss_bbox = self.loss_bbox( |
||||
batch_pos_predicted_boxes, |
||||
batch_target_boxes, |
||||
batch_bbox_weights.float(), |
||||
avg_factor=num_total_samples) |
||||
|
||||
return dict(loss_cls=loss_cls, loss_bbox=loss_bbox) |
||||
|
||||
def get_targets(self, |
||||
cls_scores_list, |
||||
bbox_preds_list, |
||||
anchor_list, |
||||
valid_flag_list, |
||||
gt_bboxes_list, |
||||
img_metas, |
||||
gt_bboxes_ignore_list=None, |
||||
gt_labels_list=None, |
||||
label_channels=1, |
||||
unmap_outputs=True): |
||||
"""Compute regression and classification targets for anchors in |
||||
multiple images. |
||||
|
||||
Args: |
||||
cls_scores_list (list[Tensor]): Classification scores of |
||||
each image. each is a 4D-tensor, the shape is |
||||
(h * w, num_anchors * num_classes). |
||||
bbox_preds_list (list[Tensor]): Bbox preds of each image. |
||||
each is a 4D-tensor, the shape is (h * w, num_anchors * 4). |
||||
anchor_list (list[Tensor]): Anchors of each image. Each element of |
||||
is a tensor of shape (h * w * num_anchors, 4). |
||||
valid_flag_list (list[Tensor]): Valid flags of each image. Each |
||||
element of is a tensor of shape (h * w * num_anchors, ) |
||||
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. |
||||
img_metas (list[dict]): Meta info of each image. |
||||
gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be |
||||
ignored. |
||||
gt_labels_list (list[Tensor]): Ground truth labels of each box. |
||||
label_channels (int): Channel of label. |
||||
unmap_outputs (bool): Whether to map outputs back to the original |
||||
set of anchors. |
||||
|
||||
Returns: |
||||
tuple: Usually returns a tuple containing learning targets. |
||||
|
||||
- batch_labels (Tensor): Label of all images. Each element \ |
||||
of is a tensor of shape (batch, h * w * num_anchors) |
||||
- batch_label_weights (Tensor): Label weights of all images \ |
||||
of is a tensor of shape (batch, h * w * num_anchors) |
||||
- num_total_pos (int): Number of positive samples in all \ |
||||
images. |
||||
- num_total_neg (int): Number of negative samples in all \ |
||||
images. |
||||
additional_returns: This function enables user-defined returns from |
||||
`self._get_targets_single`. These returns are currently refined |
||||
to properties at each feature map (i.e. having HxW dimension). |
||||
The results will be concatenated after the end |
||||
""" |
||||
num_imgs = len(img_metas) |
||||
assert len(anchor_list) == len(valid_flag_list) == num_imgs |
||||
|
||||
# compute targets for each image |
||||
if gt_bboxes_ignore_list is None: |
||||
gt_bboxes_ignore_list = [None for _ in range(num_imgs)] |
||||
if gt_labels_list is None: |
||||
gt_labels_list = [None for _ in range(num_imgs)] |
||||
results = multi_apply( |
||||
self._get_targets_single, |
||||
bbox_preds_list, |
||||
anchor_list, |
||||
valid_flag_list, |
||||
gt_bboxes_list, |
||||
gt_bboxes_ignore_list, |
||||
gt_labels_list, |
||||
img_metas, |
||||
label_channels=label_channels, |
||||
unmap_outputs=unmap_outputs) |
||||
(all_labels, all_label_weights, pos_inds_list, neg_inds_list, |
||||
sampling_results_list) = results[:5] |
||||
rest_results = list(results[5:]) # user-added return values |
||||
# no valid anchors |
||||
if any([labels is None for labels in all_labels]): |
||||
return None |
||||
# sampled anchors of all images |
||||
num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) |
||||
num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) |
||||
|
||||
batch_labels = torch.stack(all_labels, 0) |
||||
batch_label_weights = torch.stack(all_label_weights, 0) |
||||
|
||||
res = (batch_labels, batch_label_weights, num_total_pos, num_total_neg) |
||||
for i, rests in enumerate(rest_results): # user-added return values |
||||
rest_results[i] = torch.cat(rests, 0) |
||||
|
||||
return res + tuple(rest_results) |
||||
|
||||
def _get_targets_single(self, |
||||
bbox_preds, |
||||
flat_anchors, |
||||
valid_flags, |
||||
gt_bboxes, |
||||
gt_bboxes_ignore, |
||||
gt_labels, |
||||
img_meta, |
||||
label_channels=1, |
||||
unmap_outputs=True): |
||||
"""Compute regression and classification targets for anchors in a |
||||
single image. |
||||
|
||||
Args: |
||||
bbox_preds (Tensor): Bbox prediction of the image, which |
||||
shape is (h * w ,4) |
||||
flat_anchors (Tensor): Anchors of the image, which shape is |
||||
(h * w * num_anchors ,4) |
||||
valid_flags (Tensor): Valid flags of the image, which shape is |
||||
(h * w * num_anchors,). |
||||
gt_bboxes (Tensor): Ground truth bboxes of the image, |
||||
shape (num_gts, 4). |
||||
gt_bboxes_ignore (Tensor): Ground truth bboxes to be |
||||
ignored, shape (num_ignored_gts, 4). |
||||
img_meta (dict): Meta info of the image. |
||||
gt_labels (Tensor): Ground truth labels of each box, |
||||
shape (num_gts,). |
||||
label_channels (int): Channel of label. |
||||
unmap_outputs (bool): Whether to map outputs back to the original |
||||
set of anchors. |
||||
|
||||
Returns: |
||||
tuple: |
||||
labels (Tensor): Labels of image, which shape is |
||||
(h * w * num_anchors, ). |
||||
label_weights (Tensor): Label weights of image, which shape is |
||||
(h * w * num_anchors, ). |
||||
pos_inds (Tensor): Pos index of image. |
||||
neg_inds (Tensor): Neg index of image. |
||||
sampling_result (obj:`SamplingResult`): Sampling result. |
||||
pos_bbox_weights (Tensor): The Weight of using to calculate |
||||
the bbox branch loss, which shape is (num, ). |
||||
pos_predicted_boxes (Tensor): boxes predicted value of |
||||
using to calculate the bbox branch loss, which shape is |
||||
(num, 4). |
||||
pos_target_boxes (Tensor): boxes target value of |
||||
using to calculate the bbox branch loss, which shape is |
||||
(num, 4). |
||||
""" |
||||
inside_flags = anchor_inside_flags(flat_anchors, valid_flags, |
||||
img_meta['img_shape'][:2], |
||||
self.train_cfg.allowed_border) |
||||
if not inside_flags.any(): |
||||
return (None, ) * 8 |
||||
# assign gt and sample anchors |
||||
anchors = flat_anchors[inside_flags, :] |
||||
bbox_preds = bbox_preds.reshape(-1, 4) |
||||
bbox_preds = bbox_preds[inside_flags, :] |
||||
|
||||
# decoded bbox |
||||
decoder_bbox_preds = self.bbox_coder.decode(anchors, bbox_preds) |
||||
assign_result = self.assigner.assign( |
||||
decoder_bbox_preds, anchors, gt_bboxes, gt_bboxes_ignore, |
||||
None if self.sampling else gt_labels) |
||||
|
||||
pos_bbox_weights = assign_result.get_extra_property('pos_idx') |
||||
pos_predicted_boxes = assign_result.get_extra_property( |
||||
'pos_predicted_boxes') |
||||
pos_target_boxes = assign_result.get_extra_property('target_boxes') |
||||
|
||||
sampling_result = self.sampler.sample(assign_result, anchors, |
||||
gt_bboxes) |
||||
num_valid_anchors = anchors.shape[0] |
||||
labels = anchors.new_full((num_valid_anchors, ), |
||||
self.num_classes, |
||||
dtype=torch.long) |
||||
label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) |
||||
|
||||
pos_inds = sampling_result.pos_inds |
||||
neg_inds = sampling_result.neg_inds |
||||
if len(pos_inds) > 0: |
||||
if gt_labels is None: |
||||
# Only rpn gives gt_labels as None |
||||
# Foreground is the first class since v2.5.0 |
||||
labels[pos_inds] = 0 |
||||
else: |
||||
labels[pos_inds] = gt_labels[ |
||||
sampling_result.pos_assigned_gt_inds] |
||||
if self.train_cfg.pos_weight <= 0: |
||||
label_weights[pos_inds] = 1.0 |
||||
else: |
||||
label_weights[pos_inds] = self.train_cfg.pos_weight |
||||
if len(neg_inds) > 0: |
||||
label_weights[neg_inds] = 1.0 |
||||
|
||||
# map up to original set of anchors |
||||
if unmap_outputs: |
||||
num_total_anchors = flat_anchors.size(0) |
||||
labels = unmap( |
||||
labels, num_total_anchors, inside_flags, |
||||
fill=self.num_classes) # fill bg label |
||||
label_weights = unmap(label_weights, num_total_anchors, |
||||
inside_flags) |
||||
|
||||
return (labels, label_weights, pos_inds, neg_inds, sampling_result, |
||||
pos_bbox_weights, pos_predicted_boxes, pos_target_boxes) |
@ -0,0 +1,18 @@ |
||||
from ..builder import DETECTORS |
||||
from .single_stage import SingleStageDetector |
||||
|
||||
|
||||
@DETECTORS.register_module() |
||||
class YOLOF(SingleStageDetector): |
||||
r"""Implementation of `You Only Look One-level Feature |
||||
<https://arxiv.org/abs/2103.09460>`_""" |
||||
|
||||
def __init__(self, |
||||
backbone, |
||||
neck, |
||||
bbox_head, |
||||
train_cfg=None, |
||||
test_cfg=None, |
||||
pretrained=None): |
||||
super(YOLOF, self).__init__(backbone, neck, bbox_head, train_cfg, |
||||
test_cfg, pretrained) |
@ -0,0 +1,107 @@ |
||||
import torch.nn as nn |
||||
from mmcv.cnn import (ConvModule, caffe2_xavier_init, constant_init, is_norm, |
||||
normal_init) |
||||
from torch.nn import BatchNorm2d |
||||
|
||||
from ..builder import NECKS |
||||
|
||||
|
||||
class Bottleneck(nn.Module): |
||||
"""Bottleneck block for DilatedEncoder used in `YOLOF. |
||||
|
||||
<https://arxiv.org/abs/2103.09460>`. |
||||
|
||||
The Bottleneck contains three ConvLayers and one residual connection. |
||||
|
||||
Args: |
||||
in_channels (int): The number of input channels. |
||||
mid_channels (int): The number of middle output channels. |
||||
dilation (int): Dilation rate. |
||||
norm_cfg (dict): Dictionary to construct and config norm layer. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
in_channels, |
||||
mid_channels, |
||||
dilation, |
||||
norm_cfg=dict(type='BN', requires_grad=True)): |
||||
super(Bottleneck, self).__init__() |
||||
self.conv1 = ConvModule( |
||||
in_channels, mid_channels, 1, norm_cfg=norm_cfg) |
||||
self.conv2 = ConvModule( |
||||
mid_channels, |
||||
mid_channels, |
||||
3, |
||||
padding=dilation, |
||||
dilation=dilation, |
||||
norm_cfg=norm_cfg) |
||||
self.conv3 = ConvModule( |
||||
mid_channels, in_channels, 1, norm_cfg=norm_cfg) |
||||
|
||||
def forward(self, x): |
||||
identity = x |
||||
out = self.conv1(x) |
||||
out = self.conv2(out) |
||||
out = self.conv3(out) |
||||
out = out + identity |
||||
return out |
||||
|
||||
|
||||
@NECKS.register_module() |
||||
class DilatedEncoder(nn.Module): |
||||
"""Dilated Encoder for YOLOF <https://arxiv.org/abs/2103.09460>`. |
||||
|
||||
This module contains two types of components: |
||||
- the original FPN lateral convolution layer and fpn convolution layer, |
||||
which are 1x1 conv + 3x3 conv |
||||
- the dilated residual block |
||||
|
||||
Args: |
||||
in_channels (int): The number of input channels. |
||||
out_channels (int): The number of output channels. |
||||
block_mid_channels (int): The number of middle block output channels |
||||
num_residual_blocks (int): The number of residual blocks. |
||||
""" |
||||
|
||||
def __init__(self, in_channels, out_channels, block_mid_channels, |
||||
num_residual_blocks): |
||||
super(DilatedEncoder, self).__init__() |
||||
self.in_channels = in_channels |
||||
self.out_channels = out_channels |
||||
self.block_mid_channels = block_mid_channels |
||||
self.num_residual_blocks = num_residual_blocks |
||||
self.block_dilations = [2, 4, 6, 8] |
||||
self._init_layers() |
||||
|
||||
def _init_layers(self): |
||||
self.lateral_conv = nn.Conv2d( |
||||
self.in_channels, self.out_channels, kernel_size=1) |
||||
self.lateral_norm = BatchNorm2d(self.out_channels) |
||||
self.fpn_conv = nn.Conv2d( |
||||
self.out_channels, self.out_channels, kernel_size=3, padding=1) |
||||
self.fpn_norm = BatchNorm2d(self.out_channels) |
||||
encoder_blocks = [] |
||||
for i in range(self.num_residual_blocks): |
||||
dilation = self.block_dilations[i] |
||||
encoder_blocks.append( |
||||
Bottleneck( |
||||
self.out_channels, |
||||
self.block_mid_channels, |
||||
dilation=dilation)) |
||||
self.dilated_encoder_blocks = nn.Sequential(*encoder_blocks) |
||||
|
||||
def init_weights(self): |
||||
caffe2_xavier_init(self.lateral_conv) |
||||
caffe2_xavier_init(self.fpn_conv) |
||||
for m in [self.lateral_norm, self.fpn_norm]: |
||||
constant_init(m, 1) |
||||
for m in self.dilated_encoder_blocks.modules(): |
||||
if isinstance(m, nn.Conv2d): |
||||
normal_init(m, mean=0, std=0.01) |
||||
if is_norm(m): |
||||
constant_init(m, 1) |
||||
|
||||
def forward(self, feature): |
||||
out = self.lateral_norm(self.lateral_conv(feature[-1])) |
||||
out = self.fpn_norm(self.fpn_conv(out)) |
||||
return self.dilated_encoder_blocks(out), |
@ -0,0 +1,75 @@ |
||||
import mmcv |
||||
import torch |
||||
|
||||
from mmdet.models.dense_heads import YOLOFHead |
||||
|
||||
|
||||
def test_yolof_head_loss(): |
||||
"""Tests yolof head loss when truth is empty and non-empty.""" |
||||
s = 256 |
||||
img_metas = [{ |
||||
'img_shape': (s, s, 3), |
||||
'scale_factor': 1, |
||||
'pad_shape': (s, s, 3) |
||||
}] |
||||
train_cfg = mmcv.Config( |
||||
dict( |
||||
assigner=dict( |
||||
type='UniformAssigner', |
||||
pos_ignore_thr=0.15, |
||||
neg_ignore_thr=0.7), |
||||
allowed_border=-1, |
||||
pos_weight=-1, |
||||
debug=False)) |
||||
self = YOLOFHead( |
||||
num_classes=4, |
||||
in_channels=1, |
||||
reg_decoded_bbox=True, |
||||
train_cfg=train_cfg, |
||||
anchor_generator=dict( |
||||
type='AnchorGenerator', |
||||
ratios=[1.0], |
||||
scales=[1, 2, 4, 8, 16], |
||||
strides=[32]), |
||||
bbox_coder=dict( |
||||
type='DeltaXYWHBBoxCoder', |
||||
target_means=[.0, .0, .0, .0], |
||||
target_stds=[1., 1., 1., 1.], |
||||
add_ctr_clamp=True, |
||||
ctr_clamp=32), |
||||
loss_cls=dict( |
||||
type='FocalLoss', |
||||
use_sigmoid=True, |
||||
gamma=2.0, |
||||
alpha=0.25, |
||||
loss_weight=1.0), |
||||
loss_bbox=dict(type='GIoULoss', loss_weight=1.0)) |
||||
feat = [torch.rand(1, 1, s // 32, s // 32)] |
||||
cls_scores, bbox_preds = self.forward(feat) |
||||
|
||||
# Test that empty ground truth encourages the network to predict background |
||||
gt_bboxes = [torch.empty((0, 4))] |
||||
gt_labels = [torch.LongTensor([])] |
||||
gt_bboxes_ignore = None |
||||
empty_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, |
||||
img_metas, gt_bboxes_ignore) |
||||
# When there is no truth, the cls loss should be nonzero but there should |
||||
# be no box loss. |
||||
empty_cls_loss = empty_gt_losses['loss_cls'] |
||||
empty_box_loss = empty_gt_losses['loss_bbox'] |
||||
assert empty_cls_loss.item() > 0, 'cls loss should be non-zero' |
||||
assert empty_box_loss.item() == 0, ( |
||||
'there should be no box loss when there are no true boxes') |
||||
|
||||
# When truth is non-empty then both cls and box loss should be nonzero for |
||||
# random inputs |
||||
gt_bboxes = [ |
||||
torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), |
||||
] |
||||
gt_labels = [torch.LongTensor([2])] |
||||
one_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, |
||||
img_metas, gt_bboxes_ignore) |
||||
onegt_cls_loss = one_gt_losses['loss_cls'] |
||||
onegt_box_loss = one_gt_losses['loss_bbox'] |
||||
assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero' |
||||
assert onegt_box_loss.item() > 0, 'box loss should be non-zero' |
Loading…
Reference in new issue