[Feature]: Implement YOLOF (#4833)

* Add YOLOF inference

* Update

* Update Uniform match

* Update param

* Update cfg

* Fix bug about assigner

* Add reduce_mean

* update yolof

* update docstr

* add iter base config

* update code and add docstr

* fix comment and update docstr

* update transforms and add docstr

* Simplify the code.

* Enhancement random_shift and add unit test

* add rest of unit test and process empty gt.

* add README.md

* update comment

* Fix unittest

* update model link
pull/4778/head^2
Haian Huang(深度眸) 4 years ago committed by GitHub
parent 670ecc2546
commit 2a856efb6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 25
      configs/yolof/README.md
  2. 103
      configs/yolof/yolof_r50_c5_8x8_1x_coco.py
  3. 14
      configs/yolof/yolof_r50_c5_8x8_iter-1x_coco.py
  4. 3
      mmdet/core/bbox/assigners/__init__.py
  5. 134
      mmdet/core/bbox/assigners/uniform_assigner.py
  6. 44
      mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
  7. 5
      mmdet/datasets/pipelines/__init__.py
  8. 90
      mmdet/datasets/pipelines/transforms.py
  9. 3
      mmdet/models/dense_heads/__init__.py
  10. 415
      mmdet/models/dense_heads/yolof_head.py
  11. 3
      mmdet/models/detectors/__init__.py
  12. 18
      mmdet/models/detectors/yolof.py
  13. 3
      mmdet/models/necks/__init__.py
  14. 107
      mmdet/models/necks/dilated_encoder.py
  15. 40
      tests/test_data/test_pipelines/test_transform/test_transform.py
  16. 75
      tests/test_models/test_dense_heads/test_yolof_head.py
  17. 12
      tests/test_models/test_necks.py
  18. 75
      tests/test_utils/test_assigner.py
  19. 15
      tests/test_utils/test_coder.py

@ -0,0 +1,25 @@
# You Only Look One-level Feature
## Introduction
<!-- [ALGORITHM] -->
```
@inproceedings{chen2021you,
title={You Only Look One-level Feature},
author={Chen, Qiang and Wang, Yingming and Yang, Tong and Zhang, Xiangyu and Cheng, Jian and Sun, Jian},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
year={2021}
}
```
## Results and Models
| Backbone | Style | Epoch | Lr schd | Mem (GB) | box AP | Config | Download |
|:---------:|:-------:|:-------:|:-------:|:--------:|:------:|:------:|:--------:|
| R-50-C5 | caffe | Y | 1x | 8.3 | 37.5 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolof/yolof_r50_c5_8x8_1x_coco.py) |[model](http://download.openmmlab.com/mmdetection/v2.0/yolof/yolof_r50_c5_8x8_1x_coco/yolof_r50_c5_8x8_1x_coco_20210425_024427-8e864411.pth) &#124; [log](http://download.openmmlab.com/mmdetection/v2.0/yolof/yolof_r50_c5_8x8_1x_coco/yolof_r50_c5_8x8_1x_coco_20210425_024427.log.json) |
**Note**:
1. We find that the performance is unstable and may fluctuate by about 0.3 mAP. mAP 37.4 ~ 37.7 is acceptable in YOLOF_R_50_C5_1x. Such fluctuation can also be found in the [original implementation](https://github.com/chensnathan/YOLOF).
2. In addition to instability issues, sometimes there are large loss fluctuations and NAN, so there may still be problems with this project, which will be improved subsequently.

@ -0,0 +1,103 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
type='YOLOF',
pretrained='open-mmlab://detectron/resnet50_caffe',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(3, ),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='caffe'),
neck=dict(
type='DilatedEncoder',
in_channels=2048,
out_channels=512,
block_mid_channels=128,
num_residual_blocks=4),
bbox_head=dict(
type='YOLOFHead',
num_classes=80,
in_channels=512,
reg_decoded_bbox=True,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
scales=[1, 2, 4, 8, 16],
strides=[32]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1., 1., 1., 1.],
add_ctr_clamp=True,
ctr_clamp=32),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=1.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='UniformAssigner', pos_ignore_thr=0.15, neg_ignore_thr=0.7),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))
# optimizer
optimizer = dict(
type='SGD',
lr=0.12,
momentum=0.9,
weight_decay=0.0001,
paramwise_cfg=dict(
norm_decay_mult=0., custom_keys={'backbone': dict(lr_mult=1. / 3)}))
lr_config = dict(warmup_iters=1500, warmup_ratio=0.00066667)
# use caffe img_norm
img_norm_cfg = dict(
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='RandomShift', shift_ratio=0.5, max_shift_px=32),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=8,
workers_per_gpu=8,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

@ -0,0 +1,14 @@
_base_ = './yolof_r50_c5_8x8_1x_coco.py'
# We implemented the iter-based config according to the source code.
# COCO dataset has 117266 images after filtering. We use 8 gpu and
# 8 batch size training, so 22500 is equivalent to
# 22500/(117266/(8x8))=12.3 epoch, 15000 is equivalent to 8.2 epoch,
# 20000 is equivalent to 10.9 epoch. Due to lr(0.12) is large,
# the iter-based and epoch-based setting have about 0.2 difference on
# the mAP evaluation value.
lr_config = dict(step=[15000, 20000])
runner = dict(_delete_=True, type='IterBasedRunner', max_iters=22500)
checkpoint_config = dict(interval=2500)
evaluation = dict(interval=4500)
log_config = dict(interval=20)

@ -8,9 +8,10 @@ from .hungarian_assigner import HungarianAssigner
from .max_iou_assigner import MaxIoUAssigner
from .point_assigner import PointAssigner
from .region_assigner import RegionAssigner
from .uniform_assigner import UniformAssigner
__all__ = [
'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult',
'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner',
'HungarianAssigner', 'RegionAssigner'
'HungarianAssigner', 'RegionAssigner', 'UniformAssigner'
]

