Support TTA of ATSS, FCOS, YOLOv3 (#3844)

* Support TTA of ATSS, FCOS, YOLOv3

* Add comment
pull/3875/head
Yosuke Shinya 4 years ago committed by GitHub
parent c2da97c02d
commit 5b18b94f61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      configs/yolo/yolov3_d53_mstrain-608_273e_coco.py
  2. 48
      mmdet/models/dense_heads/atss_head.py
  3. 23
      mmdet/models/dense_heads/dense_test_mixins.py
  4. 78
      mmdet/models/dense_heads/fcos_head.py
  5. 80
      mmdet/models/dense_heads/yolo_head.py

@ -47,7 +47,7 @@ test_cfg = dict(
min_bbox_size=0,
score_thr=0.05,
conf_thr=0.005,
nms=dict(type='nms', iou_thr=0.45),
nms=dict(type='nms', iou_threshold=0.45),
max_per_img=100)
# dataset settings
dataset_type = 'CocoDataset'

@ -327,22 +327,25 @@ class ATSSHead(AnchorHead):
centernesses,
img_metas,
cfg=None,
rescale=False):
rescale=False,
with_nms=True):
"""Transform network output for a batch into bbox predictions.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
with shape (N, num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W)
centernesses (list[Tensor]): Centerness for each scale
level with shape (N, num_anchors * 1, H, W)
level with shape (N, num_anchors * 4, H, W).
centernesses (list[Tensor]): Centerness for each scale level with
shape (N, num_anchors * 1, H, W).
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
cfg (mmcv.Config): Test / postprocessing configuration,
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used. Default: None.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
@ -376,7 +379,8 @@ class ATSSHead(AnchorHead):
proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list,
centerness_pred_list,
mlvl_anchors, img_shape,
scale_factor, cfg, rescale)
scale_factor, cfg, rescale,
with_nms)
result_list.append(proposals)
return result_list
@ -388,26 +392,29 @@ class ATSSHead(AnchorHead):
img_shape,
scale_factor,
cfg,
rescale=False):
rescale=False,
with_nms=True):
"""Transform outputs for a single batch item into labeled boxes.
Args:
cls_scores (list[Tensor]): Box scores for a single scale level
Has shape (num_anchors * num_classes, H, W).
with shape (num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for a single
scale level with shape (num_anchors * 4, H, W).
centernesses (list[Tensor]): Centerness for a single scale level
Has shape (num_anchors * 1, H, W).
with shape (num_anchors * 1, H, W).
mlvl_anchors (list[Tensor]): Box reference for a single scale level
with shape (num_total_anchors, 4).
img_shape (tuple[int]): Shape of the input image,
(height, width, 3).
scale_factor (ndarray): Scale factor of the image arange as
scale_factor (ndarray): Scale factor of the image arrange as
(w_scale, h_scale, w_scale, h_scale).
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
tuple(Tensor):
@ -457,14 +464,17 @@ class ATSSHead(AnchorHead):
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
mlvl_centerness = torch.cat(mlvl_centerness)
det_bboxes, det_labels = multiclass_nms(
mlvl_bboxes,
mlvl_scores,
cfg.score_thr,
cfg.nms,
cfg.max_per_img,
score_factors=mlvl_centerness)
return det_bboxes, det_labels
if with_nms:
det_bboxes, det_labels = multiclass_nms(
mlvl_bboxes,
mlvl_scores,
cfg.score_thr,
cfg.nms,
cfg.max_per_img,
score_factors=mlvl_centerness)
return det_bboxes, det_labels
else:
return mlvl_bboxes, mlvl_scores, mlvl_centerness
def get_targets(self,
anchor_list,

@ -62,21 +62,30 @@ class BBoxTestMixin(object):
aug_bboxes = []
aug_scores = []
aug_factors = [] # score_factors for NMS
for x, img_meta in zip(feats, img_metas):
# only one image in the batch
outs = self.forward(x)
bbox_inputs = outs + (img_meta, self.test_cfg, False, False)
det_bboxes, det_scores = self.get_bboxes(*bbox_inputs)[0]
aug_bboxes.append(det_bboxes)
aug_scores.append(det_scores)
bbox_outputs = self.get_bboxes(*bbox_inputs)[0]
aug_bboxes.append(bbox_outputs[0])
aug_scores.append(bbox_outputs[1])
# bbox_outputs of some detectors (e.g., ATSS, FCOS, YOLOv3)
# contains additional element to adjust scores before NMS
if len(bbox_outputs) >= 3:
aug_factors.append(bbox_outputs[2])
# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = self.merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas)
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
self.test_cfg.score_thr,
self.test_cfg.nms,
self.test_cfg.max_per_img)
merged_factors = torch.cat(aug_factors, dim=0) if aug_factors else None
det_bboxes, det_labels = multiclass_nms(
merged_bboxes,
merged_scores,
self.test_cfg.score_thr,
self.test_cfg.nms,
self.test_cfg.max_per_img,
score_factors=merged_factors)
if rescale:
_det_bboxes = det_bboxes

@ -257,29 +257,33 @@ class FCOSHead(AnchorFreeHead):
centernesses,
img_metas,
cfg=None,
rescale=None):
rescale=False,
with_nms=True):
"""Transform network output for a batch into bbox predictions.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_points * num_classes, H, W)
with shape (N, num_points * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_points * 4, H, W)
level with shape (N, num_points * 4, H, W).
centernesses (list[Tensor]): Centerness for each scale level with
shape (N, num_points * 1, H, W)
shape (N, num_points * 1, H, W).
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used
rescale (bool): If True, return boxes in original image space
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used. Default: None.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. \
The first item is an (n, 5) tensor, where the first 4 columns \
are bounding box positions (tl_x, tl_y, br_x, br_y) and the \
5-th column is a score between 0 and 1. The second item is a \
(n,) tensor where each item is the predicted class label of \
the corresponding box.
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
The first item is an (n, 5) tensor, where the first 4 columns
are bounding box positions (tl_x, tl_y, br_x, br_y) and the
5-th column is a score between 0 and 1. The second item is a
(n,) tensor where each item is the predicted class label of the
corresponding box.
"""
assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)
@ -300,11 +304,9 @@ class FCOSHead(AnchorFreeHead):
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
det_bboxes = self._get_bboxes_single(cls_score_list,
bbox_pred_list,
centerness_pred_list,
mlvl_points, img_shape,
scale_factor, cfg, rescale)
det_bboxes = self._get_bboxes_single(
cls_score_list, bbox_pred_list, centerness_pred_list,
mlvl_points, img_shape, scale_factor, cfg, rescale, with_nms)
result_list.append(det_bboxes)
return result_list
@ -316,12 +318,13 @@ class FCOSHead(AnchorFreeHead):
img_shape,
scale_factor,
cfg,
rescale=False):
rescale=False,
with_nms=True):
"""Transform outputs for a single batch item into bbox predictions.
Args:
cls_scores (list[Tensor]): Box scores for a single scale level
Has shape (num_points * num_classes, H, W).
with shape (num_points * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for a single scale
level with shape (num_points * 4, H, W).
centernesses (list[Tensor]): Centerness for a single scale level
@ -332,14 +335,21 @@ class FCOSHead(AnchorFreeHead):
(height, width, 3).
scale_factor (ndarray): Scale factor of the image arrange as
(w_scale, h_scale, w_scale, h_scale).
cfg (mmcv.Config): Test / postprocessing configuration,
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
Tensor: Labeled boxes in shape (n, 5), where the first 4 columns \
are bounding box positions (tl_x, tl_y, br_x, br_y) and the \
5-th column is a score between 0 and 1.
tuple(Tensor):
det_bboxes (Tensor): BBox predictions in shape (n, 5), where
the first 4 columns are bounding box positions
(tl_x, tl_y, br_x, br_y) and the 5-th column is a score
between 0 and 1.
det_labels (Tensor): A (n,) tensor where each item is the
predicted class label of the corresponding box.
"""
cfg = self.test_cfg if cfg is None else cfg
assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
@ -375,14 +385,18 @@ class FCOSHead(AnchorFreeHead):
# BG cat_id: num_class
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
mlvl_centerness = torch.cat(mlvl_centerness)
det_bboxes, det_labels = multiclass_nms(
mlvl_bboxes,
mlvl_scores,
cfg.score_thr,
cfg.nms,
cfg.max_per_img,
score_factors=mlvl_centerness)
return det_bboxes, det_labels
if with_nms:
det_bboxes, det_labels = multiclass_nms(
mlvl_bboxes,
mlvl_scores,
cfg.score_thr,
cfg.nms,
cfg.max_per_img,
score_factors=mlvl_centerness)
return det_bboxes, det_labels
else:
return mlvl_bboxes, mlvl_scores, mlvl_centerness
def _get_points_single(self,
featmap_size,

@ -13,10 +13,11 @@ from mmdet.core import (build_anchor_generator, build_assigner,
multi_apply, multiclass_nms)
from ..builder import HEADS, build_loss
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin
@HEADS.register_module()
class YOLOV3Head(BaseDenseHead):
class YOLOV3Head(BaseDenseHead, BBoxTestMixin):
"""YOLOV3Head Paper link: https://arxiv.org/abs/1804.02767.
Args:
@ -169,16 +170,24 @@ class YOLOV3Head(BaseDenseHead):
return tuple(pred_maps),
@force_fp32(apply_to=('pred_maps', ))
def get_bboxes(self, pred_maps, img_metas, cfg=None, rescale=False):
def get_bboxes(self,
pred_maps,
img_metas,
cfg=None,
rescale=False,
with_nms=True):
"""Transform network output for a batch into bbox predictions.
Args:
pred_maps (list[Tensor]): Raw predictions for a batch of images.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used
rescale (bool): If True, return boxes in original image space
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used. Default: None.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
@ -196,7 +205,7 @@ class YOLOV3Head(BaseDenseHead):
]
scale_factor = img_metas[img_id]['scale_factor']
proposals = self._get_bboxes_single(pred_maps_list, scale_factor,
cfg, rescale)
cfg, rescale, with_nms)
result_list.append(proposals)
return result_list
@ -204,7 +213,8 @@ class YOLOV3Head(BaseDenseHead):
pred_maps_list,
scale_factor,
cfg,
rescale=False):
rescale=False,
with_nms=True):
"""Transform outputs for a single batch item into bbox predictions.
Args:
@ -212,14 +222,21 @@ class YOLOV3Head(BaseDenseHead):
of each single image in the batch.
scale_factor (ndarray): Scale factor of the image arrange as
(w_scale, h_scale, w_scale, h_scale).
cfg (mmcv.Config): Test / postprocessing configuration,
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
are bounding box positions (tl_x, tl_y, br_x, br_y) and the
5-th column is a score between 0 and 1.
tuple(Tensor):
det_bboxes (Tensor): BBox predictions in shape (n, 5), where
the first 4 columns are bounding box positions
(tl_x, tl_y, br_x, br_y) and the 5-th column is a score
between 0 and 1.
det_labels (Tensor): A (n,) tensor where each item is the
predicted class label of the corresponding box.
"""
cfg = self.test_cfg if cfg is None else cfg
assert len(pred_maps_list) == self.num_levels
@ -273,7 +290,7 @@ class YOLOV3Head(BaseDenseHead):
multi_lvl_cls_scores = torch.cat(multi_lvl_cls_scores)
multi_lvl_conf_scores = torch.cat(multi_lvl_conf_scores)
if multi_lvl_conf_scores.size(0) == 0:
if with_nms and (multi_lvl_conf_scores.size(0) == 0):
return torch.zeros((0, 5)), torch.zeros((0, ))
if rescale:
@ -286,15 +303,18 @@ class YOLOV3Head(BaseDenseHead):
multi_lvl_cls_scores = torch.cat([multi_lvl_cls_scores, padding],
dim=1)
det_bboxes, det_labels = multiclass_nms(
multi_lvl_bboxes,
multi_lvl_cls_scores,
cfg.score_thr,
cfg.nms,
cfg.max_per_img,
score_factors=multi_lvl_conf_scores)
return det_bboxes, det_labels
if with_nms:
det_bboxes, det_labels = multiclass_nms(
multi_lvl_bboxes,
multi_lvl_cls_scores,
cfg.score_thr,
cfg.nms,
cfg.max_per_img,
score_factors=multi_lvl_conf_scores)
return det_bboxes, det_labels
else:
return (multi_lvl_bboxes, multi_lvl_cls_scores,
multi_lvl_conf_scores)
@force_fp32(apply_to=('pred_maps', ))
def loss(self,
@ -488,3 +508,21 @@ class YOLOV3Head(BaseDenseHead):
neg_map[sampling_result.neg_inds] = 1
return target_map, neg_map
def aug_test(self, feats, img_metas, rescale=False):
"""Test function with test time augmentation.
Args:
feats (list[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains features for all images in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch. each dict has image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[ndarray]: bbox results of each class
"""
return self.aug_test_bboxes(feats, img_metas, rescale=rescale)

Loading…
Cancel
Save