Supports for exporting CornerNet to ONNX with dynamic shapes and comparable performance (#5136)

* Supports for exporting CornerNet to ONNX with dynamic shapes and comparable performance

* add docs for exporting cornernet, and simplify code

* fix doc

* format doc

* fix docstring
pull/5198/head^2
v-qjqs 4 years ago committed by GitHub
parent 24b6f9319f
commit 2e635eb5f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 33
      docs/tutorials/pytorch2onnx.md
  2. 12
      mmdet/core/export/model_wrappers.py
  3. 17
      mmdet/models/backbones/hourglass.py
  4. 70
      mmdet/models/dense_heads/corner_head.py
  5. 12
      mmdet/models/detectors/single_stage.py

@ -193,11 +193,18 @@ python tools/deployment/test.py \
<td align="center">34.7</td>
<td align="center">33.7</td>
</tr>
<tr >
<td align="center">CornerNet</td>
<td align="center"><code>configs/cornernet/cornernet_hourglass104_mstest_10x5_210e_coco.py</code></td>
<td align="center">Box AP</td>
<td align="center">40.6</td>
<td align="center">40.4</td>
</tr>
</table>
Notes:
- All ONNX models are evaluated with dynamic shape on coco dataset and images are preprocessed according to the original config file.
- All ONNX models are evaluated with dynamic shape on coco dataset and images are preprocessed according to the original config file. Note that CornerNet is evaluated without test-time flip, since currently only single-scale evaluation is supported with ONNX Runtime.
- Mask AP of Mask R-CNN drops by 1% for ONNXRuntime. The main reason is that the predicted masks are directly interpolated to original image in PyTorch, while they are at first interpolated to the preprocessed input image of the model and then to original image in ONNXRuntime.
@ -205,19 +212,23 @@ Notes:
The table below lists the models that are guaranteed to be exportable to ONNX and runnable in ONNX Runtime.
| Model | Config | Dynamic Shape | Batch Inference | Note |
| :----------: | :------------------------------------------------------: | :-----------: | :-------------: | :---: |
| FCOS | `configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py` | Y | Y | |
| FSAF | `configs/fsaf/fsaf_r50_fpn_1x_coco.py` | Y | Y | |
| RetinaNet | `configs/retinanet/retinanet_r50_fpn_1x_coco.py` | Y | Y | |
| SSD | `configs/ssd/ssd300_coco.py` | Y | Y | |
| YOLOv3 | `configs/yolo/yolov3_d53_mstrain-608_273e_coco.py` | Y | Y | |
| Faster R-CNN | `configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py` | Y | Y | |
| Mask R-CNN | `configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py` | Y | Y | |
| Model | Config | Dynamic Shape | Batch Inference | Note |
| :----------: | :-----------------------------------------------------------------: | :-----------: | :-------------: | :-----: |
| FCOS | `configs/fcos/fcos_r50_caffe_fpn_gn-head_4x4_1x_coco.py` | Y | Y | |
| FSAF | `configs/fsaf/fsaf_r50_fpn_1x_coco.py` | Y | Y | |
| RetinaNet | `configs/retinanet/retinanet_r50_fpn_1x_coco.py` | Y | Y | |
| SSD | `configs/ssd/ssd300_coco.py` | Y | Y | |
| YOLOv3 | `configs/yolo/yolov3_d53_mstrain-608_273e_coco.py` | Y | Y | |
| Faster R-CNN | `configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py` | Y | Y | |
| Mask R-CNN | `configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py` | Y | Y | |
| CornerNet | `configs/cornernet/cornernet_hourglass104_mstest_10x5_210e_coco.py` | Y | N | no flip, no batch inference, tested with torch==1.7.0 and onnxruntime==1.5.1. |
Notes:
- *All models above are tested with Pytorch==1.6.0 and onnxruntime==1.5.1*
- *All models above are tested with Pytorch==1.6.0 and onnxruntime==1.5.1*, except for CornerNet. For more details about the
torch version when exporting CornerNet to ONNX, which involves `mmcv::cummax`, please refer to the [Known Issues](https://github.com/open-mmlab/mmcv/blob/master/docs/onnxruntime_op.md#known-issues) in mmcv.
- Currently only single-scale evaluation is supported with ONNX Runtime, also `mmcv::SoftNonMaxSuppression` is only supported for single image by now.
- If the deployed backend platform is TensorRT, please add environment variables before running the file:

@ -80,7 +80,19 @@ class ONNXRuntimeDetector(BaseDetector):
for i in range(batch_size):
scale_factor = img_metas[i]['scale_factor']
dets, labels = batch_dets[i], batch_labels[i]
if isinstance(scale_factor, (list, tuple, np.ndarray)):
assert len(scale_factor) == 4
scale_factor = np.array(scale_factor)[None, :] # [1,4]
dets[:, :4] /= scale_factor
if 'border' in img_metas[i]:
# offset pixel of the top-left corners between original image
# and padded/enlarged image, 'border' is used when exporting
# CornerNet and CentripetalNet to onnx
x_off = img_metas[i]['border'][2]
y_off = img_metas[i]['border'][0]
dets[:, [0, 2]] -= x_off
dets[:, [1, 3]] -= y_off
dets[:, :4] *= (dets[:, :4] > 0).astype(dets.dtype)
dets_results = bbox2result(dets, labels, len(self.CLASSES))
if batch_masks is not None:
masks = batch_masks[i]

@ -1,4 +1,5 @@
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
@ -21,6 +22,8 @@ class HourglassModule(BaseModule):
norm_cfg (dict): Dictionary to construct and config norm layer.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
upsample_cfg (dict, optional): Config dict for interpolate layer.
Default: `dict(mode='nearest')`
"""
def __init__(self,
@ -28,7 +31,8 @@ class HourglassModule(BaseModule):
stage_channels,
stage_blocks,
norm_cfg=dict(type='BN', requires_grad=True),
init_cfg=None):
init_cfg=None,
upsample_cfg=dict(mode='nearest')):
super(HourglassModule, self).__init__(init_cfg)
self.depth = depth
@ -69,7 +73,8 @@ class HourglassModule(BaseModule):
norm_cfg=norm_cfg,
downsample_first=False)
self.up2 = nn.Upsample(scale_factor=2)
self.up2 = F.interpolate
self.upsample_cfg = upsample_cfg
def forward(self, x):
"""Forward function."""
@ -77,7 +82,13 @@ class HourglassModule(BaseModule):
low1 = self.low1(x)
low2 = self.low2(low1)
low3 = self.low3(low2)
up2 = self.up2(low3)
# Fixing `scale factor` (e.g. 2) is common for upsampling, but
# in some cases the spatial size is mismatched and error will arise.
if 'scale_factor' in self.upsample_cfg:
up2 = self.up2(low3, **self.upsample_cfg)
else:
shape = up1.shape[2:]
up2 = self.up2(low3, size=shape, **self.upsample_cfg)
return up1 + up2

@ -694,6 +694,15 @@ class CornerHead(BaseDenseHead):
rescale=rescale,
with_nms=with_nms))
if torch.onnx.is_in_onnx_export():
assert len(
img_metas
) == 1, 'Only support one input image while in exporting to ONNX'
detections, labels = result_list[0]
# batch_size 1 here, [1, num_det, 5], [1, num_det]
return detections.unsqueeze(0), labels.unsqueeze(0)
return result_list
def _get_bboxes_single(self,
@ -758,9 +767,11 @@ class CornerHead(BaseDenseHead):
scores = batch_scores.view([-1, 1])
clses = batch_clses.view([-1, 1])
idx = scores.argsort(dim=0, descending=True)
# use `sort` instead of `argsort` here, since currently exporting
# `argsort` to ONNX opset version 11 is not supported
scores, idx = scores.sort(dim=0, descending=True)
bboxes = bboxes[idx].view([-1, 4])
scores = scores[idx].view(-1)
scores = scores.view(-1)
clses = clses[idx].view(-1)
detections = torch.cat([bboxes, scores.unsqueeze(-1)], -1)
@ -789,8 +800,15 @@ class CornerHead(BaseDenseHead):
out_labels = labels[keep]
if len(out_bboxes) > 0:
idx = torch.argsort(out_bboxes[:, -1], descending=True)
idx = idx[:cfg.max_per_img]
# use `sort` to replace with `argsort` here
_, idx = torch.sort(out_bboxes[:, -1], descending=True)
max_per_img = out_bboxes.new_tensor(cfg.max_per_img).to(torch.long)
nms_after = max_per_img
if torch.onnx.is_in_onnx_export():
# Always keep topk op for dynamic input in onnx
from mmdet.core.export import get_k_for_topk
nms_after = get_k_for_topk(max_per_img, out_bboxes.shape[0])
idx = idx[:nms_after]
out_bboxes = out_bboxes[idx]
out_labels = out_labels[idx]
@ -852,7 +870,10 @@ class CornerHead(BaseDenseHead):
and br_centripetal_shift is not None)
assert with_embedding + with_centripetal_shift == 1
batch, _, height, width = tl_heat.size()
inp_h, inp_w, _ = img_meta['pad_shape']
if torch.onnx.is_in_onnx_export():
inp_h, inp_w = img_meta['pad_shape_for_onnx'][:2]
else:
inp_h, inp_w, _ = img_meta['pad_shape']
# perform nms on heatmaps
tl_heat = get_local_maximum(tl_heat, kernel=kernel)
@ -905,18 +926,31 @@ class CornerHead(BaseDenseHead):
br_ctxs *= (inp_w / width)
br_ctys *= (inp_h / height)
x_off = img_meta['border'][2]
y_off = img_meta['border'][0]
x_off, y_off = 0, 0 # no crop
if not torch.onnx.is_in_onnx_export():
# since `RandomCenterCropPad` is done on CPU with numpy and it's
# not dynamic traceable when exporting to ONNX, thus 'border'
# does not appears as key in 'img_meta'. As a tmp solution,
# we move this 'border' handle part to the postprocess after
# finished exporting to ONNX, which is handle in
# `mmdet/core/export/model_wrappers.py`. Though difference between
# pytorch and exported onnx model, it might be ignored since
# comparable performance is achieved between them (e.g. 40.4 vs
# 40.6 on COCO val2017, for CornerNet without test-time flip)
if 'border' in img_meta:
x_off = img_meta['border'][2]
y_off = img_meta['border'][0]
tl_xs -= x_off
tl_ys -= y_off
br_xs -= x_off
br_ys -= y_off
tl_xs *= tl_xs.gt(0.0).type_as(tl_xs)
tl_ys *= tl_ys.gt(0.0).type_as(tl_ys)
br_xs *= br_xs.gt(0.0).type_as(br_xs)
br_ys *= br_ys.gt(0.0).type_as(br_ys)
zeros = tl_xs.new_zeros(*tl_xs.size())
tl_xs = torch.where(tl_xs > 0.0, tl_xs, zeros)
tl_ys = torch.where(tl_ys > 0.0, tl_ys, zeros)
br_xs = torch.where(br_xs > 0.0, br_xs, zeros)
br_ys = torch.where(br_ys > 0.0, br_ys, zeros)
bboxes = torch.stack((tl_xs, tl_ys, br_xs, br_ys), dim=3)
area_bboxes = ((br_xs - tl_xs) * (br_ys - tl_ys)).abs()
@ -988,10 +1022,16 @@ class CornerHead(BaseDenseHead):
width_inds = (br_xs <= tl_xs)
height_inds = (br_ys <= tl_ys)
scores[cls_inds] = -1
scores[width_inds] = -1
scores[height_inds] = -1
scores[dist_inds] = -1
# No use `scores[cls_inds]`, instead we use `torch.where` here.
# Since only 1-D indices with type 'tensor(bool)' are supported
# when exporting to ONNX, any other bool indices with more dimensions
# (e.g. 2-D bool tensor) as input parameter in node is invalid
negative_scores = -1 * torch.ones_like(scores)
scores = torch.where(cls_inds, negative_scores, scores)
scores = torch.where(width_inds, negative_scores, scores)
scores = torch.where(height_inds, negative_scores, scores)
scores = torch.where(dist_inds, negative_scores, scores)
if with_centripetal_shift:
scores[tl_ctx_inds] = -1
scores[tl_cty_inds] = -1

@ -94,9 +94,21 @@ class SingleStageDetector(BaseDetector):
"""
x = self.extract_feat(img)
outs = self.bbox_head(x)
# get origin input shape to support onnx dynamic shape
if torch.onnx.is_in_onnx_export():
# get shape as tensor
img_shape = torch._shape_as_tensor(img)[2:]
img_metas[0]['img_shape_for_onnx'] = img_shape
# get pad input shape to support onnx dynamic shape for exporting
# `CornerNet` and `CentripetalNet`, which 'pad_shape' is used
# for inference
img_metas[0]['pad_shape_for_onnx'] = img_shape
bbox_list = self.bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale)
# skip post-processing when exporting to ONNX
if torch.onnx.is_in_onnx_export():
return bbox_list
bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)

Loading…
Cancel
Save