@ -0,0 +1,134 @@
import torch
from ..builder import BBOX_ASSIGNERS
from ..iou_calculators import build_iou_calculator
from ..transforms import bbox_xyxy_to_cxcywh
from .assign_result import AssignResult
from .base_assigner import BaseAssigner
@BBOX_ASSIGNERS.register_module()
class UniformAssigner(BaseAssigner):
"""Uniform Matching between the anchors and gt boxes, which can achieve
balance in positive anchors, and gt_bboxes_ignore was not considered for
now.
Args:
pos_ignore_thr (float): the threshold to ignore positive anchors
neg_ignore_thr (float): the threshold to ignore negative anchors
match_times(int): Number of positive anchors for each gt box.
Default 4.
iou_calculator (dict): iou_calculator config
"""
def __init__(self,
pos_ignore_thr,
neg_ignore_thr,
match_times=4,
iou_calculator=dict(type='BboxOverlaps2D')):
self.match_times = match_times
self.pos_ignore_thr = pos_ignore_thr
self.neg_ignore_thr = neg_ignore_thr
self.iou_calculator = build_iou_calculator(iou_calculator)
def assign(self,
bbox_pred,
anchor,
gt_bboxes,
gt_bboxes_ignore=None,
gt_labels=None):
num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)
# 1. assign -1 by default
assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),
0,
dtype=torch.long)
assigned_labels = bbox_pred.new_full((num_bboxes, ),
-1,
dtype=torch.long)
if num_gts == 0 or num_bboxes == 0:
# No ground truth or boxes, return empty assignment
if num_gts == 0:
# No ground truth, assign all to background
assigned_gt_inds[:] = 0
assign_result = AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels)
assign_result.set_extra_property(
'pos_idx', bbox_pred.new_empty(0, dtype=torch.bool))
assign_result.set_extra_property('pos_predicted_boxes',
bbox_pred.new_empty((0, 4)))
assign_result.set_extra_property('target_boxes',
bbox_pred.new_empty((0, 4)))
return assign_result
# 2. Compute the L1 cost between boxes
# Note that we use anchors and predict boxes both
cost_bbox = torch.cdist(
bbox_xyxy_to_cxcywh(bbox_pred),
bbox_xyxy_to_cxcywh(gt_bboxes),
p=1)
cost_bbox_anchors = torch.cdist(
bbox_xyxy_to_cxcywh(anchor), bbox_xyxy_to_cxcywh(gt_bboxes), p=1)
# We found that topk function has different results in cpu and
# cuda mode. In order to ensure consistency with the source code,
# we also use cpu mode.
# TODO: Check whether the performance of cpu and cuda are the same.
C = cost_bbox.cpu()
C1 = cost_bbox_anchors.cpu()
# self.match_times x n
index = torch.topk(
C, # c=b,n,x c[i]=n,x
k=self.match_times,
dim=0,
largest=False)[1]
# self.match_times x n
index1 = torch.topk(C1, k=self.match_times, dim=0, largest=False)[1]
# (self.match_times*2) x n
indexes = torch.cat((index, index1),
dim=1).reshape(-1).to(bbox_pred.device)
pred_overlaps = self.iou_calculator(bbox_pred, gt_bboxes)
anchor_overlaps = self.iou_calculator(anchor, gt_bboxes)
pred_max_overlaps, _ = pred_overlaps.max(dim=1)
anchor_max_overlaps, _ = anchor_overlaps.max(dim=0)
# 3. Compute the ignore indexes use gt_bboxes and predict boxes
ignore_idx = pred_max_overlaps > self.neg_ignore_thr
assigned_gt_inds[ignore_idx] = -1
# 4. Compute the ignore indexes of positive sample use anchors
# and predict boxes
pos_gt_index = torch.arange(
0, C1.size(1),
device=bbox_pred.device).repeat(self.match_times * 2)
pos_ious = anchor_overlaps[indexes, pos_gt_index]
pos_ignore_idx = pos_ious < self.pos_ignore_thr
pos_gt_index_with_ignore = pos_gt_index + 1
pos_gt_index_with_ignore[pos_ignore_idx] = -1
assigned_gt_inds[indexes] = pos_gt_index_with_ignore
if gt_labels is not None:
assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
pos_inds = torch.nonzero(
assigned_gt_inds > 0, as_tuple=False).squeeze()
if pos_inds.numel() > 0:
assigned_labels[pos_inds] = gt_labels[
assigned_gt_inds[pos_inds] - 1]
else:
assigned_labels = None
assign_result = AssignResult(
num_gts,
assigned_gt_inds,
anchor_max_overlaps,
labels=assigned_labels)
assign_result.set_extra_property('pos_idx', ~pos_ignore_idx)
assign_result.set_extra_property('pos_predicted_boxes',
bbox_pred[indexes])
assign_result.set_extra_property('target_boxes',
gt_bboxes[pos_gt_index])
return assign_result

