OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
111 lines
4.1 KiB
111 lines
4.1 KiB
# Copyright (c) OpenMMLab. All rights reserved. |
|
import torch |
|
|
|
from mmdet.core import bbox2result |
|
from mmdet.models.builder import DETECTORS |
|
from ...core.utils import flip_tensor |
|
from .single_stage import SingleStageDetector |
|
|
|
|
|
@DETECTORS.register_module() |
|
class CenterNet(SingleStageDetector): |
|
"""Implementation of CenterNet(Objects as Points) |
|
|
|
<https://arxiv.org/abs/1904.07850>. |
|
""" |
|
|
|
def __init__(self, |
|
backbone, |
|
neck, |
|
bbox_head, |
|
train_cfg=None, |
|
test_cfg=None, |
|
pretrained=None, |
|
init_cfg=None): |
|
super(CenterNet, self).__init__(backbone, neck, bbox_head, train_cfg, |
|
test_cfg, pretrained, init_cfg) |
|
|
|
def merge_aug_results(self, aug_results, with_nms): |
|
"""Merge augmented detection bboxes and score. |
|
|
|
Args: |
|
aug_results (list[list[Tensor]]): Det_bboxes and det_labels of each |
|
image. |
|
with_nms (bool): If True, do nms before return boxes. |
|
|
|
Returns: |
|
tuple: (out_bboxes, out_labels) |
|
""" |
|
recovered_bboxes, aug_labels = [], [] |
|
for single_result in aug_results: |
|
recovered_bboxes.append(single_result[0][0]) |
|
aug_labels.append(single_result[0][1]) |
|
|
|
bboxes = torch.cat(recovered_bboxes, dim=0).contiguous() |
|
labels = torch.cat(aug_labels).contiguous() |
|
if with_nms: |
|
out_bboxes, out_labels = self.bbox_head._bboxes_nms( |
|
bboxes, labels, self.bbox_head.test_cfg) |
|
else: |
|
out_bboxes, out_labels = bboxes, labels |
|
|
|
return out_bboxes, out_labels |
|
|
|
def aug_test(self, imgs, img_metas, rescale=True): |
|
"""Augment testing of CenterNet. Aug test must have flipped image pair, |
|
and unlike CornerNet, it will perform an averaging operation on the |
|
feature map instead of detecting bbox. |
|
|
|
Args: |
|
imgs (list[Tensor]): Augmented images. |
|
img_metas (list[list[dict]]): Meta information of each image, e.g., |
|
image size, scaling factor, etc. |
|
rescale (bool): If True, return boxes in original image space. |
|
Default: True. |
|
|
|
Note: |
|
``imgs`` must including flipped image pairs. |
|
|
|
Returns: |
|
list[list[np.ndarray]]: BBox results of each image and classes. |
|
The outer list corresponds to each image. The inner list |
|
corresponds to each class. |
|
""" |
|
img_inds = list(range(len(imgs))) |
|
assert img_metas[0][0]['flip'] + img_metas[1][0]['flip'], ( |
|
'aug test must have flipped image pair') |
|
aug_results = [] |
|
for ind, flip_ind in zip(img_inds[0::2], img_inds[1::2]): |
|
flip_direction = img_metas[flip_ind][0]['flip_direction'] |
|
img_pair = torch.cat([imgs[ind], imgs[flip_ind]]) |
|
x = self.extract_feat(img_pair) |
|
center_heatmap_preds, wh_preds, offset_preds = self.bbox_head(x) |
|
assert len(center_heatmap_preds) == len(wh_preds) == len( |
|
offset_preds) == 1 |
|
|
|
# Feature map averaging |
|
center_heatmap_preds[0] = ( |
|
center_heatmap_preds[0][0:1] + |
|
flip_tensor(center_heatmap_preds[0][1:2], flip_direction)) / 2 |
|
wh_preds[0] = (wh_preds[0][0:1] + |
|
flip_tensor(wh_preds[0][1:2], flip_direction)) / 2 |
|
|
|
bbox_list = self.bbox_head.get_bboxes( |
|
center_heatmap_preds, |
|
wh_preds, [offset_preds[0][0:1]], |
|
img_metas[ind], |
|
rescale=rescale, |
|
with_nms=False) |
|
aug_results.append(bbox_list) |
|
|
|
nms_cfg = self.bbox_head.test_cfg.get('nms_cfg', None) |
|
if nms_cfg is None: |
|
with_nms = False |
|
else: |
|
with_nms = True |
|
bbox_list = [self.merge_aug_results(aug_results, with_nms)] |
|
bbox_results = [ |
|
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) |
|
for det_bboxes, det_labels in bbox_list |
|
] |
|
return bbox_results
|
|
|