From 4febf34fc9f26832631bcaaf6e0938cab563db32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?= <1286304229@qq.com> Date: Fri, 12 Mar 2021 13:18:09 +0800 Subject: [PATCH] [Refactor]: Support batch inference traceable by ONNX in SSD, YOLOv3, FSAF, RetinaNet, and FCOS (#4699) * Support batch infer in RetinaNet * Suport batch multiclass_nms * Revert multiclass_nms * Fix api deprecated warning * do not repeat anchors * Move img_shapes * Update Yolov3 * Support FCOS * Support RPN * Fix RPN topk_inds error * make batch exportable to onnx for yolohead * make fcos_head exportable to onnx with batch dim * Support ATSS * Support CornerNet and centripetalNet * Update RetinaNet and delta_xywh * Remove ugly code * Remove ugly code of FCOS * Remove ugly code of ATSS/YOLOV3 * Support RPN and revert bbox_head * expand anchors to batch and remove BG class when use deploy_nms_pre * Update * Use dim=-1 instead of dim=2 * Rename anchor_head method * Keep the original format output when nms is not use * Rename method and unified code style * Fix paa_head and unittest * Fix FASF onnx export error * Fix error * fix single stage img_shapes for onnx * move conf_thr * fix rpn_head for onnx * Add distance2bbox unittest * Remove TODO * Fix RPN * Update docstrs Co-authored-by: maningsheng --- configs/_base_/models/ssd300.py | 1 + .../core/bbox/coder/delta_xywh_bbox_coder.py | 2 +- mmdet/core/bbox/coder/tblr_bbox_coder.py | 7 +- mmdet/core/bbox/transforms.py | 42 ++-- mmdet/models/dense_heads/anchor_head.py | 183 +++++++++------- mmdet/models/dense_heads/atss_head.py | 190 +++++++++------- mmdet/models/dense_heads/cascade_rpn_head.py | 132 ++++++++++- mmdet/models/dense_heads/dense_test_mixins.py | 5 +- mmdet/models/dense_heads/fcos_head.py | 181 +++++++++------ mmdet/models/dense_heads/gfl_head.py | 97 +++++---- mmdet/models/dense_heads/paa_head.py | 102 +++++---- mmdet/models/dense_heads/rpn_head.py | 129 ++++++----- mmdet/models/dense_heads/yolo_head.py | 206 ++++++++++-------- .../test_dense_heads/test_paa_head.py | 10 +- tests/test_models/test_forward.py | 2 +- tests/test_utils/test_misc.py | 45 ++++ 16 files changed, 867 insertions(+), 467 deletions(-) diff --git a/configs/_base_/models/ssd300.py b/configs/_base_/models/ssd300.py index 4ea797503..1b839ad43 100644 --- a/configs/_base_/models/ssd300.py +++ b/configs/_base_/models/ssd300.py @@ -42,6 +42,7 @@ model = dict( neg_pos_ratio=3, debug=False), test_cfg=dict( + nms_pre=1000, nms=dict(type='nms', iou_threshold=0.45), min_bbox_size=0, score_thr=0.02, diff --git a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py index 07473fe52..da317184a 100644 --- a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py +++ b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py @@ -179,7 +179,7 @@ def delta2bbox(rois, >>> [ 1., 1., 1., 1.], >>> [ 0., 0., 2., -1.], >>> [ 0.7, -1.9, -0.5, 0.3]]) - >>> delta2bbox(rois, deltas, max_shape=(32, 32)) + >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3)) tensor([[0.0000, 0.0000, 1.0000, 1.0000], [0.1409, 0.1409, 2.8591, 2.8591], [0.0000, 0.3161, 4.1945, 0.6839], diff --git a/mmdet/core/bbox/coder/tblr_bbox_coder.py b/mmdet/core/bbox/coder/tblr_bbox_coder.py index ccf796ad5..edaffaf1f 100644 --- a/mmdet/core/bbox/coder/tblr_bbox_coder.py +++ b/mmdet/core/bbox/coder/tblr_bbox_coder.py @@ -168,9 +168,10 @@ def tblr2bboxes(priors, if normalize_by_wh: wh = priors[..., 2:4] - priors[..., 0:2] w, h = torch.split(wh, 1, dim=-1) - loc_decode[..., :2] *= h # tb - loc_decode[..., 2:] *= w # lr - + # Inplace operation with slice would failed for exporting to ONNX + th = h * loc_decode[..., :2] # tb + tw = w * loc_decode[..., 2:] # lr + loc_decode = torch.cat([th, tw], dim=-1) # Cannot be exported using onnx when loc_decode.split(1, dim=-1) top, bottom, left, right = loc_decode.split((1, 1, 1, 1), dim=-1) xmin = prior_centers[..., 0].unsqueeze(-1) - left diff --git a/mmdet/core/bbox/transforms.py b/mmdet/core/bbox/transforms.py index 102db0d1f..df55b0a49 100644 --- a/mmdet/core/bbox/transforms.py +++ b/mmdet/core/bbox/transforms.py @@ -120,24 +120,40 @@ def distance2bbox(points, distance, max_shape=None): """Decode distance prediction to bounding box. Args: - points (Tensor): Shape (n, 2), [x, y]. + points (Tensor): Shape (B, N, 2) or (N, 2). distance (Tensor): Distance from the given point to 4 - boundaries (left, top, right, bottom). - max_shape (tuple): Shape of the image. + boundaries (left, top, right, bottom). Shape (B, N, 4) or (N, 4) + max_shape (Sequence[int] or torch.Tensor or Sequence[ + Sequence[int]],optional): Maximum bounds for boxes, specifies + (H, W, C) or (H, W). If priors shape is (B, N, 4), then + the max_shape should be a Sequence[Sequence[int]] + and the length of max_shape should also be B. Returns: - Tensor: Decoded bboxes. + Tensor: Boxes with shape (N, 4) or (B, N, 4) """ - x1 = points[:, 0] - distance[:, 0] - y1 = points[:, 1] - distance[:, 1] - x2 = points[:, 0] + distance[:, 2] - y2 = points[:, 1] + distance[:, 3] + x1 = points[..., 0] - distance[..., 0] + y1 = points[..., 1] - distance[..., 1] + x2 = points[..., 0] + distance[..., 2] + y2 = points[..., 1] + distance[..., 3] + + bboxes = torch.stack([x1, y1, x2, y2], -1) + if max_shape is not None: - x1 = x1.clamp(min=0, max=max_shape[1]) - y1 = y1.clamp(min=0, max=max_shape[0]) - x2 = x2.clamp(min=0, max=max_shape[1]) - y2 = y2.clamp(min=0, max=max_shape[0]) - return torch.stack([x1, y1, x2, y2], -1) + if not isinstance(max_shape, torch.Tensor): + max_shape = x1.new_tensor(max_shape) + max_shape = max_shape[..., :2].type_as(x1) + if max_shape.ndim == 2: + assert bboxes.ndim == 3 + assert max_shape.size(0) == bboxes.size(0) + + min_xy = x1.new_tensor(0) + max_xy = torch.cat([max_shape, max_shape], + dim=-1).flip(-1).unsqueeze(-2) + bboxes = torch.where(bboxes < min_xy, min_xy, bboxes) + bboxes = torch.where(bboxes > max_xy, max_xy, bboxes) + + return bboxes def bbox2distance(points, bbox, max_dis=None, eps=0.1): diff --git a/mmdet/models/dense_heads/anchor_head.py b/mmdet/models/dense_heads/anchor_head.py index 007b04940..0e55892df 100644 --- a/mmdet/models/dense_heads/anchor_head.py +++ b/mmdet/models/dense_heads/anchor_head.py @@ -519,11 +519,11 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin): Returns: list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. - The first item is an (n, 5) tensor, where the first 4 columns - are bounding box positions (tl_x, tl_y, br_x, br_y) and the - 5-th column is a score between 0 and 1. The second item is a - (n,) tensor where each item is the predicted class labelof the - corresponding box. + The first item is an (n, 5) tensor, where 5 represent + (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1. + The shape of the second tensor in the tuple is (n,), and + each element represents the class label of the corresponding + box. Example: >>> import mmcv @@ -559,57 +559,57 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin): mlvl_anchors = self.anchor_generator.grid_anchors( featmap_sizes, device=device) - result_list = [] - for img_id in range(len(img_metas)): - cls_score_list = [ - cls_scores[i][img_id].detach() for i in range(num_levels) - ] - bbox_pred_list = [ - bbox_preds[i][img_id].detach() for i in range(num_levels) + cls_score_list = [cls_scores[i].detach() for i in range(num_levels)] + bbox_pred_list = [bbox_preds[i].detach() for i in range(num_levels)] + + if torch.onnx.is_in_onnx_export(): + assert len( + img_metas + ) == 1, 'Only support one input image while in exporting to ONNX' + img_shapes = img_metas[0]['img_shape_for_onnx'] + else: + img_shapes = [ + img_metas[i]['img_shape'] + for i in range(cls_scores[0].shape[0]) ] - # get origin input shape to support onnx dynamic shape - if torch.onnx.is_in_onnx_export(): - img_shape = img_metas[img_id]['img_shape_for_onnx'] - else: - img_shape = img_metas[img_id]['img_shape'] - scale_factor = img_metas[img_id]['scale_factor'] - if with_nms: - # some heads don't support with_nms argument - proposals = self._get_bboxes_single(cls_score_list, - bbox_pred_list, - mlvl_anchors, img_shape, - scale_factor, cfg, rescale) - else: - proposals = self._get_bboxes_single(cls_score_list, - bbox_pred_list, - mlvl_anchors, img_shape, - scale_factor, cfg, rescale, - with_nms) - result_list.append(proposals) + scale_factors = [ + img_metas[i]['scale_factor'] for i in range(cls_scores[0].shape[0]) + ] + + if with_nms: + # some heads don't support with_nms argument + result_list = self._get_bboxes(cls_score_list, bbox_pred_list, + mlvl_anchors, img_shapes, + scale_factors, cfg, rescale) + else: + result_list = self._get_bboxes(cls_score_list, bbox_pred_list, + mlvl_anchors, img_shapes, + scale_factors, cfg, rescale, + with_nms) return result_list - def _get_bboxes_single(self, - cls_score_list, - bbox_pred_list, - mlvl_anchors, - img_shape, - scale_factor, - cfg, - rescale=False, - with_nms=True): - """Transform outputs for a single batch item into bbox predictions. + def _get_bboxes(self, + cls_score_list, + bbox_pred_list, + mlvl_anchors, + img_shapes, + scale_factors, + cfg, + rescale=False, + with_nms=True): + """Transform outputs for a batch item into bbox predictions. Args: cls_score_list (list[Tensor]): Box scores for a single scale level - Has shape (num_anchors * num_classes, H, W). + Has shape (N, num_anchors * num_classes, H, W). bbox_pred_list (list[Tensor]): Box energies / deltas for a single - scale level with shape (num_anchors * 4, H, W). + scale level with shape (N, num_anchors * 4, H, W). mlvl_anchors (list[Tensor]): Box reference for a single scale level with shape (num_total_anchors, 4). - img_shape (tuple[int]): Shape of the input image, - (height, width, 3). - scale_factor (ndarray): Scale factor of the image arange as - (w_scale, h_scale, w_scale, h_scale). + img_shapes (list[tuple[int]]): Shape of the batch input image, + list[(height, width, 3)]. + scale_factors (list[ndarray]): Scale factor of the batch + image arange as list[(w_scale, h_scale, w_scale, h_scale)]. cfg (mmcv.Config): Test / postprocessing configuration, if None, test_cfg would be used. rescale (bool): If True, return boxes in original image space. @@ -618,78 +618,113 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin): Default: True. Returns: - Tensor: Labeled boxes in shape (n, 5), where the first 4 columns - are bounding box positions (tl_x, tl_y, br_x, br_y) and the - 5-th column is a score between 0 and 1. + list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. + The first item is an (n, 5) tensor, where 5 represent + (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1. + The shape of the second tensor in the tuple is (n,), and + each element represents the class label of the corresponding + box. """ cfg = self.test_cfg if cfg is None else cfg assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors) + batch_size = cls_score_list[0].shape[0] # convert to tensor to keep tracing nms_pre_tensor = torch.tensor( cfg.get('nms_pre', -1), device=cls_score_list[0].device, dtype=torch.long) + mlvl_bboxes = [] mlvl_scores = [] for cls_score, bbox_pred, anchors in zip(cls_score_list, bbox_pred_list, mlvl_anchors): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] - cls_score = cls_score.permute(1, 2, - 0).reshape(-1, self.cls_out_channels) + cls_score = cls_score.permute(0, 2, 3, + 1).reshape(batch_size, -1, + self.cls_out_channels) if self.use_sigmoid_cls: scores = cls_score.sigmoid() else: scores = cls_score.softmax(-1) - bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) + bbox_pred = bbox_pred.permute(0, 2, 3, + 1).reshape(batch_size, -1, 4) + anchors = anchors.expand_as(bbox_pred) # Always keep topk op for dynamic input in onnx if nms_pre_tensor > 0 and (torch.onnx.is_in_onnx_export() or scores.shape[-2] > nms_pre_tensor): from torch import _shape_as_tensor # keep shape as tensor and get k - num_anchor = _shape_as_tensor(scores)[-2].to(nms_pre_tensor) + num_anchor = _shape_as_tensor(scores)[-2].to( + nms_pre_tensor.device) nms_pre = torch.where(nms_pre_tensor < num_anchor, nms_pre_tensor, num_anchor) + # Get maximum scores for foreground classes. if self.use_sigmoid_cls: - max_scores, _ = scores.max(dim=1) + max_scores, _ = scores.max(-1) else: # remind that we set FG labels to [0, num_class-1] # since mmdet v2.0 # BG cat_id: num_class - max_scores, _ = scores[:, :-1].max(dim=1) + max_scores, _ = scores[..., :-1].max(-1) + _, topk_inds = max_scores.topk(nms_pre) - anchors = anchors[topk_inds, :] - bbox_pred = bbox_pred[topk_inds, :] - scores = scores[topk_inds, :] + batch_inds = torch.arange(batch_size).view( + -1, 1).expand_as(topk_inds) + anchors = anchors[batch_inds, topk_inds, :] + bbox_pred = bbox_pred[batch_inds, topk_inds, :] + scores = scores[batch_inds, topk_inds, :] + bboxes = self.bbox_coder.decode( - anchors, bbox_pred, max_shape=img_shape) + anchors, bbox_pred, max_shape=img_shapes) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) - mlvl_bboxes = torch.cat(mlvl_bboxes) + + batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1) if rescale: - mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) - mlvl_scores = torch.cat(mlvl_scores) + batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor( + scale_factors).unsqueeze(1) + batch_mlvl_scores = torch.cat(mlvl_scores, dim=1) + # Set max number of box to be feed into nms in deployment deploy_nms_pre = cfg.get('deploy_nms_pre', -1) if deploy_nms_pre > 0 and torch.onnx.is_in_onnx_export(): - max_scores, _ = mlvl_scores.max(dim=1) - _, topk_inds = max_scores.topk(deploy_nms_pre) - mlvl_scores = mlvl_scores[topk_inds, :] - mlvl_bboxes = mlvl_bboxes[topk_inds, :] + # Get maximum scores for foreground classes. + if self.use_sigmoid_cls: + batch_mlvl_scores, _ = batch_mlvl_scores.max(-1) + else: + # remind that we set FG labels to [0, num_class-1] + # since mmdet v2.0 + # BG cat_id: num_class + batch_mlvl_scores, _ = batch_mlvl_scores[..., :-1].max(-1) + _, topk_inds = batch_mlvl_scores.topk(deploy_nms_pre) + batch_inds = torch.arange(batch_size).view(-1, + 1).expand_as(topk_inds) + batch_mlvl_scores = batch_mlvl_scores[batch_inds, topk_inds] + batch_mlvl_bboxes = batch_mlvl_bboxes[batch_inds, topk_inds] if self.use_sigmoid_cls: # Add a dummy background class to the backend when using sigmoid # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 # BG cat_id: num_class - padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) - mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) + padding = batch_mlvl_scores.new_zeros(batch_size, + batch_mlvl_scores.shape[1], + 1) + batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1) if with_nms: - det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores, - cfg.score_thr, cfg.nms, - cfg.max_per_img) - return det_bboxes, det_labels + det_results = [] + for (mlvl_bboxes, mlvl_scores) in zip(batch_mlvl_bboxes, + batch_mlvl_scores): + det_bbox, det_label = multiclass_nms(mlvl_bboxes, mlvl_scores, + cfg.score_thr, cfg.nms, + cfg.max_per_img) + det_results.append(tuple([det_bbox, det_label])) else: - return mlvl_bboxes, mlvl_scores + det_results = [ + tuple(mlvl_bs) + for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores) + ] + return det_results def aug_test(self, feats, img_metas, rescale=False): """Test function with test time augmentation. diff --git a/mmdet/models/dense_heads/atss_head.py b/mmdet/models/dense_heads/atss_head.py index e96ea7ff1..7526d5470 100644 --- a/mmdet/models/dense_heads/atss_head.py +++ b/mmdet/models/dense_heads/atss_head.py @@ -342,11 +342,11 @@ class ATSSHead(AnchorHead): Returns: list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. - The first item is an (n, 5) tensor, where the first 4 columns - are bounding box positions (tl_x, tl_y, br_x, br_y) and the - 5-th column is a score between 0 and 1. The second item is a - (n,) tensor where each item is the predicted class label of the - corresponding box. + The first item is an (n, 5) tensor, where 5 represent + (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1. + The shape of the second tensor in the tuple is (n,), and + each element represents the class label of the corresponding + box. """ cfg = self.test_cfg if cfg is None else cfg assert len(cls_scores) == len(bbox_preds) @@ -356,51 +356,47 @@ class ATSSHead(AnchorHead): mlvl_anchors = self.anchor_generator.grid_anchors( featmap_sizes, device=device) - result_list = [] - for img_id in range(len(img_metas)): - cls_score_list = [ - cls_scores[i][img_id].detach() for i in range(num_levels) - ] - bbox_pred_list = [ - bbox_preds[i][img_id].detach() for i in range(num_levels) - ] - centerness_pred_list = [ - centernesses[i][img_id].detach() for i in range(num_levels) - ] - img_shape = img_metas[img_id]['img_shape'] - scale_factor = img_metas[img_id]['scale_factor'] - proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list, - centerness_pred_list, - mlvl_anchors, img_shape, - scale_factor, cfg, rescale, - with_nms) - result_list.append(proposals) + cls_score_list = [cls_scores[i].detach() for i in range(num_levels)] + bbox_pred_list = [bbox_preds[i].detach() for i in range(num_levels)] + centerness_pred_list = [ + centernesses[i].detach() for i in range(num_levels) + ] + img_shapes = [ + img_metas[i]['img_shape'] for i in range(cls_scores[0].shape[0]) + ] + scale_factors = [ + img_metas[i]['scale_factor'] for i in range(cls_scores[0].shape[0]) + ] + result_list = self._get_bboxes(cls_score_list, bbox_pred_list, + centerness_pred_list, mlvl_anchors, + img_shapes, scale_factors, cfg, rescale, + with_nms) return result_list - def _get_bboxes_single(self, - cls_scores, - bbox_preds, - centernesses, - mlvl_anchors, - img_shape, - scale_factor, - cfg, - rescale=False, - with_nms=True): + def _get_bboxes(self, + cls_scores, + bbox_preds, + centernesses, + mlvl_anchors, + img_shapes, + scale_factors, + cfg, + rescale=False, + with_nms=True): """Transform outputs for a single batch item into labeled boxes. Args: cls_scores (list[Tensor]): Box scores for a single scale level - with shape (num_anchors * num_classes, H, W). + with shape (N, num_anchors * num_classes, H, W). bbox_preds (list[Tensor]): Box energies / deltas for a single - scale level with shape (num_anchors * 4, H, W). + scale level with shape (N, num_anchors * 4, H, W). centernesses (list[Tensor]): Centerness for a single scale level - with shape (num_anchors * 1, H, W). + with shape (N, num_anchors * 1, H, W). mlvl_anchors (list[Tensor]): Box reference for a single scale level with shape (num_total_anchors, 4). - img_shape (tuple[int]): Shape of the input image, - (height, width, 3). - scale_factor (ndarray): Scale factor of the image arrange as + img_shapes (list[tuple[int]]): Shape of the input image, + list[(height, width, 3)]. + scale_factors (list[ndarray]): Scale factor of the image arrange as (w_scale, h_scale, w_scale, h_scale). cfg (mmcv.Config | None): Test / postprocessing configuration, if None, test_cfg would be used. @@ -410,64 +406,106 @@ class ATSSHead(AnchorHead): Default: True. Returns: - tuple(Tensor): - det_bboxes (Tensor): BBox predictions in shape (n, 5), where - the first 4 columns are bounding box positions - (tl_x, tl_y, br_x, br_y) and the 5-th column is a score - between 0 and 1. - det_labels (Tensor): A (n,) tensor where each item is the - predicted class label of the corresponding box. + list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. + The first item is an (n, 5) tensor, where 5 represent + (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1. + The shape of the second tensor in the tuple is (n,), and + each element represents the class label of the corresponding + box. """ assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) + device = cls_scores[0].device + batch_size = cls_scores[0].shape[0] + # convert to tensor to keep tracing + nms_pre_tensor = torch.tensor( + cfg.get('nms_pre', -1), device=device, dtype=torch.long) mlvl_bboxes = [] mlvl_scores = [] mlvl_centerness = [] for cls_score, bbox_pred, centerness, anchors in zip( cls_scores, bbox_preds, centernesses, mlvl_anchors): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] - - scores = cls_score.permute(1, 2, 0).reshape( - -1, self.cls_out_channels).sigmoid() - bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) - centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid() - - nms_pre = cfg.get('nms_pre', -1) - if nms_pre > 0 and scores.shape[0] > nms_pre: - max_scores, _ = (scores * centerness[:, None]).max(dim=1) + scores = cls_score.permute(0, 2, 3, 1).reshape( + batch_size, -1, self.cls_out_channels).sigmoid() + centerness = centerness.permute(0, 2, 3, + 1).reshape(batch_size, + -1).sigmoid() + bbox_pred = bbox_pred.permute(0, 2, 3, + 1).reshape(batch_size, -1, 4) + + # Always keep topk op for dynamic input in onnx + if nms_pre_tensor > 0 and (torch.onnx.is_in_onnx_export() + or scores.shape[-2] > nms_pre_tensor): + from torch import _shape_as_tensor + # keep shape as tensor and get k + num_anchor = _shape_as_tensor(scores)[-2].to(device) + nms_pre = torch.where(nms_pre_tensor < num_anchor, + nms_pre_tensor, num_anchor) + + max_scores, _ = (scores * centerness[..., None]).max(-1) _, topk_inds = max_scores.topk(nms_pre) anchors = anchors[topk_inds, :] - bbox_pred = bbox_pred[topk_inds, :] - scores = scores[topk_inds, :] - centerness = centerness[topk_inds] + batch_inds = torch.arange(batch_size).view( + -1, 1).expand_as(topk_inds).long() + bbox_pred = bbox_pred[batch_inds, topk_inds, :] + scores = scores[batch_inds, topk_inds, :] + centerness = centerness[batch_inds, topk_inds] + else: + anchors = anchors.expand_as(bbox_pred) bboxes = self.bbox_coder.decode( - anchors, bbox_pred, max_shape=img_shape) + anchors, bbox_pred, max_shape=img_shapes) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_centerness.append(centerness) - mlvl_bboxes = torch.cat(mlvl_bboxes) + batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1) if rescale: - mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) - mlvl_scores = torch.cat(mlvl_scores) - # Add a dummy background class to the backend when using sigmoid + batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor( + scale_factors).unsqueeze(1) + batch_mlvl_scores = torch.cat(mlvl_scores, dim=1) + batch_mlvl_centerness = torch.cat(mlvl_centerness, dim=1) + + # Set max number of box to be feed into nms in deployment + deploy_nms_pre = cfg.get('deploy_nms_pre', -1) + if deploy_nms_pre > 0 and torch.onnx.is_in_onnx_export(): + batch_mlvl_scores, _ = ( + batch_mlvl_scores * + batch_mlvl_centerness.unsqueeze(2).expand_as(batch_mlvl_scores) + ).max(-1) + _, topk_inds = batch_mlvl_scores.topk(deploy_nms_pre) + batch_inds = torch.arange(batch_size).view(-1, + 1).expand_as(topk_inds) + batch_mlvl_scores = batch_mlvl_scores[batch_inds, topk_inds, :] + batch_mlvl_bboxes = batch_mlvl_bboxes[batch_inds, topk_inds, :] + batch_mlvl_centerness = batch_mlvl_centerness[batch_inds, + topk_inds] # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 # BG cat_id: num_class - padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) - mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) - mlvl_centerness = torch.cat(mlvl_centerness) + padding = batch_mlvl_scores.new_zeros(batch_size, + batch_mlvl_scores.shape[1], 1) + batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1) if with_nms: - det_bboxes, det_labels = multiclass_nms( - mlvl_bboxes, - mlvl_scores, - cfg.score_thr, - cfg.nms, - cfg.max_per_img, - score_factors=mlvl_centerness) - return det_bboxes, det_labels + det_results = [] + for (mlvl_bboxes, mlvl_scores, + mlvl_centerness) in zip(batch_mlvl_bboxes, batch_mlvl_scores, + batch_mlvl_centerness): + det_bbox, det_label = multiclass_nms( + mlvl_bboxes, + mlvl_scores, + cfg.score_thr, + cfg.nms, + cfg.max_per_img, + score_factors=mlvl_centerness) + det_results.append(tuple([det_bbox, det_label])) else: - return mlvl_bboxes, mlvl_scores, mlvl_centerness + det_results = [ + tuple(mlvl_bs) + for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores, + batch_mlvl_centerness) + ] + return det_results def get_targets(self, anchor_list, diff --git a/mmdet/models/dense_heads/cascade_rpn_head.py b/mmdet/models/dense_heads/cascade_rpn_head.py index c01d048c7..210927935 100644 --- a/mmdet/models/dense_heads/cascade_rpn_head.py +++ b/mmdet/models/dense_heads/cascade_rpn_head.py @@ -1,9 +1,12 @@ from __future__ import division +import copy +import warnings import torch import torch.nn as nn +from mmcv import ConfigDict from mmcv.cnn import normal_init -from mmcv.ops import DeformConv2d +from mmcv.ops import DeformConv2d, batched_nms from mmdet.core import (RegionAssigner, build_assigner, build_sampler, images_to_levels, multi_apply) @@ -536,6 +539,133 @@ class StageCascadeRPNHead(RPNHead): new_anchor_list.append(mlvl_anchors) return new_anchor_list + # TODO: temporary plan + def _get_bboxes_single(self, + cls_scores, + bbox_preds, + mlvl_anchors, + img_shape, + scale_factor, + cfg, + rescale=False): + """Transform outputs for a single batch item into bbox predictions. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (num_anchors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (num_anchors * 4, H, W). + mlvl_anchors (list[Tensor]): Box reference for each scale level + with shape (num_total_anchors, 4). + img_shape (tuple[int]): Shape of the input image, + (height, width, 3). + scale_factor (ndarray): Scale factor of the image arange as + (w_scale, h_scale, w_scale, h_scale). + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + + Returns: + Tensor: Labeled boxes have the shape of (n,5), where the + first 4 columns are bounding box positions + (tl_x, tl_y, br_x, br_y) and the 5-th column is a score + between 0 and 1. + """ + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + # bboxes from different level should be independent during NMS, + # level_ids are used as labels for batched NMS to separate them + level_ids = [] + mlvl_scores = [] + mlvl_bbox_preds = [] + mlvl_valid_anchors = [] + for idx in range(len(cls_scores)): + rpn_cls_score = cls_scores[idx] + rpn_bbox_pred = bbox_preds[idx] + assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:] + rpn_cls_score = rpn_cls_score.permute(1, 2, 0) + if self.use_sigmoid_cls: + rpn_cls_score = rpn_cls_score.reshape(-1) + scores = rpn_cls_score.sigmoid() + else: + rpn_cls_score = rpn_cls_score.reshape(-1, 2) + # We set FG labels to [0, num_class-1] and BG label to + # num_class in RPN head since mmdet v2.5, which is unified to + # be consistent with other head since mmdet v2.0. In mmdet v2.0 + # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head. + scores = rpn_cls_score.softmax(dim=1)[:, 0] + rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4) + anchors = mlvl_anchors[idx] + if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre: + # sort is faster than topk + # _, topk_inds = scores.topk(cfg.nms_pre) + if torch.onnx.is_in_onnx_export(): + # sort op will be converted to TopK in onnx + # and k<=3480 in TensorRT + _, topk_inds = scores.topk(cfg.nms_pre) + scores = scores[topk_inds] + else: + ranked_scores, rank_inds = scores.sort(descending=True) + topk_inds = rank_inds[:cfg.nms_pre] + scores = ranked_scores[:cfg.nms_pre] + rpn_bbox_pred = rpn_bbox_pred[topk_inds, :] + anchors = anchors[topk_inds, :] + mlvl_scores.append(scores) + mlvl_bbox_preds.append(rpn_bbox_pred) + mlvl_valid_anchors.append(anchors) + level_ids.append( + scores.new_full((scores.size(0), ), idx, dtype=torch.long)) + + scores = torch.cat(mlvl_scores) + anchors = torch.cat(mlvl_valid_anchors) + rpn_bbox_pred = torch.cat(mlvl_bbox_preds) + proposals = self.bbox_coder.decode( + anchors, rpn_bbox_pred, max_shape=img_shape) + ids = torch.cat(level_ids) + + # Skip nonzero op while exporting to ONNX + if cfg.min_bbox_size > 0 and (not torch.onnx.is_in_onnx_export()): + w = proposals[:, 2] - proposals[:, 0] + h = proposals[:, 3] - proposals[:, 1] + valid_inds = torch.nonzero( + (w >= cfg.min_bbox_size) + & (h >= cfg.min_bbox_size), + as_tuple=False).squeeze() + if valid_inds.sum().item() != len(proposals): + proposals = proposals[valid_inds, :] + scores = scores[valid_inds] + ids = ids[valid_inds] + + # deprecate arguments warning + if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg: + warnings.warn( + 'In rpn_proposal or test_cfg, ' + 'nms_thr has been moved to a dict named nms as ' + 'iou_threshold, max_num has been renamed as max_per_img, ' + 'name of original arguments and the way to specify ' + 'iou_threshold of NMS will be deprecated.') + if 'nms' not in cfg: + cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr)) + if 'max_num' in cfg: + if 'max_per_img' in cfg: + assert cfg.max_num == cfg.max_per_img, f'You ' \ + f'set max_num and ' \ + f'max_per_img at the same time, but get {cfg.max_num} ' \ + f'and {cfg.max_per_img} respectively' \ + 'Please delete max_num which will be deprecated.' + else: + cfg.max_per_img = cfg.max_num + if 'nms_thr' in cfg: + assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set' \ + f' iou_threshold in nms and ' \ + f'nms_thr at the same time, but get' \ + f' {cfg.nms.iou_threshold} and {cfg.nms_thr}' \ + f' respectively. Please delete the nms_thr ' \ + f'which will be deprecated.' + + dets, keep = batched_nms(proposals, scores, ids, cfg.nms) + return dets[:cfg.max_per_img] + @HEADS.register_module() class CascadeRPNHead(BaseDenseHead): diff --git a/mmdet/models/dense_heads/dense_test_mixins.py b/mmdet/models/dense_heads/dense_test_mixins.py index a07c9d423..dd81364de 100644 --- a/mmdet/models/dense_heads/dense_test_mixins.py +++ b/mmdet/models/dense_heads/dense_test_mixins.py @@ -54,7 +54,10 @@ class BBoxTestMixin(object): # check with_nms argument gb_sig = signature(self.get_bboxes) gb_args = [p.name for p in gb_sig.parameters.values()] - gbs_sig = signature(self._get_bboxes_single) + if hasattr(self, '_get_bboxes'): + gbs_sig = signature(self._get_bboxes) + else: + gbs_sig = signature(self._get_bboxes_single) gbs_args = [p.name for p in gbs_sig.parameters.values()] assert ('with_nms' in gb_args) and ('with_nms' in gbs_args), \ f'{self.__class__.__name__}' \ diff --git a/mmdet/models/dense_heads/fcos_head.py b/mmdet/models/dense_heads/fcos_head.py index c2b9dc59b..284742b48 100644 --- a/mmdet/models/dense_heads/fcos_head.py +++ b/mmdet/models/dense_heads/fcos_head.py @@ -282,11 +282,11 @@ class FCOSHead(AnchorFreeHead): Returns: list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. - The first item is an (n, 5) tensor, where the first 4 columns - are bounding box positions (tl_x, tl_y, br_x, br_y) and the - 5-th column is a score between 0 and 1. The second item is a - (n,) tensor where each item is the predicted class label of the - corresponding box. + The first item is an (n, 5) tensor, where 5 represent + (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1. + The shape of the second tensor in the tuple is (n,), and + each element represents the class label of the corresponding + box. """ assert len(cls_scores) == len(bbox_preds) num_levels = len(cls_scores) @@ -294,49 +294,55 @@ class FCOSHead(AnchorFreeHead): featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] mlvl_points = self.get_points(featmap_sizes, bbox_preds[0].dtype, bbox_preds[0].device) - result_list = [] - for img_id in range(len(img_metas)): - cls_score_list = [ - cls_scores[i][img_id].detach() for i in range(num_levels) - ] - bbox_pred_list = [ - bbox_preds[i][img_id].detach() for i in range(num_levels) - ] - centerness_pred_list = [ - centernesses[i][img_id].detach() for i in range(num_levels) + + cls_score_list = [cls_scores[i].detach() for i in range(num_levels)] + bbox_pred_list = [bbox_preds[i].detach() for i in range(num_levels)] + centerness_pred_list = [ + centernesses[i].detach() for i in range(num_levels) + ] + if torch.onnx.is_in_onnx_export(): + assert len( + img_metas + ) == 1, 'Only support one input image while in exporting to ONNX' + img_shapes = img_metas[0]['img_shape_for_onnx'] + else: + img_shapes = [ + img_metas[i]['img_shape'] + for i in range(cls_scores[0].shape[0]) ] - img_shape = img_metas[img_id]['img_shape'] - scale_factor = img_metas[img_id]['scale_factor'] - det_bboxes = self._get_bboxes_single( - cls_score_list, bbox_pred_list, centerness_pred_list, - mlvl_points, img_shape, scale_factor, cfg, rescale, with_nms) - result_list.append(det_bboxes) + scale_factors = [ + img_metas[i]['scale_factor'] for i in range(cls_scores[0].shape[0]) + ] + result_list = self._get_bboxes(cls_score_list, bbox_pred_list, + centerness_pred_list, mlvl_points, + img_shapes, scale_factors, cfg, rescale, + with_nms) return result_list - def _get_bboxes_single(self, - cls_scores, - bbox_preds, - centernesses, - mlvl_points, - img_shape, - scale_factor, - cfg, - rescale=False, - with_nms=True): + def _get_bboxes(self, + cls_scores, + bbox_preds, + centernesses, + mlvl_points, + img_shapes, + scale_factors, + cfg, + rescale=False, + with_nms=True): """Transform outputs for a single batch item into bbox predictions. Args: cls_scores (list[Tensor]): Box scores for a single scale level - with shape (num_points * num_classes, H, W). + with shape (N, num_points * num_classes, H, W). bbox_preds (list[Tensor]): Box energies / deltas for a single scale - level with shape (num_points * 4, H, W). + level with shape (N, num_points * 4, H, W). centernesses (list[Tensor]): Centerness for a single scale level - with shape (num_points * 4, H, W). + with shape (N, num_points * 4, H, W). mlvl_points (list[Tensor]): Box reference for a single scale level with shape (num_total_points, 4). - img_shape (tuple[int]): Shape of the input image, - (height, width, 3). - scale_factor (ndarray): Scale factor of the image arrange as + img_shapes (list[tuple[int]]): Shape of the input image, + list[(height, width, 3)]. + scale_factors (list[ndarray]): Scale factor of the image arrange as (w_scale, h_scale, w_scale, h_scale). cfg (mmcv.Config | None): Test / postprocessing configuration, if None, test_cfg would be used. @@ -356,59 +362,96 @@ class FCOSHead(AnchorFreeHead): """ cfg = self.test_cfg if cfg is None else cfg assert len(cls_scores) == len(bbox_preds) == len(mlvl_points) + device = cls_scores[0].device + batch_size = cls_scores[0].shape[0] + # convert to tensor to keep tracing + nms_pre_tensor = torch.tensor( + cfg.get('nms_pre', -1), device=device, dtype=torch.long) mlvl_bboxes = [] mlvl_scores = [] mlvl_centerness = [] for cls_score, bbox_pred, centerness, points in zip( cls_scores, bbox_preds, centernesses, mlvl_points): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] - scores = cls_score.permute(1, 2, 0).reshape( - -1, self.cls_out_channels).sigmoid() - centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid() - - bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) - nms_pre = cfg.get('nms_pre', -1) - if nms_pre > 0 and scores.shape[0] > nms_pre: - max_scores, _ = (scores * centerness[:, None]).max(dim=1) + scores = cls_score.permute(0, 2, 3, 1).reshape( + batch_size, -1, self.cls_out_channels).sigmoid() + centerness = centerness.permute(0, 2, 3, + 1).reshape(batch_size, + -1).sigmoid() + + bbox_pred = bbox_pred.permute(0, 2, 3, + 1).reshape(batch_size, -1, 4) + # Always keep topk op for dynamic input in onnx + if nms_pre_tensor > 0 and (torch.onnx.is_in_onnx_export() + or scores.shape[-2] > nms_pre_tensor): + from torch import _shape_as_tensor + # keep shape as tensor and get k + num_anchor = _shape_as_tensor(scores)[-2].to(device) + nms_pre = torch.where(nms_pre_tensor < num_anchor, + nms_pre_tensor, num_anchor) + + max_scores, _ = (scores * centerness[..., None]).max(-1) _, topk_inds = max_scores.topk(nms_pre) points = points[topk_inds, :] - bbox_pred = bbox_pred[topk_inds, :] - scores = scores[topk_inds, :] - centerness = centerness[topk_inds] - bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape) + batch_inds = torch.arange(batch_size).view( + -1, 1).expand_as(topk_inds).long() + bbox_pred = bbox_pred[batch_inds, topk_inds, :] + scores = scores[batch_inds, topk_inds, :] + centerness = centerness[batch_inds, topk_inds] + + bboxes = distance2bbox(points, bbox_pred, max_shape=img_shapes) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_centerness.append(centerness) - mlvl_bboxes = torch.cat(mlvl_bboxes) + + batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1) if rescale: - mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) - mlvl_scores = torch.cat(mlvl_scores) - mlvl_centerness = torch.cat(mlvl_centerness) + batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor( + scale_factors).unsqueeze(1) + batch_mlvl_scores = torch.cat(mlvl_scores, dim=1) + batch_mlvl_centerness = torch.cat(mlvl_centerness, dim=1) # Set max number of box to be feed into nms in deployment deploy_nms_pre = cfg.get('deploy_nms_pre', -1) if deploy_nms_pre > 0 and torch.onnx.is_in_onnx_export(): - max_scores, _ = (mlvl_scores * mlvl_centerness[:, None]).max(dim=1) - _, topk_inds = max_scores.topk(deploy_nms_pre) - mlvl_scores = mlvl_scores[topk_inds, :] - mlvl_bboxes = mlvl_bboxes[topk_inds, :] - mlvl_centerness = mlvl_centerness[topk_inds] - padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) + batch_mlvl_scores, _ = ( + batch_mlvl_scores * + batch_mlvl_centerness.unsqueeze(2).expand_as(batch_mlvl_scores) + ).max(-1) + _, topk_inds = batch_mlvl_scores.topk(deploy_nms_pre) + batch_inds = torch.arange(batch_mlvl_scores.shape[0]).view( + -1, 1).expand_as(topk_inds) + batch_mlvl_scores = batch_mlvl_scores[batch_inds, topk_inds, :] + batch_mlvl_bboxes = batch_mlvl_bboxes[batch_inds, topk_inds, :] + batch_mlvl_centerness = batch_mlvl_centerness[batch_inds, + topk_inds] + # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 # BG cat_id: num_class - mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) + padding = batch_mlvl_scores.new_zeros(batch_size, + batch_mlvl_scores.shape[1], 1) + batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1) if with_nms: - det_bboxes, det_labels = multiclass_nms( - mlvl_bboxes, - mlvl_scores, - cfg.score_thr, - cfg.nms, - cfg.max_per_img, - score_factors=mlvl_centerness) - return det_bboxes, det_labels + det_results = [] + for (mlvl_bboxes, mlvl_scores, + mlvl_centerness) in zip(batch_mlvl_bboxes, batch_mlvl_scores, + batch_mlvl_centerness): + det_bbox, det_label = multiclass_nms( + mlvl_bboxes, + mlvl_scores, + cfg.score_thr, + cfg.nms, + cfg.max_per_img, + score_factors=mlvl_centerness) + det_results.append(tuple([det_bbox, det_label])) else: - return mlvl_bboxes, mlvl_scores, mlvl_centerness + det_results = [ + tuple(mlvl_bs) + for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores, + batch_mlvl_centerness) + ] + return det_results def _get_points_single(self, featmap_size, diff --git a/mmdet/models/dense_heads/gfl_head.py b/mmdet/models/dense_heads/gfl_head.py index 0b59a6e7d..80b647bc3 100644 --- a/mmdet/models/dense_heads/gfl_head.py +++ b/mmdet/models/dense_heads/gfl_head.py @@ -202,8 +202,8 @@ class GFLHead(AnchorHead): Returns: Tensor: Anchor centers with shape (N, 2), "xy" format. """ - anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2 - anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2 + anchors_cx = (anchors[..., 2] + anchors[..., 0]) / 2 + anchors_cy = (anchors[..., 3] + anchors[..., 1]) / 2 return torch.stack([anchors_cx, anchors_cy], dim=-1) def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights, @@ -368,28 +368,28 @@ class GFLHead(AnchorHead): return dict( loss_cls=losses_cls, loss_bbox=losses_bbox, loss_dfl=losses_dfl) - def _get_bboxes_single(self, - cls_scores, - bbox_preds, - mlvl_anchors, - img_shape, - scale_factor, - cfg, - rescale=False, - with_nms=True): + def _get_bboxes(self, + cls_scores, + bbox_preds, + mlvl_anchors, + img_shapes, + scale_factors, + cfg, + rescale=False, + with_nms=True): """Transform outputs for a single batch item into labeled boxes. Args: cls_scores (list[Tensor]): Box scores for a single scale level - has shape (num_classes, H, W). + has shape (N, num_classes, H, W). bbox_preds (list[Tensor]): Box distribution logits for a single - scale level with shape (4*(n+1), H, W), n is max value of + scale level with shape (N, 4*(n+1), H, W), n is max value of integral set. mlvl_anchors (list[Tensor]): Box reference for a single scale level with shape (num_total_anchors, 4). - img_shape (tuple[int]): Shape of the input image, - (height, width, 3). - scale_factor (ndarray): Scale factor of the image arange as + img_shapes (list[tuple[int]]): Shape of the input image, + list[(height, width, 3)]. + scale_factors (list[ndarray]): Scale factor of the image arange as (w_scale, h_scale, w_scale, h_scale). cfg (mmcv.Config | None): Test / postprocessing configuration, if None, test_cfg would be used. @@ -399,16 +399,17 @@ class GFLHead(AnchorHead): Default: True. Returns: - tuple(Tensor): - det_bboxes (Tensor): Bbox predictions in shape (N, 5), where - the first 4 columns are bounding box positions - (tl_x, tl_y, br_x, br_y) and the 5-th column is a score - between 0 and 1. - det_labels (Tensor): A (N,) tensor where each item is the - predicted class label of the corresponding box. + list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. + The first item is an (n, 5) tensor, where 5 represent + (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1. + The shape of the second tensor in the tuple is (n,), and + each element represents the class label of the corresponding + box. """ cfg = self.test_cfg if cfg is None else cfg assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) + batch_size = cls_scores[0].shape[0] + mlvl_bboxes = [] mlvl_scores = [] for cls_score, bbox_pred, stride, anchors in zip( @@ -416,43 +417,57 @@ class GFLHead(AnchorHead): mlvl_anchors): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] assert stride[0] == stride[1] + scores = cls_score.permute(0, 2, 3, 1).reshape( + batch_size, -1, self.cls_out_channels).sigmoid() + bbox_pred = bbox_pred.permute(0, 2, 3, 1) - scores = cls_score.permute(1, 2, 0).reshape( - -1, self.cls_out_channels).sigmoid() - bbox_pred = bbox_pred.permute(1, 2, 0) bbox_pred = self.integral(bbox_pred) * stride[0] + bbox_pred = bbox_pred.reshape(batch_size, -1, 4) nms_pre = cfg.get('nms_pre', -1) - if nms_pre > 0 and scores.shape[0] > nms_pre: - max_scores, _ = scores.max(dim=1) + if nms_pre > 0 and scores.shape[1] > nms_pre: + max_scores, _ = scores.max(-1) _, topk_inds = max_scores.topk(nms_pre) + batch_inds = torch.arange(batch_size).view( + -1, 1).expand_as(topk_inds).long() anchors = anchors[topk_inds, :] - bbox_pred = bbox_pred[topk_inds, :] - scores = scores[topk_inds, :] + bbox_pred = bbox_pred[batch_inds, topk_inds, :] + scores = scores[batch_inds, topk_inds, :] + else: + anchors = anchors.expand_as(bbox_pred) bboxes = distance2bbox( - self.anchor_center(anchors), bbox_pred, max_shape=img_shape) + self.anchor_center(anchors), bbox_pred, max_shape=img_shapes) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) - mlvl_bboxes = torch.cat(mlvl_bboxes) + batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1) if rescale: - mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) + batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor( + scale_factors).unsqueeze(1) - mlvl_scores = torch.cat(mlvl_scores) + batch_mlvl_scores = torch.cat(mlvl_scores, dim=1) # Add a dummy background class to the backend when using sigmoid # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 # BG cat_id: num_class - padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) - mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) + padding = batch_mlvl_scores.new_zeros(batch_size, + batch_mlvl_scores.shape[1], 1) + batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1) if with_nms: - det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores, - cfg.score_thr, cfg.nms, - cfg.max_per_img) - return det_bboxes, det_labels + det_results = [] + for (mlvl_bboxes, mlvl_scores) in zip(batch_mlvl_bboxes, + batch_mlvl_scores): + det_bbox, det_label = multiclass_nms(mlvl_bboxes, mlvl_scores, + cfg.score_thr, cfg.nms, + cfg.max_per_img) + det_results.append(tuple([det_bbox, det_label])) else: - return mlvl_bboxes, mlvl_scores + det_results = [ + tuple(mlvl_bs) + for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores) + ] + return det_results def get_targets(self, anchor_list, diff --git a/mmdet/models/dense_heads/paa_head.py b/mmdet/models/dense_heads/paa_head.py index 3ef0ac67c..4edc3ef57 100644 --- a/mmdet/models/dense_heads/paa_head.py +++ b/mmdet/models/dense_heads/paa_head.py @@ -516,25 +516,27 @@ class PAAHead(ATSSHead): label_channels=1, unmap_outputs=True) - def _get_bboxes_single(self, - cls_scores, - bbox_preds, - iou_preds, - mlvl_anchors, - img_shape, - scale_factor, - cfg, - rescale=False, - with_nms=True): + def _get_bboxes(self, + cls_scores, + bbox_preds, + iou_preds, + mlvl_anchors, + img_shapes, + scale_factors, + cfg, + rescale=False, + with_nms=True): """Transform outputs for a single batch item into labeled boxes. - This method is almost same as `ATSSHead._get_bboxes_single()`. + This method is almost same as `ATSSHead._get_bboxes()`. We use sqrt(iou_preds * cls_scores) in NMS process instead of just cls_scores. Besides, score voting is used when `` score_voting`` is set to True. """ assert with_nms, 'PAA only supports "with_nms=True" now' assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) + batch_size = cls_scores[0].shape[0] + mlvl_bboxes = [] mlvl_scores = [] mlvl_iou_preds = [] @@ -542,50 +544,64 @@ class PAAHead(ATSSHead): cls_scores, bbox_preds, iou_preds, mlvl_anchors): assert cls_score.size()[-2:] == bbox_pred.size()[-2:] - scores = cls_score.permute(1, 2, 0).reshape( - -1, self.cls_out_channels).sigmoid() - bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) - iou_preds = iou_preds.permute(1, 2, 0).reshape(-1).sigmoid() + scores = cls_score.permute(0, 2, 3, 1).reshape( + batch_size, -1, self.cls_out_channels).sigmoid() + bbox_pred = bbox_pred.permute(0, 2, 3, + 1).reshape(batch_size, -1, 4) + iou_preds = iou_preds.permute(0, 2, 3, 1).reshape(batch_size, + -1).sigmoid() + nms_pre = cfg.get('nms_pre', -1) - if nms_pre > 0 and scores.shape[0] > nms_pre: - max_scores, _ = (scores * iou_preds[:, None]).sqrt().max(dim=1) + if nms_pre > 0 and scores.shape[1] > nms_pre: + max_scores, _ = (scores * iou_preds[..., None]).sqrt().max(-1) _, topk_inds = max_scores.topk(nms_pre) + batch_inds = torch.arange(batch_size).view( + -1, 1).expand_as(topk_inds).long() anchors = anchors[topk_inds, :] - bbox_pred = bbox_pred[topk_inds, :] - scores = scores[topk_inds, :] - iou_preds = iou_preds[topk_inds] + bbox_pred = bbox_pred[batch_inds, topk_inds, :] + scores = scores[batch_inds, topk_inds, :] + iou_preds = iou_preds[batch_inds, topk_inds] + else: + anchors = anchors.expand_as(bbox_pred) bboxes = self.bbox_coder.decode( - anchors, bbox_pred, max_shape=img_shape) + anchors, bbox_pred, max_shape=img_shapes) mlvl_bboxes.append(bboxes) mlvl_scores.append(scores) mlvl_iou_preds.append(iou_preds) - mlvl_bboxes = torch.cat(mlvl_bboxes) + batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1) if rescale: - mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) - mlvl_scores = torch.cat(mlvl_scores) + batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor( + scale_factors).unsqueeze(1) + batch_mlvl_scores = torch.cat(mlvl_scores, dim=1) # Add a dummy background class to the backend when using sigmoid # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 # BG cat_id: num_class - padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) - mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) - mlvl_iou_preds = torch.cat(mlvl_iou_preds) - mlvl_nms_scores = (mlvl_scores * mlvl_iou_preds[:, None]).sqrt() - det_bboxes, det_labels = multiclass_nms( - mlvl_bboxes, - mlvl_nms_scores, - cfg.score_thr, - cfg.nms, - cfg.max_per_img, - score_factors=None) - if self.with_score_voting and len(det_bboxes) > 0: - det_bboxes, det_labels = self.score_voting(det_bboxes, det_labels, - mlvl_bboxes, - mlvl_nms_scores, - cfg.score_thr) - - return det_bboxes, det_labels + padding = batch_mlvl_scores.new_zeros(batch_size, + batch_mlvl_scores.shape[1], 1) + batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1) + batch_mlvl_iou_preds = torch.cat(mlvl_iou_preds, dim=1) + batch_mlvl_nms_scores = (batch_mlvl_scores * + batch_mlvl_iou_preds[..., None]).sqrt() + + det_results = [] + for (mlvl_bboxes, mlvl_scores) in zip(batch_mlvl_bboxes, + batch_mlvl_nms_scores): + det_bbox, det_label = multiclass_nms( + mlvl_bboxes, + mlvl_scores, + cfg.score_thr, + cfg.nms, + cfg.max_per_img, + score_factors=None) + if self.with_score_voting and len(det_bbox) > 0: + det_bbox, det_label = self.score_voting( + det_bbox, det_label, mlvl_bboxes, mlvl_scores, + cfg.score_thr) + det_results.append(tuple([det_bbox, det_label])) + + return det_results def score_voting(self, det_bboxes, det_labels, mlvl_bboxes, mlvl_nms_scores, score_thr): @@ -602,7 +618,7 @@ class PAAHead(ATSSHead): with shape (num_anchors,4). mlvl_nms_scores (Tensor): The scores of all boxes which is used in the NMS procedure, with shape (num_anchors, num_class) - mlvl_iou_preds (Tensot): The predictions of IOU of all boxes + mlvl_iou_preds (Tensor): The predictions of IOU of all boxes before the NMS procedure, with shape (num_anchors, 1) score_thr (float): The score threshold of bboxes. diff --git a/mmdet/models/dense_heads/rpn_head.py b/mmdet/models/dense_heads/rpn_head.py index da8b5f63d..a888cb8c1 100644 --- a/mmdet/models/dense_heads/rpn_head.py +++ b/mmdet/models/dense_heads/rpn_head.py @@ -79,35 +79,38 @@ class RPNHead(RPNTestMixin, AnchorHead): return dict( loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox']) - def _get_bboxes_single(self, - cls_scores, - bbox_preds, - mlvl_anchors, - img_shape, - scale_factor, - cfg, - rescale=False): + def _get_bboxes(self, + cls_scores, + bbox_preds, + mlvl_anchors, + img_shapes, + scale_factors, + cfg, + rescale=False): """Transform outputs for a single batch item into bbox predictions. Args: cls_scores (list[Tensor]): Box scores for each scale level - Has shape (num_anchors * num_classes, H, W). + Has shape (N, num_anchors * num_classes, H, W). bbox_preds (list[Tensor]): Box energies / deltas for each scale - level with shape (num_anchors * 4, H, W). + level with shape (N, num_anchors * 4, H, W). mlvl_anchors (list[Tensor]): Box reference for each scale level with shape (num_total_anchors, 4). - img_shape (tuple[int]): Shape of the input image, + img_shapes (list[tuple[int]]): Shape of the input image, (height, width, 3). - scale_factor (ndarray): Scale factor of the image arange as + scale_factors (list[ndarray]): Scale factor of the image arange as (w_scale, h_scale, w_scale, h_scale). cfg (mmcv.Config): Test / postprocessing configuration, if None, test_cfg would be used. rescale (bool): If True, return boxes in original image space. Returns: - Tensor: Labeled boxes in shape (n, 5), where the first 4 columns + list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. + The first item is an (n, 5) tensor, where the first 4 columns are bounding box positions (tl_x, tl_y, br_x, br_y) and the - 5-th column is a score between 0 and 1. + 5-th column is a score between 0 and 1. The second item is a + (n,) tensor where each item is the predicted class labelof the + corresponding box. """ cfg = self.test_cfg if cfg is None else cfg cfg = copy.deepcopy(cfg) @@ -117,26 +120,29 @@ class RPNHead(RPNTestMixin, AnchorHead): mlvl_scores = [] mlvl_bbox_preds = [] mlvl_valid_anchors = [] + batch_size = cls_scores[0].shape[0] nms_pre_tensor = torch.tensor( cfg.nms_pre, device=cls_scores[0].device, dtype=torch.long) for idx in range(len(cls_scores)): rpn_cls_score = cls_scores[idx] rpn_bbox_pred = bbox_preds[idx] assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:] - rpn_cls_score = rpn_cls_score.permute(1, 2, 0) + rpn_cls_score = rpn_cls_score.permute(0, 2, 3, 1) if self.use_sigmoid_cls: - rpn_cls_score = rpn_cls_score.reshape(-1) + rpn_cls_score = rpn_cls_score.reshape(batch_size, -1) scores = rpn_cls_score.sigmoid() else: - rpn_cls_score = rpn_cls_score.reshape(-1, 2) + rpn_cls_score = rpn_cls_score.reshape(batch_size, -1, 2) # We set FG labels to [0, num_class-1] and BG label to # num_class in RPN head since mmdet v2.5, which is unified to # be consistent with other head since mmdet v2.0. In mmdet v2.0 # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head. - scores = rpn_cls_score.softmax(dim=1)[:, 0] - rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4) + scores = rpn_cls_score.softmax(-1)[..., 0] + rpn_bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1).reshape( + batch_size, -1, 4) anchors = mlvl_anchors[idx] - if cfg.nms_pre > 0: + anchors = anchors.expand_as(rpn_bbox_pred) + if nms_pre_tensor > 0: # sort is faster than topk # _, topk_inds = scores.topk(cfg.nms_pre) # keep topk op for dynamic k in onnx model @@ -144,43 +150,41 @@ class RPNHead(RPNTestMixin, AnchorHead): # sort op will be converted to TopK in onnx # and k<=3480 in TensorRT scores_shape = torch._shape_as_tensor(scores) - nms_pre = torch.where(scores_shape[0] < nms_pre_tensor, - scores_shape[0], nms_pre_tensor) + nms_pre = torch.where(scores_shape[1] < nms_pre_tensor, + scores_shape[1], nms_pre_tensor) _, topk_inds = scores.topk(nms_pre) - scores = scores[topk_inds] - rpn_bbox_pred = rpn_bbox_pred[topk_inds, :] - anchors = anchors[topk_inds, :] - elif scores.shape[0] > cfg.nms_pre: + batch_inds = torch.arange(batch_size).view( + -1, 1).expand_as(topk_inds) + scores = scores[batch_inds, topk_inds] + rpn_bbox_pred = rpn_bbox_pred[batch_inds, topk_inds, :] + anchors = anchors[batch_inds, topk_inds, :] + + elif scores.shape[-1] > cfg.nms_pre: ranked_scores, rank_inds = scores.sort(descending=True) - topk_inds = rank_inds[:cfg.nms_pre] - scores = ranked_scores[:cfg.nms_pre] - rpn_bbox_pred = rpn_bbox_pred[topk_inds, :] - anchors = anchors[topk_inds, :] + topk_inds = rank_inds[:, :cfg.nms_pre] + scores = ranked_scores[:, :cfg.nms_pre] + batch_inds = torch.arange(batch_size).view( + -1, 1).expand_as(topk_inds) + rpn_bbox_pred = rpn_bbox_pred[batch_inds, topk_inds, :] + anchors = anchors[batch_inds, topk_inds, :] + mlvl_scores.append(scores) mlvl_bbox_preds.append(rpn_bbox_pred) mlvl_valid_anchors.append(anchors) level_ids.append( - scores.new_full((scores.size(0), ), idx, dtype=torch.long)) - - scores = torch.cat(mlvl_scores) - anchors = torch.cat(mlvl_valid_anchors) - rpn_bbox_pred = torch.cat(mlvl_bbox_preds) - proposals = self.bbox_coder.decode( - anchors, rpn_bbox_pred, max_shape=img_shape) - ids = torch.cat(level_ids) - - # Skip nonzero op while exporting to ONNX - if cfg.min_bbox_size > 0 and (not torch.onnx.is_in_onnx_export()): - w = proposals[:, 2] - proposals[:, 0] - h = proposals[:, 3] - proposals[:, 1] - valid_inds = torch.nonzero( - (w >= cfg.min_bbox_size) - & (h >= cfg.min_bbox_size), - as_tuple=False).squeeze() - if valid_inds.sum().item() != len(proposals): - proposals = proposals[valid_inds, :] - scores = scores[valid_inds] - ids = ids[valid_inds] + scores.new_full(( + batch_size, + scores.size(1), + ), + idx, + dtype=torch.long)) + + batch_mlvl_scores = torch.cat(mlvl_scores, dim=1) + batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1) + batch_mlvl_rpn_bbox_pred = torch.cat(mlvl_bbox_preds, dim=1) + batch_mlvl_proposals = self.bbox_coder.decode( + batch_mlvl_anchors, batch_mlvl_rpn_bbox_pred, max_shape=img_shapes) + batch_mlvl_ids = torch.cat(level_ids, dim=1) # deprecate arguments warning if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg: @@ -209,5 +213,24 @@ class RPNHead(RPNTestMixin, AnchorHead): f' respectively. Please delete the nms_thr ' \ f'which will be deprecated.' - dets, keep = batched_nms(proposals, scores, ids, cfg.nms) - return dets[:cfg.max_per_img] + result_list = [] + for (mlvl_proposals, mlvl_scores, + mlvl_ids) in zip(batch_mlvl_proposals, batch_mlvl_scores, + batch_mlvl_ids): + # Skip nonzero op while exporting to ONNX + if cfg.min_bbox_size > 0 and (not torch.onnx.is_in_onnx_export()): + w = mlvl_proposals[:, 2] - mlvl_proposals[:, 0] + h = mlvl_proposals[:, 3] - mlvl_proposals[:, 1] + valid_ind = torch.nonzero( + (w >= cfg.min_bbox_size) + & (h >= cfg.min_bbox_size), + as_tuple=False).squeeze() + if valid_ind.sum().item() != len(mlvl_proposals): + mlvl_proposals = mlvl_proposals[valid_ind, :] + mlvl_scores = mlvl_scores[valid_ind] + mlvl_ids = mlvl_ids[valid_ind] + + dets, keep = batched_nms(mlvl_proposals, mlvl_scores, mlvl_ids, + cfg.nms) + result_list.append(dets[:cfg.max_per_img]) + return result_list diff --git a/mmdet/models/dense_heads/yolo_head.py b/mmdet/models/dense_heads/yolo_head.py index 83346ab10..25a005d36 100644 --- a/mmdet/models/dense_heads/yolo_head.py +++ b/mmdet/models/dense_heads/yolo_head.py @@ -191,36 +191,34 @@ class YOLOV3Head(BaseDenseHead, BBoxTestMixin): Returns: list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. - The first item is an (n, 5) tensor, where the first 4 columns - are bounding box positions (tl_x, tl_y, br_x, br_y) and the - 5-th column is a score between 0 and 1. The second item is a - (n,) tensor where each item is the predicted class label of the - corresponding box. + The first item is an (n, 5) tensor, where 5 represent + (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1. + The shape of the second tensor in the tuple is (n,), and + each element represents the class label of the corresponding + box. """ - result_list = [] num_levels = len(pred_maps) - for img_id in range(len(img_metas)): - pred_maps_list = [ - pred_maps[i][img_id].detach() for i in range(num_levels) - ] - scale_factor = img_metas[img_id]['scale_factor'] - proposals = self._get_bboxes_single(pred_maps_list, scale_factor, - cfg, rescale, with_nms) - result_list.append(proposals) + pred_maps_list = [pred_maps[i].detach() for i in range(num_levels)] + scale_factors = [ + img_metas[i]['scale_factor'] + for i in range(pred_maps_list[0].shape[0]) + ] + result_list = self._get_bboxes(pred_maps_list, scale_factors, cfg, + rescale, with_nms) return result_list - def _get_bboxes_single(self, - pred_maps_list, - scale_factor, - cfg, - rescale=False, - with_nms=True): + def _get_bboxes(self, + pred_maps_list, + scale_factors, + cfg, + rescale=False, + with_nms=True): """Transform outputs for a single batch item into bbox predictions. Args: pred_maps_list (list[Tensor]): Prediction maps for different scales of each single image in the batch. - scale_factor (ndarray): Scale factor of the image arrange as + scale_factors (list(ndarray)): Scale factor of the image arrange as (w_scale, h_scale, w_scale, h_scale). cfg (mmcv.Config | None): Test / postprocessing configuration, if None, test_cfg would be used. @@ -230,62 +228,71 @@ class YOLOV3Head(BaseDenseHead, BBoxTestMixin): Default: True. Returns: - tuple(Tensor): - det_bboxes (Tensor): BBox predictions in shape (n, 5), where - the first 4 columns are bounding box positions - (tl_x, tl_y, br_x, br_y) and the 5-th column is a score - between 0 and 1. - det_labels (Tensor): A (n,) tensor where each item is the - predicted class label of the corresponding box. + list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple. + The first item is an (n, 5) tensor, where 5 represent + (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1. + The shape of the second tensor in the tuple is (n,), and + each element represents the class label of the corresponding + box. """ cfg = self.test_cfg if cfg is None else cfg assert len(pred_maps_list) == self.num_levels - multi_lvl_bboxes = [] - multi_lvl_cls_scores = [] - multi_lvl_conf_scores = [] - num_levels = len(pred_maps_list) + + device = pred_maps_list[0].device + batch_size = pred_maps_list[0].shape[0] + featmap_sizes = [ - pred_maps_list[i].shape[-2:] for i in range(num_levels) + pred_maps_list[i].shape[-2:] for i in range(self.num_levels) ] multi_lvl_anchors = self.anchor_generator.grid_anchors( - featmap_sizes, pred_maps_list[0][0].device) + featmap_sizes, device) + # convert to tensor to keep tracing + nms_pre_tensor = torch.tensor( + cfg.get('nms_pre', -1), device=device, dtype=torch.long) + + multi_lvl_bboxes = [] + multi_lvl_cls_scores = [] + multi_lvl_conf_scores = [] for i in range(self.num_levels): # get some key info for current scale pred_map = pred_maps_list[i] stride = self.featmap_strides[i] - - # (h, w, num_anchors*num_attrib) -> (h*w*num_anchors, num_attrib) - pred_map = pred_map.permute(1, 2, 0).reshape(-1, self.num_attrib) - - pred_map[..., :2] = torch.sigmoid(pred_map[..., :2]) - bbox_pred = self.bbox_coder.decode(multi_lvl_anchors[i], - pred_map[..., :4], stride) + # (b,h, w, num_anchors*num_attrib) -> + # (b,h*w*num_anchors, num_attrib) + pred_map = pred_map.permute(0, 2, 3, + 1).reshape(batch_size, -1, + self.num_attrib) + # Inplace operation like + # ```pred_map[..., :2] = \torch.sigmoid(pred_map[..., :2])``` + # would create constant tensor when exporting to onnx + pred_map_conf = torch.sigmoid(pred_map[..., :2]) + pred_map_rest = pred_map[..., 2:] + pred_map = torch.cat([pred_map_conf, pred_map_rest], dim=-1) + pred_map_boxes = pred_map[..., :4] + multi_lvl_anchor = multi_lvl_anchors[i] + multi_lvl_anchor = multi_lvl_anchor.expand_as(pred_map_boxes) + bbox_pred = self.bbox_coder.decode(multi_lvl_anchor, + pred_map_boxes, stride) # conf and cls - conf_pred = torch.sigmoid(pred_map[..., 4]).view(-1) + conf_pred = torch.sigmoid(pred_map[..., 4]) cls_pred = torch.sigmoid(pred_map[..., 5:]).view( - -1, self.num_classes) # Cls pred one-hot. - - # Filtering out all predictions with conf < conf_thr - conf_thr = cfg.get('conf_thr', -1) - if conf_thr > 0 and (not torch.onnx.is_in_onnx_export()): - # TensorRT not support NonZero - # add as_tuple=False for compatibility in Pytorch 1.6 - # flatten would create a Reshape op with constant values, - # and raise RuntimeError when doing inference in ONNX Runtime - # with a different input image (#4221). - conf_inds = conf_pred.ge(conf_thr).nonzero( - as_tuple=False).squeeze(1) - bbox_pred = bbox_pred[conf_inds, :] - cls_pred = cls_pred[conf_inds, :] - conf_pred = conf_pred[conf_inds] + batch_size, -1, self.num_classes) # Cls pred one-hot. # Get top-k prediction - nms_pre = cfg.get('nms_pre', -1) - if 0 < nms_pre < conf_pred.size(0): + # Always keep topk op for dynamic input in onnx + if nms_pre_tensor > 0 and (torch.onnx.is_in_onnx_export() + or conf_pred.shape[1] > nms_pre_tensor): + from torch import _shape_as_tensor + # keep shape as tensor and get k + num_anchor = _shape_as_tensor(conf_pred)[1].to(device) + nms_pre = torch.where(nms_pre_tensor < num_anchor, + nms_pre_tensor, num_anchor) _, topk_inds = conf_pred.topk(nms_pre) - bbox_pred = bbox_pred[topk_inds, :] - cls_pred = cls_pred[topk_inds, :] - conf_pred = conf_pred[topk_inds] + batch_inds = torch.arange(batch_size).view( + -1, 1).expand_as(topk_inds).long() + bbox_pred = bbox_pred[batch_inds, topk_inds, :] + cls_pred = cls_pred[batch_inds, topk_inds, :] + conf_pred = conf_pred[batch_inds, topk_inds] # Save the result of current scale multi_lvl_bboxes.append(bbox_pred) @@ -293,43 +300,70 @@ class YOLOV3Head(BaseDenseHead, BBoxTestMixin): multi_lvl_conf_scores.append(conf_pred) # Merge the results of different scales together - multi_lvl_bboxes = torch.cat(multi_lvl_bboxes) - multi_lvl_cls_scores = torch.cat(multi_lvl_cls_scores) - multi_lvl_conf_scores = torch.cat(multi_lvl_conf_scores) + batch_mlvl_bboxes = torch.cat(multi_lvl_bboxes, dim=1) + batch_mlvl_scores = torch.cat(multi_lvl_cls_scores, dim=1) + batch_mlvl_conf_scores = torch.cat(multi_lvl_conf_scores, dim=1) + # Set max number of box to be feed into nms in deployment deploy_nms_pre = cfg.get('deploy_nms_pre', -1) if deploy_nms_pre > 0 and torch.onnx.is_in_onnx_export(): - _, topk_inds = multi_lvl_conf_scores.topk(deploy_nms_pre) - multi_lvl_bboxes = multi_lvl_bboxes[topk_inds, :] - multi_lvl_cls_scores = multi_lvl_cls_scores[topk_inds, :] - multi_lvl_conf_scores = multi_lvl_conf_scores[topk_inds] - - if with_nms and (multi_lvl_conf_scores.size(0) == 0): + _, topk_inds = batch_mlvl_conf_scores.topk(deploy_nms_pre) + batch_inds = torch.arange(batch_size).view( + -1, 1).expand_as(topk_inds).long() + batch_mlvl_bboxes = batch_mlvl_bboxes[batch_inds, topk_inds, :] + batch_mlvl_scores = batch_mlvl_scores[batch_inds, topk_inds, :] + batch_mlvl_conf_scores = batch_mlvl_conf_scores[batch_inds, + topk_inds] + + if with_nms and (batch_mlvl_conf_scores.size(0) == 0): return torch.zeros((0, 5)), torch.zeros((0, )) if rescale: - multi_lvl_bboxes /= multi_lvl_bboxes.new_tensor(scale_factor) + batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor( + scale_factors).unsqueeze(1) # In mmdet 2.x, the class_id for background is num_classes. # i.e., the last column. - padding = multi_lvl_cls_scores.new_zeros(multi_lvl_cls_scores.shape[0], - 1) - multi_lvl_cls_scores = torch.cat([multi_lvl_cls_scores, padding], - dim=1) + padding = batch_mlvl_scores.new_zeros(batch_size, + batch_mlvl_scores.shape[1], 1) + batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1) # Support exporting to onnx without nms if with_nms and cfg.get('nms', None) is not None: - det_bboxes, det_labels = multiclass_nms( - multi_lvl_bboxes, - multi_lvl_cls_scores, - cfg.score_thr, - cfg.nms, - cfg.max_per_img, - score_factors=multi_lvl_conf_scores) - return det_bboxes, det_labels + det_results = [] + for (mlvl_bboxes, mlvl_scores, + mlvl_conf_scores) in zip(batch_mlvl_bboxes, batch_mlvl_scores, + batch_mlvl_conf_scores): + # Filtering out all predictions with conf < conf_thr + conf_thr = cfg.get('conf_thr', -1) + if conf_thr > 0 and (not torch.onnx.is_in_onnx_export()): + # TensorRT not support NonZero + # add as_tuple=False for compatibility in Pytorch 1.6 + # flatten would create a Reshape op with constant values, + # and raise RuntimeError when doing inference in ONNX + # Runtime with a different input image (#4221). + conf_inds = mlvl_conf_scores.ge(conf_thr).nonzero( + as_tuple=False).squeeze(1) + mlvl_bboxes = mlvl_bboxes[conf_inds, :] + mlvl_scores = mlvl_scores[conf_inds, :] + mlvl_conf_scores = mlvl_conf_scores[conf_inds] + + det_bboxes, det_labels = multiclass_nms( + mlvl_bboxes, + mlvl_scores, + cfg.score_thr, + cfg.nms, + cfg.max_per_img, + score_factors=mlvl_conf_scores) + det_results.append(tuple([det_bboxes, det_labels])) + else: - return (multi_lvl_bboxes, multi_lvl_cls_scores, - multi_lvl_conf_scores) + det_results = [ + tuple(mlvl_bs) + for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores, + batch_mlvl_conf_scores) + ] + return det_results @force_fp32(apply_to=('pred_maps', )) def loss(self, diff --git a/tests/test_models/test_dense_heads/test_paa_head.py b/tests/test_models/test_dense_heads/test_paa_head.py index 358e660d3..262e89d2b 100644 --- a/tests/test_models/test_dense_heads/test_paa_head.py +++ b/tests/test_models/test_dense_heads/test_paa_head.py @@ -97,10 +97,10 @@ def test_paa_head_loss(): assert len(results) == n assert results[0].size() == (h * w * 5, c) assert self.with_score_voting - cls_scores = [torch.ones(4, 5, 5)] - bbox_preds = [torch.ones(4, 5, 5)] - iou_preds = [torch.ones(1, 5, 5)] - mlvl_anchors = [torch.ones(5 * 5, 4)] + cls_scores = [torch.ones(2, 4, 5, 5)] + bbox_preds = [torch.ones(2, 4, 5, 5)] + iou_preds = [torch.ones(2, 1, 5, 5)] + mlvl_anchors = [torch.ones(2, 5 * 5, 4)] img_shape = None scale_factor = [0.5, 0.5] cfg = mmcv.Config( @@ -111,7 +111,7 @@ def test_paa_head_loss(): nms=dict(type='nms', iou_threshold=0.6), max_per_img=100)) rescale = False - self._get_bboxes_single( + self._get_bboxes( cls_scores, bbox_preds, iou_preds, diff --git a/tests/test_models/test_forward.py b/tests/test_models/test_forward.py index cba001df8..4e5589c80 100644 --- a/tests/test_models/test_forward.py +++ b/tests/test_models/test_forward.py @@ -137,7 +137,7 @@ def test_rpn_forward(): 'retinanet/retinanet_r50_fpn_1x_coco.py', 'guided_anchoring/ga_retinanet_r50_fpn_1x_coco.py', 'ghm/retinanet_ghm_r50_fpn_1x_coco.py', - 'fcos/fcos_center_r50_caffe_fpn_gn-head_4x4_1x_coco.py', + 'fcos/fcos_center_r50_caffe_fpn_gn-head_1x_coco.py', 'foveabox/fovea_align_r50_fpn_gn-head_4x4_2x_coco.py', # 'free_anchor/retinanet_free_anchor_r50_fpn_1x_coco.py', # 'atss/atss_r50_fpn_1x_coco.py', # not ready for topk diff --git a/tests/test_utils/test_misc.py b/tests/test_utils/test_misc.py index 2deb31e34..16be906c8 100644 --- a/tests/test_utils/test_misc.py +++ b/tests/test_utils/test_misc.py @@ -2,6 +2,7 @@ import numpy as np import pytest import torch +from mmdet.core.bbox import distance2bbox from mmdet.core.mask.structures import BitmapMasks, PolygonMasks from mmdet.core.utils import mask2ndarray @@ -45,3 +46,47 @@ def test_mask2ndarray(): raw_masks = [] with pytest.raises(TypeError): output_mask = mask2ndarray(raw_masks) + + +def test_distance2bbox(): + point = torch.Tensor([[74., 61.], [-29., 106.], [138., 61.], [29., 170.]]) + + distance = torch.Tensor([[0., 0, 1., 1.], [1., 2., 10., 6.], + [22., -29., 138., 61.], [54., -29., 170., 61.]]) + expected_decode_bboxes = torch.Tensor([[74., 61., 75., 62.], + [0., 104., 0., 112.], + [100., 90., 100., 120.], + [0., 120., 100., 120.]]) + out_bbox = distance2bbox(point, distance, max_shape=(120, 100)) + assert expected_decode_bboxes.allclose(out_bbox) + out = distance2bbox(point, distance, max_shape=torch.Tensor((120, 100))) + assert expected_decode_bboxes.allclose(out) + + batch_point = point.unsqueeze(0).repeat(2, 1, 1) + batch_distance = distance.unsqueeze(0).repeat(2, 1, 1) + batch_out = distance2bbox( + batch_point, batch_distance, max_shape=(120, 100))[0] + assert out.allclose(batch_out) + batch_out = distance2bbox( + batch_point, batch_distance, max_shape=[(120, 100), (120, 100)])[0] + assert out.allclose(batch_out) + + batch_out = distance2bbox(point, batch_distance, max_shape=(120, 100))[0] + assert out.allclose(batch_out) + + # test max_shape is not equal to batch + with pytest.raises(AssertionError): + distance2bbox( + batch_point, + batch_distance, + max_shape=[(120, 100), (120, 100), (32, 32)]) + + rois = torch.zeros((0, 4)) + deltas = torch.zeros((0, 4)) + out = distance2bbox(rois, deltas, max_shape=(120, 100)) + assert rois.shape == out.shape + + rois = torch.zeros((2, 0, 4)) + deltas = torch.zeros((2, 0, 4)) + out = distance2bbox(rois, deltas, max_shape=(120, 100)) + assert rois.shape == out.shape