@ -21,16 +21,25 @@ class DeltaXYWHBBoxCoder(BaseBBoxCoder):
target for delta coordinates
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
add_ctr_clamp (bool): Whether to add center clamp, when added, the
predicted box is clamped is its center is too far away from
the original anchor's center. Only used by YOLOF. Default False.
ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
Default 32.
"""
def __init__(self,
target_means=(0., 0., 0., 0.),
target_stds=(1., 1., 1., 1.),
clip_border=True):
clip_border=True,
add_ctr_clamp=False,
ctr_clamp=32):
super(BaseBBoxCoder, self).__init__()
self.means = target_means
self.stds = target_stds
self.clip_border = clip_border
self.add_ctr_clamp = add_ctr_clamp
self.ctr_clamp = ctr_clamp
def encode(self, bboxes, gt_bboxes):
"""Get box regression transformation deltas that can be used to
@ -79,7 +88,8 @@ class DeltaXYWHBBoxCoder(BaseBBoxCoder):
if pred_bboxes.ndim == 3:
assert pred_bboxes.size(1) == bboxes.size(1)
decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means, self.stds,
max_shape, wh_ratio_clip, self.clip_border)
max_shape, wh_ratio_clip, self.clip_border,
self.add_ctr_clamp, self.ctr_clamp)
return decoded_bboxes
@ -137,7 +147,9 @@ def delta2bbox(rois,
stds=(1., 1., 1., 1.),
max_shape=None,
wh_ratio_clip=16 / 1000,
clip_border=True):
clip_border=True,
add_ctr_clamp=False,
ctr_clamp=32):
"""Apply deltas to shift/scale base boxes.
Typically the rois are anchor or proposed bounding boxes and the deltas are
@ -161,6 +173,11 @@ def delta2bbox(rois,
wh_ratio_clip (float): Maximum aspect ratio for boxes.
clip_border (bool, optional): Whether clip the objects outside the
border of the image. Defaults to True.
add_ctr_clamp (bool): Whether to add center clamp, when added, the
predicted box is clamped is its center is too far away from
the original anchor's center. Only used by YOLOF. Default False.
ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
Default 32.
Returns:
Tensor: Boxes with shape (B, N, num_classes * 4) or (B, N, 4) or
@ -194,9 +211,7 @@ def delta2bbox(rois,
dy = denorm_deltas[..., 1::4]
dw = denorm_deltas[..., 2::4]
dh = denorm_deltas[..., 3::4]
max_ratio = np.abs(np.log(wh_ratio_clip))
dw = dw.clamp(min=-max_ratio, max=max_ratio)
dh = dh.clamp(min=-max_ratio, max=max_ratio)
x1, y1 = rois[..., 0], rois[..., 1]
x2, y2 = rois[..., 2], rois[..., 3]
# Compute center of each roi
@ -205,12 +220,25 @@ def delta2bbox(rois,
# Compute width/height of each roi
pw = (x2 - x1).unsqueeze(-1).expand_as(dw)
ph = (y2 - y1).unsqueeze(-1).expand_as(dh)
dx_width = pw * dx
dy_height = ph * dy
max_ratio = np.abs(np.log(wh_ratio_clip))
if add_ctr_clamp:
dx_width = torch.clamp(dx_width, max=ctr_clamp, min=-ctr_clamp)
dy_height = torch.clamp(dy_height, max=ctr_clamp, min=-ctr_clamp)
dw = torch.clamp(dw, max=max_ratio)
dh = torch.clamp(dh, max=max_ratio)
else:
dw = dw.clamp(min=-max_ratio, max=max_ratio)
dh = dh.clamp(min=-max_ratio, max=max_ratio)
# Use exp(network energy) to enlarge/shrink each roi
gw = pw * dw.exp()
gh = ph * dh.exp()
# Use network energy to shift the center of each roi
gx = px + pw * dx
gy = py + ph * dy
gx = px + dx_width
gy = py + dy_height
# Convert center-xy/width/height to top-left, bottom-right
x1 = gx - gw * 0.5
y1 = gy - gh * 0.5

@ -10,7 +10,8 @@ from .loading import (LoadAnnotations, LoadImageFromFile, LoadImageFromWebcam,
from .test_time_aug import MultiScaleFlipAug
from .transforms import (Albu, CutOut, Expand, MinIoURandomCrop, Normalize,
Pad, PhotoMetricDistortion, RandomCenterCropPad,
RandomCrop, RandomFlip, Resize, SegRescale)
RandomCrop, RandomFlip, RandomShift, Resize,
SegRescale)
__all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
@ -21,5 +22,5 @@ __all__ = [
'MinIoURandomCrop', 'Expand', 'PhotoMetricDistortion', 'Albu',
'InstaBoost', 'RandomCenterCropPad', 'AutoAugment', 'CutOut', 'Shear',
'Rotate', 'ColorTransform', 'EqualizeTransform', 'BrightnessTransform',
'ContrastTransform', 'Translate'
'ContrastTransform', 'Translate', 'RandomShift'
]

@ -472,6 +472,96 @@ class RandomFlip(object):
return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})'
@PIPELINES.register_module()
class RandomShift(object):
"""Shift the image and box given shift pixels and probability.
Args:
shift_ratio (float): Probability of shifts. Default 0.5.
max_shift_px (int): The max pixels for shifting. Default 32.
filter_thr_px (int): The width and height threshold for filtering.
The bbox and the rest of the targets below the width and
height threshold will be filtered. Default 1.
"""
def __init__(self, shift_ratio=0.5, max_shift_px=32, filter_thr_px=1):
assert 0 <= shift_ratio <= 1
assert max_shift_px >= 0
self.shift_ratio = shift_ratio
self.max_shift_px = max_shift_px
self.filter_thr_px = int(filter_thr_px)
# The key correspondence from bboxes to labels.
self.bbox2label = {
'gt_bboxes': 'gt_labels',
'gt_bboxes_ignore': 'gt_labels_ignore'
}
def __call__(self, results):
"""Call function to random shift images, bounding boxes.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Shift results.
"""
if random.random() < self.shift_ratio:
img_shape = results['img'].shape[:2]
random_shift_x = random.randint(-self.max_shift_px,
self.max_shift_px)
random_shift_y = random.randint(-self.max_shift_px,
self.max_shift_px)
new_x = max(0, random_shift_x)
orig_x = max(0, -random_shift_x)
new_y = max(0, random_shift_y)
orig_y = max(0, -random_shift_y)
# TODO: support mask and semantic segmentation maps.
for key in results.get('bbox_fields', []):
bboxes = results[key].copy()
bboxes[..., 0::2] += random_shift_x
bboxes[..., 1::2] += random_shift_y
# clip border
bboxes[..., 0::2] = np.clip(bboxes[..., 0::2], 0, img_shape[1])
bboxes[..., 1::2] = np.clip(bboxes[..., 1::2], 0, img_shape[0])
# remove invalid bboxes
bbox_w = bboxes[..., 2] - bboxes[..., 0]
bbox_h = bboxes[..., 3] - bboxes[..., 1]
valid_inds = (bbox_w > self.filter_thr_px) & (
bbox_h > self.filter_thr_px)
# If the shift does not contain any gt-bbox area, skip this
# image.
if key == 'gt_bboxes' and not valid_inds.any():
return results
bboxes = bboxes[valid_inds]
results[key] = bboxes
# label fields. e.g. gt_labels and gt_labels_ignore
label_key = self.bbox2label.get(key)
if label_key in results:
results[label_key] = results[label_key][valid_inds]
for key in results.get('img_fields', ['img']):
img = results[key]
new_img = np.zeros_like(img)
img_h, img_w = img.shape[:2]
new_h = img_h - np.abs(random_shift_y)
new_w = img_w - np.abs(random_shift_x)
new_img[new_y:new_y + new_h, new_x:new_x + new_w] \
= img[orig_y:orig_y + new_h, orig_x:orig_x + new_w]
results[key] = new_img
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(max_shift_px={self.max_shift_px}, '
return repr_str
@PIPELINES.register_module()
class Pad(object):
"""Pad the image & mask.

