diff --git a/docs/tutorials/pytorch2onnx.md b/docs/tutorials/pytorch2onnx.md index 0b9934554..ccd9bb4bf 100644 --- a/docs/tutorials/pytorch2onnx.md +++ b/docs/tutorials/pytorch2onnx.md @@ -193,11 +193,18 @@ python tools/deployment/test.py \ 34.7 33.7 + + CornerNet + configs/cornernet/cornernet_hourglass104_mstest_10x5_210e_coco.py + Box AP + 40.6 + 40.4 + Notes: -- All ONNX models are evaluated with dynamic shape on coco dataset and images are preprocessed according to the original config file. +- All ONNX models are evaluated with dynamic shape on coco dataset and images are preprocessed according to the original config file. Note that CornerNet is evaluated without test-time flip, since currently only single-scale evaluation is supported with ONNX Runtime. - Mask AP of Mask R-CNN drops by 1% for ONNXRuntime. The main reason is that the predicted masks are directly interpolated to original image in PyTorch, while they are at first interpolated to the preprocessed input image of the model and then to original image in ONNXRuntime. @@ -205,19 +212,23 @@ Notes: The table below lists the models that are guaranteed to be exportable to ONNX and runnable in ONNX Runtime. -| Model | Config | Dynamic Shape | Batch Inference | Note | -| :----------: | :------------------------------------------------------: | :-----------: | :-------------: | :---: | -| FCOS | `configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py` | Y | Y | | -| FSAF | `configs/fsaf/fsaf_r50_fpn_1x_coco.py` | Y | Y | | -| RetinaNet | `configs/retinanet/retinanet_r50_fpn_1x_coco.py` | Y | Y | | -| SSD | `configs/ssd/ssd300_coco.py` | Y | Y | | -| YOLOv3 | `configs/yolo/yolov3_d53_mstrain-608_273e_coco.py` | Y | Y | | -| Faster R-CNN | `configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py` | Y | Y | | -| Mask R-CNN | `configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py` | Y | Y | | +| Model | Config | Dynamic Shape | Batch Inference | Note | +| :----------: | :-----------------------------------------------------------------: | :-----------: | :-------------: | :-----: | +| FCOS | `configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py` | Y | Y | | +| FSAF | `configs/fsaf/fsaf_r50_fpn_1x_coco.py` | Y | Y | | +| RetinaNet | `configs/retinanet/retinanet_r50_fpn_1x_coco.py` | Y | Y | | +| SSD | `configs/ssd/ssd300_coco.py` | Y | Y | | +| YOLOv3 | `configs/yolo/yolov3_d53_mstrain-608_273e_coco.py` | Y | Y | | +| Faster R-CNN | `configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py` | Y | Y | | +| Mask R-CNN | `configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py` | Y | Y | | +| CornerNet | `configs/cornernet/cornernet_hourglass104_mstest_10x5_210e_coco.py` | Y | N | no flip, no batch inference, tested with torch==1.7.0 and onnxruntime==1.5.1. | Notes: -- *All models above are tested with Pytorch==1.6.0 and onnxruntime==1.5.1* +- *All models above are tested with Pytorch==1.6.0 and onnxruntime==1.5.1*, except for CornerNet. For more details about the +torch version when exporting CornerNet to ONNX, which involves `mmcv::cummax`, please refer to the [Known Issues](https://github.com/open-mmlab/mmcv/blob/master/docs/onnxruntime_op.md#known-issues) in mmcv. + +- Currently only single-scale evaluation is supported with ONNX Runtime, also `mmcv::SoftNonMaxSuppression` is only supported for single image by now. - If the deployed backend platform is TensorRT, please add environment variables before running the file: diff --git a/mmdet/core/export/model_wrappers.py b/mmdet/core/export/model_wrappers.py index e9988ba4e..dfdbe5d4b 100644 --- a/mmdet/core/export/model_wrappers.py +++ b/mmdet/core/export/model_wrappers.py @@ -80,7 +80,19 @@ class ONNXRuntimeDetector(BaseDetector): for i in range(batch_size): scale_factor = img_metas[i]['scale_factor'] dets, labels = batch_dets[i], batch_labels[i] + if isinstance(scale_factor, (list, tuple, np.ndarray)): + assert len(scale_factor) == 4 + scale_factor = np.array(scale_factor)[None, :] # [1,4] dets[:, :4] /= scale_factor + if 'border' in img_metas[i]: + # offset pixel of the top-left corners between original image + # and padded/enlarged image, 'border' is used when exporting + # CornerNet and CentripetalNet to onnx + x_off = img_metas[i]['border'][2] + y_off = img_metas[i]['border'][0] + dets[:, [0, 2]] -= x_off + dets[:, [1, 3]] -= y_off + dets[:, :4] *= (dets[:, :4] > 0).astype(dets.dtype) dets_results = bbox2result(dets, labels, len(self.CLASSES)) if batch_masks is not None: masks = batch_masks[i] diff --git a/mmdet/models/backbones/hourglass.py b/mmdet/models/backbones/hourglass.py index ea210f59f..d9e16e675 100644 --- a/mmdet/models/backbones/hourglass.py +++ b/mmdet/models/backbones/hourglass.py @@ -1,4 +1,5 @@ import torch.nn as nn +import torch.nn.functional as F from mmcv.cnn import ConvModule from mmcv.runner import BaseModule @@ -21,6 +22,8 @@ class HourglassModule(BaseModule): norm_cfg (dict): Dictionary to construct and config norm layer. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None + upsample_cfg (dict, optional): Config dict for interpolate layer. + Default: `dict(mode='nearest')` """ def __init__(self, @@ -28,7 +31,8 @@ class HourglassModule(BaseModule): stage_channels, stage_blocks, norm_cfg=dict(type='BN', requires_grad=True), - init_cfg=None): + init_cfg=None, + upsample_cfg=dict(mode='nearest')): super(HourglassModule, self).__init__(init_cfg) self.depth = depth @@ -69,7 +73,8 @@ class HourglassModule(BaseModule): norm_cfg=norm_cfg, downsample_first=False) - self.up2 = nn.Upsample(scale_factor=2) + self.up2 = F.interpolate + self.upsample_cfg = upsample_cfg def forward(self, x): """Forward function.""" @@ -77,7 +82,13 @@ class HourglassModule(BaseModule): low1 = self.low1(x) low2 = self.low2(low1) low3 = self.low3(low2) - up2 = self.up2(low3) + # Fixing `scale factor` (e.g. 2) is common for upsampling, but + # in some cases the spatial size is mismatched and error will arise. + if 'scale_factor' in self.upsample_cfg: + up2 = self.up2(low3, **self.upsample_cfg) + else: + shape = up1.shape[2:] + up2 = self.up2(low3, size=shape, **self.upsample_cfg) return up1 + up2 diff --git a/mmdet/models/dense_heads/corner_head.py b/mmdet/models/dense_heads/corner_head.py index 1e7c61273..aad39ed36 100644 --- a/mmdet/models/dense_heads/corner_head.py +++ b/mmdet/models/dense_heads/corner_head.py @@ -694,6 +694,15 @@ class CornerHead(BaseDenseHead): rescale=rescale, with_nms=with_nms)) + if torch.onnx.is_in_onnx_export(): + assert len( + img_metas + ) == 1, 'Only support one input image while in exporting to ONNX' + + detections, labels = result_list[0] + # batch_size 1 here, [1, num_det, 5], [1, num_det] + return detections.unsqueeze(0), labels.unsqueeze(0) + return result_list def _get_bboxes_single(self, @@ -758,9 +767,11 @@ class CornerHead(BaseDenseHead): scores = batch_scores.view([-1, 1]) clses = batch_clses.view([-1, 1]) - idx = scores.argsort(dim=0, descending=True) + # use `sort` instead of `argsort` here, since currently exporting + # `argsort` to ONNX opset version 11 is not supported + scores, idx = scores.sort(dim=0, descending=True) bboxes = bboxes[idx].view([-1, 4]) - scores = scores[idx].view(-1) + scores = scores.view(-1) clses = clses[idx].view(-1) detections = torch.cat([bboxes, scores.unsqueeze(-1)], -1) @@ -789,8 +800,15 @@ class CornerHead(BaseDenseHead): out_labels = labels[keep] if len(out_bboxes) > 0: - idx = torch.argsort(out_bboxes[:, -1], descending=True) - idx = idx[:cfg.max_per_img] + # use `sort` to replace with `argsort` here + _, idx = torch.sort(out_bboxes[:, -1], descending=True) + max_per_img = out_bboxes.new_tensor(cfg.max_per_img).to(torch.long) + nms_after = max_per_img + if torch.onnx.is_in_onnx_export(): + # Always keep topk op for dynamic input in onnx + from mmdet.core.export import get_k_for_topk + nms_after = get_k_for_topk(max_per_img, out_bboxes.shape[0]) + idx = idx[:nms_after] out_bboxes = out_bboxes[idx] out_labels = out_labels[idx] @@ -852,7 +870,10 @@ class CornerHead(BaseDenseHead): and br_centripetal_shift is not None) assert with_embedding + with_centripetal_shift == 1 batch, _, height, width = tl_heat.size() - inp_h, inp_w, _ = img_meta['pad_shape'] + if torch.onnx.is_in_onnx_export(): + inp_h, inp_w = img_meta['pad_shape_for_onnx'][:2] + else: + inp_h, inp_w, _ = img_meta['pad_shape'] # perform nms on heatmaps tl_heat = get_local_maximum(tl_heat, kernel=kernel) @@ -905,18 +926,31 @@ class CornerHead(BaseDenseHead): br_ctxs *= (inp_w / width) br_ctys *= (inp_h / height) - x_off = img_meta['border'][2] - y_off = img_meta['border'][0] + x_off, y_off = 0, 0 # no crop + if not torch.onnx.is_in_onnx_export(): + # since `RandomCenterCropPad` is done on CPU with numpy and it's + # not dynamic traceable when exporting to ONNX, thus 'border' + # does not appears as key in 'img_meta'. As a tmp solution, + # we move this 'border' handle part to the postprocess after + # finished exporting to ONNX, which is handle in + # `mmdet/core/export/model_wrappers.py`. Though difference between + # pytorch and exported onnx model, it might be ignored since + # comparable performance is achieved between them (e.g. 40.4 vs + # 40.6 on COCO val2017, for CornerNet without test-time flip) + if 'border' in img_meta: + x_off = img_meta['border'][2] + y_off = img_meta['border'][0] tl_xs -= x_off tl_ys -= y_off br_xs -= x_off br_ys -= y_off - tl_xs *= tl_xs.gt(0.0).type_as(tl_xs) - tl_ys *= tl_ys.gt(0.0).type_as(tl_ys) - br_xs *= br_xs.gt(0.0).type_as(br_xs) - br_ys *= br_ys.gt(0.0).type_as(br_ys) + zeros = tl_xs.new_zeros(*tl_xs.size()) + tl_xs = torch.where(tl_xs > 0.0, tl_xs, zeros) + tl_ys = torch.where(tl_ys > 0.0, tl_ys, zeros) + br_xs = torch.where(br_xs > 0.0, br_xs, zeros) + br_ys = torch.where(br_ys > 0.0, br_ys, zeros) bboxes = torch.stack((tl_xs, tl_ys, br_xs, br_ys), dim=3) area_bboxes = ((br_xs - tl_xs) * (br_ys - tl_ys)).abs() @@ -988,10 +1022,16 @@ class CornerHead(BaseDenseHead): width_inds = (br_xs <= tl_xs) height_inds = (br_ys <= tl_ys) - scores[cls_inds] = -1 - scores[width_inds] = -1 - scores[height_inds] = -1 - scores[dist_inds] = -1 + # No use `scores[cls_inds]`, instead we use `torch.where` here. + # Since only 1-D indices with type 'tensor(bool)' are supported + # when exporting to ONNX, any other bool indices with more dimensions + # (e.g. 2-D bool tensor) as input parameter in node is invalid + negative_scores = -1 * torch.ones_like(scores) + scores = torch.where(cls_inds, negative_scores, scores) + scores = torch.where(width_inds, negative_scores, scores) + scores = torch.where(height_inds, negative_scores, scores) + scores = torch.where(dist_inds, negative_scores, scores) + if with_centripetal_shift: scores[tl_ctx_inds] = -1 scores[tl_cty_inds] = -1 diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py index b8dfd88e5..c85498933 100644 --- a/mmdet/models/detectors/single_stage.py +++ b/mmdet/models/detectors/single_stage.py @@ -94,9 +94,21 @@ class SingleStageDetector(BaseDetector): """ x = self.extract_feat(img) outs = self.bbox_head(x) + # get origin input shape to support onnx dynamic shape + if torch.onnx.is_in_onnx_export(): + # get shape as tensor + img_shape = torch._shape_as_tensor(img)[2:] + img_metas[0]['img_shape_for_onnx'] = img_shape + # get pad input shape to support onnx dynamic shape for exporting + # `CornerNet` and `CentripetalNet`, which 'pad_shape' is used + # for inference + img_metas[0]['pad_shape_for_onnx'] = img_shape bbox_list = self.bbox_head.get_bboxes( *outs, img_metas, rescale=rescale) + # skip post-processing when exporting to ONNX + if torch.onnx.is_in_onnx_export(): + return bbox_list bbox_results = [ bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)