OpenMMLab Detection Toolbox and Benchmark https://mmdetection.readthedocs.io/
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

350 lines
14 KiB

# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from abc import abstractmethod
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import force_fp32
from mmdet.core import build_bbox_coder, multi_apply
from mmdet.core.anchor.point_generator import MlvlPointGenerator
from ..builder import HEADS, build_loss
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin
@HEADS.register_module()
class AnchorFreeHead(BaseDenseHead, BBoxTestMixin):
"""Anchor-free head (FCOS, Fovea, RepPoints, etc.).
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of hidden channels. Used in child classes.
stacked_convs (int): Number of stacking convs of the head.
strides (tuple): Downsample factor of each feature map.
dcn_on_last_conv (bool): If true, use dcn in the last layer of
towers. Default: False.
conv_bias (bool | str): If specified as `auto`, it will be decided by
the norm_cfg. Bias of conv will be set as True if `norm_cfg` is
None, otherwise False. Default: "auto".
loss_cls (dict): Config of classification loss.
loss_bbox (dict): Config of localization loss.
bbox_coder (dict): Config of bbox coder. Defaults
'DistancePointBBoxCoder'.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Config dict for normalization layer. Default: None.
train_cfg (dict): Training config of anchor head.
test_cfg (dict): Testing config of anchor head.
init_cfg (dict or list[dict], optional): Initialization config dict.
""" # noqa: W605
_version = 1
def __init__(self,
num_classes,
in_channels,
feat_channels=256,
stacked_convs=4,
strides=(4, 8, 16, 32, 64),
dcn_on_last_conv=False,
conv_bias='auto',
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='IoULoss', loss_weight=1.0),
bbox_coder=dict(type='DistancePointBBoxCoder'),
conv_cfg=None,
norm_cfg=None,
train_cfg=None,
test_cfg=None,
init_cfg=dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal',
name='conv_cls',
std=0.01,
bias_prob=0.01))):
super(AnchorFreeHead, self).__init__(init_cfg)
self.num_classes = num_classes
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
if self.use_sigmoid_cls:
self.cls_out_channels = num_classes
else:
self.cls_out_channels = num_classes + 1
self.in_channels = in_channels
self.feat_channels = feat_channels
self.stacked_convs = stacked_convs
self.strides = strides
self.dcn_on_last_conv = dcn_on_last_conv
assert conv_bias == 'auto' or isinstance(conv_bias, bool)
self.conv_bias = conv_bias
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.prior_generator = MlvlPointGenerator(strides)
# In order to keep a more general interface and be consistent with
# anchor_head. We can think of point like one anchor
self.num_base_priors = self.prior_generator.num_base_priors[0]
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.fp16_enabled = False
self._init_layers()
def _init_layers(self):
"""Initialize layers of the head."""
self._init_cls_convs()
self._init_reg_convs()
self._init_predictor()
def _init_cls_convs(self):
"""Initialize classification conv layers of the head."""
self.cls_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
if self.dcn_on_last_conv and i == self.stacked_convs - 1:
conv_cfg = dict(type='DCNv2')
else:
conv_cfg = self.conv_cfg
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=self.norm_cfg,
bias=self.conv_bias))
def _init_reg_convs(self):
"""Initialize bbox regression conv layers of the head."""
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
if self.dcn_on_last_conv and i == self.stacked_convs - 1:
conv_cfg = dict(type='DCNv2')
else:
conv_cfg = self.conv_cfg
self.reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=self.norm_cfg,
bias=self.conv_bias))
def _init_predictor(self):
"""Initialize predictor layers of the head."""
self.conv_cls = nn.Conv2d(
self.feat_channels, self.cls_out_channels, 3, padding=1)
self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""Hack some keys of the model state dict so that can load checkpoints
of previous version."""
version = local_metadata.get('version', None)
if version is None:
# the key is different in early versions
# for example, 'fcos_cls' become 'conv_cls' now
bbox_head_keys = [
k for k in state_dict.keys() if k.startswith(prefix)
]
ori_predictor_keys = []
new_predictor_keys = []
# e.g. 'fcos_cls' or 'fcos_reg'
for key in bbox_head_keys:
ori_predictor_keys.append(key)
key = key.split('.')
conv_name = None
if key[1].endswith('cls'):
conv_name = 'conv_cls'
elif key[1].endswith('reg'):
conv_name = 'conv_reg'
elif key[1].endswith('centerness'):
conv_name = 'conv_centerness'
else:
assert NotImplementedError
if conv_name is not None:
key[1] = conv_name
new_predictor_keys.append('.'.join(key))
else:
ori_predictor_keys.pop(-1)
for i in range(len(new_predictor_keys)):
state_dict[new_predictor_keys[i]] = state_dict.pop(
ori_predictor_keys[i])
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)
def forward(self, feats):
"""Forward features from the upstream network.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: Usually contain classification scores and bbox predictions.
cls_scores (list[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is
num_points * num_classes.
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is
num_points * 4.
"""
return multi_apply(self.forward_single, feats)[:2]
def forward_single(self, x):
"""Forward features of a single scale level.
Args:
x (Tensor): FPN feature maps of the specified stride.
Returns:
tuple: Scores for each class, bbox predictions, features
after classification and regression conv layers, some
models needs these features like FCOS.
"""
cls_feat = x
reg_feat = x
for cls_layer in self.cls_convs:
cls_feat = cls_layer(cls_feat)
cls_score = self.conv_cls(cls_feat)
for reg_layer in self.reg_convs:
reg_feat = reg_layer(reg_feat)
bbox_pred = self.conv_reg(reg_feat)
return cls_score, bbox_pred, cls_feat, reg_feat
@abstractmethod
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
gt_bboxes_ignore=None):
"""Compute loss of the head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level,
each is a 4D-tensor, the channel number is
num_points * num_classes.
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level, each is a 4D-tensor, the channel number is
num_points * 4.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (None | list[Tensor]): specify which bounding
boxes can be ignored when computing the loss.
"""
raise NotImplementedError
@abstractmethod
def get_targets(self, points, gt_bboxes_list, gt_labels_list):
"""Compute regression, classification and centerness targets for points
in multiple images.
Args:
points (list[Tensor]): Points of each fpn level, each has shape
(num_points, 2).
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
each has shape (num_gt, 4).
gt_labels_list (list[Tensor]): Ground truth labels of each box,
each has shape (num_gt,).
"""
raise NotImplementedError
def _get_points_single(self,
featmap_size,
stride,
dtype,
device,
flatten=False):
"""Get points of a single scale level.
This function will be deprecated soon.
"""
warnings.warn(
'`_get_points_single` in `AnchorFreeHead` will be '
'deprecated soon, we support a multi level point generator now'
'you can get points of a single level feature map '
'with `self.prior_generator.single_level_grid_priors` ')
h, w = featmap_size
# First create Range with the default dtype, than convert to
# target `dtype` for onnx exporting.
x_range = torch.arange(w, device=device).to(dtype)
y_range = torch.arange(h, device=device).to(dtype)
y, x = torch.meshgrid(y_range, x_range)
if flatten:
y = y.flatten()
x = x.flatten()
return y, x
def get_points(self, featmap_sizes, dtype, device, flatten=False):
"""Get points according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
dtype (torch.dtype): Type of points.
device (torch.device): Device of points.
Returns:
tuple: points of each image.
"""
warnings.warn(
'`get_points` in `AnchorFreeHead` will be '
'deprecated soon, we support a multi level point generator now'
'you can get points of all levels '
'with `self.prior_generator.grid_priors` ')
mlvl_points = []
for i in range(len(featmap_sizes)):
mlvl_points.append(
self._get_points_single(featmap_sizes[i], self.strides[i],
dtype, device, flatten))
return mlvl_points
def aug_test(self, feats, img_metas, rescale=False):
"""Test function with test time augmentation.
Args:
feats (list[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains features for all images in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch. each dict has image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[ndarray]: bbox results of each class
"""
return self.aug_test_bboxes(feats, img_metas, rescale=rescale)