@ -29,6 +29,7 @@ from .ssd_head import SSDHead
from .vfnet_head import VFNetHead
from .yolact_head import YOLACTHead, YOLACTProtonet, YOLACTSegmHead
from .yolo_head import YOLOV3Head
from .yolof_head import YOLOFHead
__all__ = [
'AnchorFreeHead', 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption',
@ -39,5 +40,5 @@ __all__ = [
'YOLACTSegmHead', 'YOLACTProtonet', 'YOLOV3Head', 'PAAHead',
'SABLRetinaHead', 'CentripetalHead', 'VFNetHead', 'StageCascadeRPNHead',
'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead',
'AutoAssignHead', 'DETRHead'
'AutoAssignHead', 'DETRHead', 'YOLOFHead'
]

@ -0,0 +1,415 @@
import torch
import torch.nn as nn
from mmcv.cnn import (ConvModule, bias_init_with_prob, constant_init, is_norm,
normal_init)
from mmcv.runner import force_fp32
from mmdet.core import anchor_inside_flags, multi_apply, reduce_mean, unmap
from ..builder import HEADS
from .anchor_head import AnchorHead
INF = 1e8
def levels_to_images(mlvl_tensor):
"""Concat multi-level feature maps by image.
[feature_level0, feature_level1...] -> [feature_image0, feature_image1...]
Convert the shape of each element in mlvl_tensor from (N, C, H, W) to
(N, H*W , C), then split the element to N elements with shape (H*W, C), and
concat elements in same image of all level along first dimension.
Args:
mlvl_tensor (list[torch.Tensor]): list of Tensor which collect from
corresponding level. Each element is of shape (N, C, H, W)
Returns:
list[torch.Tensor]: A list that contains N tensors and each tensor is
of shape (num_elements, C)
"""
batch_size = mlvl_tensor[0].size(0)
batch_list = [[] for _ in range(batch_size)]
channels = mlvl_tensor[0].size(1)
for t in mlvl_tensor:
t = t.permute(0, 2, 3, 1)
t = t.view(batch_size, -1, channels).contiguous()
for img in range(batch_size):
batch_list[img].append(t[img])
return [torch.cat(item, 0) for item in batch_list]
@HEADS.register_module()
class YOLOFHead(AnchorHead):
"""YOLOFHead Paper link: https://arxiv.org/abs/2103.09460.
Args:
num_classes (int): The number of object classes (w/o background)
in_channels (List[int]): The number of input channels per scale.
cls_num_convs (int): The number of convolutions of cls branch.
Default 2.
reg_num_convs (int): The number of convolutions of reg branch.
Default 4.
norm_cfg (dict): Dictionary to construct and config norm layer.
"""
def __init__(self,
num_classes,
in_channels,
num_cls_convs=2,
num_reg_convs=4,
norm_cfg=dict(type='BN', requires_grad=True),
**kwargs):
self.num_cls_convs = num_cls_convs
self.num_reg_convs = num_reg_convs
self.norm_cfg = norm_cfg
super(YOLOFHead, self).__init__(num_classes, in_channels, **kwargs)
def _init_layers(self):
cls_subnet = []
bbox_subnet = []
for i in range(self.num_cls_convs):
cls_subnet.append(
ConvModule(
self.in_channels,
self.in_channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg))
for i in range(self.num_reg_convs):
bbox_subnet.append(
ConvModule(
self.in_channels,
self.in_channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg))
self.cls_subnet = nn.Sequential(*cls_subnet)
self.bbox_subnet = nn.Sequential(*bbox_subnet)
self.cls_score = nn.Conv2d(
self.in_channels,
self.num_anchors * self.num_classes,
kernel_size=3,
stride=1,
padding=1)
self.bbox_pred = nn.Conv2d(
self.in_channels,
self.num_anchors * 4,
kernel_size=3,
stride=1,
padding=1)
self.object_pred = nn.Conv2d(
self.in_channels,
self.num_anchors,
kernel_size=3,
stride=1,
padding=1)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, mean=0, std=0.01)
if is_norm(m):
constant_init(m, 1)
# Use prior in model initialization to improve stability
bias_cls = bias_init_with_prob(0.01)
torch.nn.init.constant_(self.cls_score.bias, bias_cls)
def forward_single(self, feature):
cls_score = self.cls_score(self.cls_subnet(feature))
N, _, H, W = cls_score.shape
cls_score = cls_score.view(N, -1, self.num_classes, H, W)
reg_feat = self.bbox_subnet(feature)
bbox_reg = self.bbox_pred(reg_feat)
objectness = self.object_pred(reg_feat)
# implicit objectness
objectness = objectness.view(N, -1, 1, H, W)
normalized_cls_score = cls_score + objectness - torch.log(
1. + torch.clamp(cls_score.exp(), max=INF) +
torch.clamp(objectness.exp(), max=INF))
normalized_cls_score = normalized_cls_score.view(N, -1, H, W)
return normalized_cls_score, bbox_reg
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
gt_bboxes_ignore=None):
"""Compute losses of the head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (batch, num_anchors * num_classes, h, w)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (batch, num_anchors * 4, h, w)
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (None | list[Tensor]): specify which bounding
boxes can be ignored when computing the loss. Default: None
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert len(cls_scores) == 1
assert self.anchor_generator.num_levels == 1
device = cls_scores[0].device
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas, device=device)
# The output level is always 1
anchor_list = [anchors[0] for anchors in anchor_list]
valid_flag_list = [valid_flags[0] for valid_flags in valid_flag_list]
cls_scores_list = levels_to_images(cls_scores)
bbox_preds_list = levels_to_images(bbox_preds)
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = self.get_targets(
cls_scores_list,
bbox_preds_list,
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels)
if cls_reg_targets is None:
return None
(batch_labels, batch_label_weights, num_total_pos, num_total_neg,
batch_bbox_weights, batch_pos_predicted_boxes,
batch_target_boxes) = cls_reg_targets
flatten_labels = batch_labels.reshape(-1)
batch_label_weights = batch_label_weights.reshape(-1)
cls_score = cls_scores[0].permute(0, 2, 3,
1).reshape(-1, self.cls_out_channels)
num_total_samples = (num_total_pos +
num_total_neg) if self.sampling else num_total_pos
num_total_samples = reduce_mean(
cls_score.new_tensor(num_total_samples)).clamp_(1.0).item()
# classification loss
loss_cls = self.loss_cls(
cls_score,
flatten_labels,
batch_label_weights,
avg_factor=num_total_samples)
# regression loss
if batch_pos_predicted_boxes.shape[0] == 0:
# no pos sample
loss_bbox = batch_pos_predicted_boxes.sum() * 0
else:
loss_bbox = self.loss_bbox(
batch_pos_predicted_boxes,
batch_target_boxes,
batch_bbox_weights.float(),
avg_factor=num_total_samples)
return dict(loss_cls=loss_cls, loss_bbox=loss_bbox)
def get_targets(self,
cls_scores_list,
bbox_preds_list,
anchor_list,
valid_flag_list,
gt_bboxes_list,
img_metas,
gt_bboxes_ignore_list=None,
gt_labels_list=None,
label_channels=1,
unmap_outputs=True):
"""Compute regression and classification targets for anchors in
multiple images.
Args:
cls_scores_list (list[Tensor]) Classification scores of
each image. each is a 4D-tensor, the shape is
(h * w, num_anchors * num_classes).
bbox_preds_list (list[Tensor]) Bbox preds of each image.
each is a 4D-tensor, the shape is (h * w, num_anchors * 4).
anchor_list (list[Tensor]): Anchors of each image. Each element of
is a tensor of shape (h * w * num_anchors, 4).
valid_flag_list (list[Tensor]): Valid flags of each image. Each
element of is a tensor of shape (h * w * num_anchors, )
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
img_metas (list[dict]): Meta info of each image.
gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
ignored.
gt_labels_list (list[Tensor]): Ground truth labels of each box.
label_channels (int): Channel of label.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors.
Returns:
tuple: Usually returns a tuple containing learning targets.
- batch_labels (Tensor): Label of all images. Each element \
of is a tensor of shape (batch, h * w * num_anchors)
- batch_label_weights (Tensor): Label weights of all images \
of is a tensor of shape (batch, h * w * num_anchors)
- num_total_pos (int): Number of positive samples in all \
images.
- num_total_neg (int): Number of negative samples in all \
images.
additional_returns: This function enables user-defined returns from
`self._get_targets_single`. These returns are currently refined
to properties at each feature map (i.e. having HxW dimension).
The results will be concatenated after the end
"""
num_imgs = len(img_metas)
assert len(anchor_list) == len(valid_flag_list) == num_imgs
# compute targets for each image
if gt_bboxes_ignore_list is None:
gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
if gt_labels_list is None:
gt_labels_list = [None for _ in range(num_imgs)]
results = multi_apply(
self._get_targets_single,
bbox_preds_list,
anchor_list,
valid_flag_list,
gt_bboxes_list,
gt_bboxes_ignore_list,
gt_labels_list,
img_metas,
label_channels=label_channels,
unmap_outputs=unmap_outputs)
(all_labels, all_label_weights, pos_inds_list, neg_inds_list,
sampling_results_list) = results[:5]
rest_results = list(results[5:]) # user-added return values
# no valid anchors
if any([labels is None for labels in all_labels]):
return None
# sampled anchors of all images
num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
batch_labels = torch.stack(all_labels, 0)
batch_label_weights = torch.stack(all_label_weights, 0)
res = (batch_labels, batch_label_weights, num_total_pos, num_total_neg)
for i, rests in enumerate(rest_results): # user-added return values
rest_results[i] = torch.cat(rests, 0)
return res + tuple(rest_results)
def _get_targets_single(self,
bbox_preds,
flat_anchors,
valid_flags,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
img_meta,
label_channels=1,
unmap_outputs=True):
"""Compute regression and classification targets for anchors in a
single image.
Args:
bbox_preds (Tensor): Bbox prediction of the image, which
shape is (h * w ,4)
flat_anchors (Tensor): Anchors of the image, which shape is
(h * w * num_anchors ,4)
valid_flags (Tensor): Valid flags of the image, which shape is
(h * w * num_anchors,).
gt_bboxes (Tensor): Ground truth bboxes of the image,
shape (num_gts, 4).
gt_bboxes_ignore (Tensor): Ground truth bboxes to be
ignored, shape (num_ignored_gts, 4).
img_meta (dict): Meta info of the image.
gt_labels (Tensor): Ground truth labels of each box,
shape (num_gts,).
label_channels (int): Channel of label.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors.
Returns:
tuple:
labels (Tensor): Labels of image, which shape is
(h * w * num_anchors, ).
label_weights (Tensor): Label weights of image, which shape is
(h * w * num_anchors, ).
pos_inds (Tensor): Pos index of image.
neg_inds (Tensor): Neg index of image.
sampling_result (obj:`SamplingResult`): Sampling result.
pos_bbox_weights (Tensor): The Weight of using to calculate
the bbox branch loss, which shape is (num, ).
pos_predicted_boxes (Tensor): boxes predicted value of
using to calculate the bbox branch loss, which shape is
(num, 4).
pos_target_boxes (Tensor): boxes target value of
using to calculate the bbox branch loss, which shape is
(num, 4).
"""
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
img_meta['img_shape'][:2],
self.train_cfg.allowed_border)
if not inside_flags.any():
return (None, ) * 8
# assign gt and sample anchors
anchors = flat_anchors[inside_flags, :]
bbox_preds = bbox_preds.reshape(-1, 4)
bbox_preds = bbox_preds[inside_flags, :]
# decoded bbox
decoder_bbox_preds = self.bbox_coder.decode(anchors, bbox_preds)
assign_result = self.assigner.assign(
decoder_bbox_preds, anchors, gt_bboxes, gt_bboxes_ignore,
None if self.sampling else gt_labels)
pos_bbox_weights = assign_result.get_extra_property('pos_idx')
pos_predicted_boxes = assign_result.get_extra_property(
'pos_predicted_boxes')
pos_target_boxes = assign_result.get_extra_property('target_boxes')
sampling_result = self.sampler.sample(assign_result, anchors,
gt_bboxes)
num_valid_anchors = anchors.shape[0]
labels = anchors.new_full((num_valid_anchors, ),
self.num_classes,
dtype=torch.long)
label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
if gt_labels is None:
# Only rpn gives gt_labels as None
# Foreground is the first class since v2.5.0
labels[pos_inds] = 0
else:
labels[pos_inds] = gt_labels[
sampling_result.pos_assigned_gt_inds]
if self.train_cfg.pos_weight <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = self.train_cfg.pos_weight
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
# map up to original set of anchors
if unmap_outputs:
num_total_anchors = flat_anchors.size(0)
labels = unmap(
labels, num_total_anchors, inside_flags,
fill=self.num_classes) # fill bg label
label_weights = unmap(label_weights, num_total_anchors,
inside_flags)
return (labels, label_weights, pos_inds, neg_inds, sampling_result,
pos_bbox_weights, pos_predicted_boxes, pos_target_boxes)

@ -29,6 +29,7 @@ from .two_stage import TwoStageDetector
from .vfnet import VFNet
from .yolact import YOLACT
from .yolo import YOLOV3
from .yolof import YOLOF
__all__ = [
'ATSS', 'BaseDetector', 'SingleStageDetector',
@ -37,5 +38,5 @@ __all__ = [
'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector',
'FOVEA', 'FSAF', 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA',
'YOLOV3', 'YOLACT', 'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN',
'SCNet', 'AutoAssign'
'SCNet', 'YOLOF', 'AutoAssign'
]

@ -0,0 +1,18 @@
from ..builder import DETECTORS
from .single_stage import SingleStageDetector
@DETECTORS.register_module()
class YOLOF(SingleStageDetector):
r"""Implementation of `You Only Look One-level Feature
<https://arxiv.org/abs/2103.09460>`_"""
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(YOLOF, self).__init__(backbone, neck, bbox_head, train_cfg,
test_cfg, pretrained)

@ -1,5 +1,6 @@
from .bfp import BFP
from .channel_mapper import ChannelMapper
from .dilated_encoder import DilatedEncoder
from .fpg import FPG
from .fpn import FPN
from .fpn_carafe import FPN_CARAFE
@ -12,5 +13,5 @@ from .yolo_neck import YOLOV3Neck
__all__ = [
'FPN', 'BFP', 'ChannelMapper', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN',
'NASFCOS_FPN', 'RFP', 'YOLOV3Neck', 'FPG'
'NASFCOS_FPN', 'RFP', 'YOLOV3Neck', 'FPG', 'DilatedEncoder'
]

@ -0,0 +1,107 @@
import torch.nn as nn
from mmcv.cnn import (ConvModule, caffe2_xavier_init, constant_init, is_norm,
normal_init)
from torch.nn import BatchNorm2d
from ..builder import NECKS
class Bottleneck(nn.Module):
"""Bottleneck block for DilatedEncoder used in `YOLOF.
<https://arxiv.org/abs/2103.09460>`.
The Bottleneck contains three ConvLayers and one residual connection.
Args:
in_channels (int): The number of input channels.
mid_channels (int): The number of middle output channels.
dilation (int): Dilation rate.
norm_cfg (dict): Dictionary to construct and config norm layer.
"""
def __init__(self,
in_channels,
mid_channels,
dilation,
norm_cfg=dict(type='BN', requires_grad=True)):
super(Bottleneck, self).__init__()
self.conv1 = ConvModule(
in_channels, mid_channels, 1, norm_cfg=norm_cfg)
self.conv2 = ConvModule(
mid_channels,
mid_channels,
3,
padding=dilation,
dilation=dilation,
norm_cfg=norm_cfg)
self.conv3 = ConvModule(
mid_channels, in_channels, 1, norm_cfg=norm_cfg)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
out = out + identity
return out
@NECKS.register_module()
class DilatedEncoder(nn.Module):
"""Dilated Encoder for YOLOF <https://arxiv.org/abs/2103.09460>`.
This module contains two types of components:
- the original FPN lateral convolution layer and fpn convolution layer,
which are 1x1 conv + 3x3 conv
- the dilated residual block
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
block_mid_channels (int): The number of middle block output channels
num_residual_blocks (int): The number of residual blocks.
"""
def __init__(self, in_channels, out_channels, block_mid_channels,
num_residual_blocks):
super(DilatedEncoder, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.block_mid_channels = block_mid_channels
self.num_residual_blocks = num_residual_blocks
self.block_dilations = [2, 4, 6, 8]
self._init_layers()
def _init_layers(self):
self.lateral_conv = nn.Conv2d(
self.in_channels, self.out_channels, kernel_size=1)
self.lateral_norm = BatchNorm2d(self.out_channels)
self.fpn_conv = nn.Conv2d(
self.out_channels, self.out_channels, kernel_size=3, padding=1)
self.fpn_norm = BatchNorm2d(self.out_channels)
encoder_blocks = []
for i in range(self.num_residual_blocks):
dilation = self.block_dilations[i]
encoder_blocks.append(
Bottleneck(
self.out_channels,
self.block_mid_channels,
dilation=dilation))
self.dilated_encoder_blocks = nn.Sequential(*encoder_blocks)
def init_weights(self):
caffe2_xavier_init(self.lateral_conv)
caffe2_xavier_init(self.fpn_conv)
for m in [self.lateral_norm, self.fpn_norm]:
constant_init(m, 1)
for m in self.dilated_encoder_blocks.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, mean=0, std=0.01)
if is_norm(m):
constant_init(m, 1)
def forward(self, feature):
out = self.lateral_norm(self.lateral_conv(feature[-1]))
out = self.fpn_norm(self.fpn_conv(out))
return self.dilated_encoder_blocks(out),

@ -750,3 +750,43 @@ def test_cutout():
cutout_module = build_from_cfg(transform, PIPELINES)
cutout_result = cutout_module(copy.deepcopy(results))
assert cutout_result['img'].sum() > img.sum()
def test_random_shift():
# test assertion for invalid shift_ratio
with pytest.raises(AssertionError):
transform = dict(type='RandomShift', shift_ratio=1.5)
build_from_cfg(transform, PIPELINES)
# test assertion for invalid max_shift_px
with pytest.raises(AssertionError):
transform = dict(type='RandomShift', max_shift_px=-1)
build_from_cfg(transform, PIPELINES)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../../../data/color.jpg'), 'color')
results['img'] = img
# TODO: add img_fields test
results['bbox_fields'] = ['gt_bboxes', 'gt_bboxes_ignore']
def create_random_bboxes(num_bboxes, img_w, img_h):
bboxes_left_top = np.random.uniform(0, 0.5, size=(num_bboxes, 2))
bboxes_right_bottom = np.random.uniform(0.5, 1, size=(num_bboxes, 2))
bboxes = np.concatenate((bboxes_left_top, bboxes_right_bottom), 1)
bboxes = (bboxes * np.array([img_w, img_h, img_w, img_h])).astype(
np.int)
return bboxes
h, w, _ = img.shape
gt_bboxes = create_random_bboxes(8, w, h)
gt_bboxes_ignore = create_random_bboxes(2, w, h)
results['gt_labels'] = torch.ones(gt_bboxes.shape[0])
results['gt_bboxes'] = gt_bboxes
results['gt_bboxes_ignore'] = gt_bboxes_ignore
transform = dict(type='RandomShift', shift_ratio=1.0)
random_shift_module = build_from_cfg(transform, PIPELINES)
results = random_shift_module(results)
assert results['img'].shape[:2] == (h, w)
assert results['gt_labels'].shape[0] == results['gt_bboxes'].shape[0]

@ -0,0 +1,75 @@
import mmcv
import torch
from mmdet.models.dense_heads import YOLOFHead
def test_yolof_head_loss():
"""Tests yolof head loss when truth is empty and non-empty."""
s = 256
img_metas = [{
'img_shape': (s, s, 3),
'scale_factor': 1,
'pad_shape': (s, s, 3)
}]
train_cfg = mmcv.Config(
dict(
assigner=dict(
type='UniformAssigner',
pos_ignore_thr=0.15,
neg_ignore_thr=0.7),
allowed_border=-1,
pos_weight=-1,
debug=False))
self = YOLOFHead(
num_classes=4,
in_channels=1,
reg_decoded_bbox=True,
train_cfg=train_cfg,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
scales=[1, 2, 4, 8, 16],
strides=[32]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1., 1., 1., 1.],
add_ctr_clamp=True,
ctr_clamp=32),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=1.0))
feat = [torch.rand(1, 1, s // 32, s // 32)]
cls_scores, bbox_preds = self.forward(feat)
# Test that empty ground truth encourages the network to predict background
gt_bboxes = [torch.empty((0, 4))]
gt_labels = [torch.LongTensor([])]
gt_bboxes_ignore = None
empty_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels,
img_metas, gt_bboxes_ignore)
# When there is no truth, the cls loss should be nonzero but there should
# be no box loss.
empty_cls_loss = empty_gt_losses['loss_cls']
empty_box_loss = empty_gt_losses['loss_bbox']
assert empty_cls_loss.item() > 0, 'cls loss should be non-zero'
assert empty_box_loss.item() == 0, (
'there should be no box loss when there are no true boxes')
# When truth is non-empty then both cls and box loss should be nonzero for
# random inputs
gt_bboxes = [
torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
]
gt_labels = [torch.LongTensor([2])]
one_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels,
img_metas, gt_bboxes_ignore)
onegt_cls_loss = one_gt_losses['loss_cls']
onegt_box_loss = one_gt_losses['loss_bbox']
assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero'
assert onegt_box_loss.item() > 0, 'box loss should be non-zero'

@ -2,7 +2,7 @@ import pytest
import torch
from torch.nn.modules.batchnorm import _BatchNorm
from mmdet.models.necks import FPN, ChannelMapper
from mmdet.models.necks import FPN, ChannelMapper, DilatedEncoder
def test_fpn():
@ -236,3 +236,13 @@ def test_channel_mapper():
for i in range(len(feats)):
outs[i].shape[1] == out_channels
outs[i].shape[2] == outs[i].shape[3] == s // (2**i)
def test_dilated_encoder():
in_channels = 16
out_channels = 32
out_shape = 34
dilated_encoder = DilatedEncoder(in_channels, out_channels, 16, 2)
feat = [torch.rand(1, in_channels, 34, 34)]
out_feat = dilated_encoder(feat)[0]
assert out_feat.shape == (1, out_channels, out_shape, out_shape)

@ -8,7 +8,8 @@ import torch
from mmdet.core.bbox.assigners import (ApproxMaxIoUAssigner,
CenterRegionAssigner, HungarianAssigner,
MaxIoUAssigner, PointAssigner)
MaxIoUAssigner, PointAssigner,
UniformAssigner)
def test_max_iou_assigner():
@ -422,3 +423,75 @@ def test_hungarian_match_assigner():
assert torch.all(assign_result.gt_inds > -1)
assert (assign_result.gt_inds > 0).sum() == gt_bboxes.size(0)
assert (assign_result.labels > -1).sum() == gt_bboxes.size(0)
def test_uniform_assigner():
self = UniformAssigner(0.15, 0.7, 1)
pred_bbox = torch.FloatTensor([
[1, 1, 12, 8],
[4, 4, 20, 20],
[1, 5, 15, 15],
[30, 5, 32, 42],
])
anchor = torch.FloatTensor([
[0, 0, 10, 10],
[10, 10, 20, 20],
[5, 5, 15, 15],
[32, 32, 38, 42],
])
gt_bboxes = torch.FloatTensor([
[0, 0, 10, 9],
[0, 10, 10, 19],
])
gt_labels = torch.LongTensor([2, 3])
assign_result = self.assign(
pred_bbox, anchor, gt_bboxes, gt_labels=gt_labels)
assert len(assign_result.gt_inds) == 4
assert len(assign_result.labels) == 4
expected_gt_inds = torch.LongTensor([-1, 0, 2, 0])
assert torch.all(assign_result.gt_inds == expected_gt_inds)
def test_uniform_assigner_with_empty_gt():
"""Test corner case where an image might have no true detections."""
self = UniformAssigner(0.15, 0.7, 1)
pred_bbox = torch.FloatTensor([
[1, 1, 12, 8],
[4, 4, 20, 20],
[1, 5, 15, 15],
[30, 5, 32, 42],
])
anchor = torch.FloatTensor([
[0, 0, 10, 10],
[10, 10, 20, 20],
[5, 5, 15, 15],
[32, 32, 38, 42],
])
gt_bboxes = torch.empty(0, 4)
assign_result = self.assign(pred_bbox, anchor, gt_bboxes)
expected_gt_inds = torch.LongTensor([0, 0, 0, 0])
assert torch.all(assign_result.gt_inds == expected_gt_inds)
def test_uniform_assigner_with_empty_boxes():
"""Test corner case where a network might predict no boxes."""
self = UniformAssigner(0.15, 0.7, 1)
pred_bbox = torch.empty((0, 4))
anchor = torch.empty((0, 4))
gt_bboxes = torch.FloatTensor([
[0, 0, 10, 9],
[0, 10, 10, 19],
])
gt_labels = torch.LongTensor([2, 3])
# Test with gt_labels
assign_result = self.assign(
pred_bbox, anchor, gt_bboxes, gt_labels=gt_labels)
assert len(assign_result.gt_inds) == 0
assert tuple(assign_result.labels.shape) == (0, )
# Test without gt_labels
assign_result = self.assign(pred_bbox, anchor, gt_bboxes, gt_labels=None)
assert len(assign_result.gt_inds) == 0

@ -58,6 +58,21 @@ def test_delta_bbox_coder():
out = coder.decode(rois, deltas, max_shape=(32, 32))
assert rois.shape == out.shape
# test add_ctr_clamp
coder = DeltaXYWHBBoxCoder(add_ctr_clamp=True, ctr_clamp=2)
rois = torch.Tensor([[0., 0., 6., 6.], [0., 0., 1., 1.], [0., 0., 1., 1.],
[5., 5., 5., 5.]])
deltas = torch.Tensor([[1., 1., 2., 2.], [1., 1., 1., 1.],
[0., 0., 2., -1.], [0.7, -1.9, -0.5, 0.3]])
expected_decode_bboxes = torch.Tensor([[0.0000, 0.0000, 27.1672, 27.1672],
[0.1409, 0.1409, 2.8591, 2.8591],
[0.0000, 0.3161, 4.1945, 0.6839],
[5.0000, 5.0000, 5.0000, 5.0000]])
out = coder.decode(rois, deltas, max_shape=(32, 32))
assert expected_decode_bboxes.allclose(out, atol=1e-04)
def test_tblr_bbox_coder():
coder = TBLRBBoxCoder(normalizer=15.)

Loading…
Cancel
Save