add onnx simplify option to pytorch2onnx tool (#4468)

* add onnx simplify option

* resolve comments

* add warnings
pull/4501/head
RunningLeon 4 years ago committed by GitHub
parent 1b6895a9c5
commit 247785ceb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      docs/tutorials/pytorch2onnx.md
  2. 48
      tools/pytorch2onnx.py

@ -55,6 +55,7 @@ Description of all arguments:
- `--opset-version` : The opset version of ONNX. If not specified, it will be set to `11`. - `--opset-version` : The opset version of ONNX. If not specified, it will be set to `11`.
- `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`. - `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`.
- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`. - `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`.
- `--simplify`: Determines whether to simplify the exported ONNX model. If not specified, it will be set to `False`.
Example: Example:
@ -90,6 +91,8 @@ Notes:
## Reminders ## Reminders
- When the input model has custom op such as `RoIAlign` and if you want to verify the exported ONNX model, you may have to build `mmcv` with [ONNXRuntime](https://mmcv.readthedocs.io/en/latest/onnxruntime_op.html) from source.
- `mmcv.onnx.simplify` feature is based on [onnx-simplifier](https://github.com/daquexian/onnx-simplifier). If you want to try it, please refer to [onnx in `mmcv`](https://mmcv.readthedocs.io/en/latest/onnx.html) and [onnxruntime op in `mmcv`](https://mmcv.readthedocs.io/en/latest/onnxruntime_op.html) for more information.
- If you meet any problem with the listed models above, please create an issue and it would be taken care of soon. For models not included in the list, please try to dig a little deeper and debug a little bit more and hopefully solve them by yourself. - If you meet any problem with the listed models above, please create an issue and it would be taken care of soon. For models not included in the list, please try to dig a little deeper and debug a little bit more and hopefully solve them by yourself.
- Because this feature is experimental and may change fast, please always try with the latest `mmcv` and `mmdetecion`. - Because this feature is experimental and may change fast, please always try with the latest `mmcv` and `mmdetecion`.

@ -1,5 +1,6 @@
import argparse import argparse
import os.path as osp import os.path as osp
import warnings
import numpy as np import numpy as np
import onnx import onnx
@ -20,7 +21,8 @@ def pytorch2onnx(config_path,
verify=False, verify=False,
normalize_cfg=None, normalize_cfg=None,
dataset='coco', dataset='coco',
test_img=None): test_img=None,
do_simplify=False):
input_config = { input_config = {
'input_shape': input_shape, 'input_shape': input_shape,
@ -52,10 +54,32 @@ def pytorch2onnx(config_path,
opset_version=opset_version) opset_version=opset_version)
model.forward = orig_model.forward model.forward = orig_model.forward
# simplify onnx model
if do_simplify:
from mmdet import digit_version
import mmcv
min_required_version = '1.2.5'
assert digit_version(mmcv.__version__) >= digit_version(
min_required_version
), f'Requires to install mmcv>={min_required_version}'
from mmcv.onnx.simplify import simplify
input_dic = {'input': one_img.detach().cpu().numpy()}
_ = simplify(output_file, [input_dic], output_file)
print(f'Successfully exported ONNX model: {output_file}') print(f'Successfully exported ONNX model: {output_file}')
if verify: if verify:
from mmdet.core import get_classes from mmdet.core import get_classes, bbox2result
from mmdet.apis import show_result_pyplot from mmdet.apis import show_result_pyplot
ort_custom_op_path = ''
try:
from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
except (ImportError, ModuleNotFoundError):
warnings.warn('If input model has custom op from mmcv, \
you may have to build mmcv with ONNXRuntime from source.')
model.CLASSES = get_classes(dataset) model.CLASSES = get_classes(dataset)
num_classes = len(model.CLASSES) num_classes = len(model.CLASSES)
# check by onnx # check by onnx
@ -76,8 +100,11 @@ def pytorch2onnx(config_path,
] ]
net_feed_input = list(set(input_all) - set(input_initializer)) net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 1) assert (len(net_feed_input) == 1)
sess = rt.InferenceSession(output_file) session_options = rt.SessionOptions()
from mmdet.core import bbox2result # register custom op for onnxruntime
if osp.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
sess = rt.InferenceSession(output_file, session_options)
onnx_outputs = sess.run(None, onnx_outputs = sess.run(None,
{net_feed_input[0]: one_img.detach().numpy()}) {net_feed_input[0]: one_img.detach().numpy()})
output_names = [_.name for _ in sess.get_outputs()] output_names = [_.name for _ in sess.get_outputs()]
@ -102,11 +129,7 @@ def pytorch2onnx(config_path,
if show: if show:
show_result_pyplot( show_result_pyplot(
model, model, one_meta['show_img'], pytorch_results, title='Pytorch')
one_meta['show_img'],
pytorch_results,
title='Pytorch',
block=False)
show_result_pyplot( show_result_pyplot(
model, one_meta['show_img'], onnx_results, title='ONNX') model, one_meta['show_img'], onnx_results, title='ONNX')
@ -144,6 +167,10 @@ def parse_args():
'--verify', '--verify',
action='store_true', action='store_true',
help='verify the onnx model output against pytorch output') help='verify the onnx model output against pytorch output')
parser.add_argument(
'--simplify',
action='store_true',
help='Whether to simplify onnx model.')
parser.add_argument( parser.add_argument(
'--shape', '--shape',
type=int, type=int,
@ -199,4 +226,5 @@ if __name__ == '__main__':
verify=args.verify, verify=args.verify,
normalize_cfg=normalize_cfg, normalize_cfg=normalize_cfg,
dataset=args.dataset, dataset=args.dataset,
test_img=args.test_img) test_img=args.test_img,
do_simplify=args.simplify)

Loading…
Cancel
Save