OpenMMLab Detection Toolbox and Benchmark https://mmdetection.readthedocs.io/
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

170 lines
5.5 KiB

import argparse
import os.path as osp
from functools import partial
import mmcv
import numpy as np
import onnx
import onnxruntime as rt
import torch
from mmcv.runner import load_checkpoint
from mmdet.models import build_detector
try:
from mmcv.onnx.symbolic import register_extra_symbolics
except ModuleNotFoundError:
raise NotImplementedError('please update mmcv to version>=v1.0.4')
def pytorch2onnx(model,
input_img,
input_shape,
opset_version=11,
show=False,
output_file='tmp.onnx',
verify=False,
normalize_cfg=None):
model.cpu().eval()
# read image
one_img = mmcv.imread(input_img)
if normalize_cfg:
one_img = mmcv.imnormalize(one_img, normalize_cfg['mean'],
normalize_cfg['std'])
one_img = mmcv.imresize(one_img, input_shape[2:]).transpose(2, 0, 1)
one_img = torch.from_numpy(one_img).unsqueeze(0).float()
(_, C, H, W) = input_shape
one_meta = {
'img_shape': (H, W, C),
'ori_shape': (H, W, C),
'pad_shape': (H, W, C),
'filename': '<demo>.png',
'scale_factor': 1.0,
'flip': False
}
# onnx.export does not support kwargs
origin_forward = model.forward
model.forward = partial(
model.forward, img_metas=[[one_meta]], return_loss=False)
# pytorch has some bug in pytorch1.3, we have to fix it
# by replacing these existing op
register_extra_symbolics(opset_version)
torch.onnx.export(
model, ([one_img]),
output_file,
export_params=True,
keep_initializers_as_inputs=True,
verbose=show,
opset_version=opset_version)
model.forward = origin_forward
print(f'Successfully exported ONNX model: {output_file}')
if verify:
# check by onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
# check the numerical value
# get pytorch output
pytorch_result = model([one_img], [[one_meta]], return_loss=False)
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 1)
sess = rt.InferenceSession(output_file)
from mmdet.core import bbox2result
det_bboxes, det_labels = sess.run(
None, {net_feed_input[0]: one_img.detach().numpy()})
# only compare a part of result
bbox_results = bbox2result(det_bboxes, det_labels, 1)
onnx_results = bbox_results[0]
assert np.allclose(
pytorch_result[0][:, 4], onnx_results[:, 4]
), 'The outputs are different between Pytorch and ONNX'
print('The numerical values are same between Pytorch and ONNX')
def parse_args():
parser = argparse.ArgumentParser(
description='Convert MMDetection models to ONNX')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--input-img', type=str, help='Images for input')
parser.add_argument('--show', action='store_true', help='show onnx graph')
parser.add_argument('--output-file', type=str, default='tmp.onnx')
parser.add_argument('--opset-version', type=int, default=11)
parser.add_argument(
'--verify',
action='store_true',
help='verify the onnx model output against pytorch output')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[800, 1216],
help='input image size')
parser.add_argument(
'--mean',
type=int,
nargs='+',
default=[123.675, 116.28, 103.53],
help='mean value used for preprocess input data')
parser.add_argument(
'--std',
type=int,
nargs='+',
default=[58.395, 57.12, 57.375],
help='variance value used for preprocess input data')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
assert args.opset_version == 11, 'MMDet only support opset 11 now'
if not args.input_img:
args.input_img = osp.join(
osp.dirname(__file__), '../tests/data/color.jpg')
if len(args.shape) == 1:
input_shape = (1, 3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (1, 3) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
assert len(args.mean) == 3
assert len(args.std) == 3
normalize_cfg = {
'mean': np.array(args.mean, dtype=np.float32),
'std': np.array(args.std, dtype=np.float32)
}
cfg = mmcv.Config.fromfile(args.config)
# import modules from string list.
if cfg.get('custom_imports', None):
from mmcv.utils import import_modules_from_strings
import_modules_from_strings(**cfg['custom_imports'])
cfg.model.pretrained = None
cfg.data.test.test_mode = True
# build the model
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
# conver model to onnx file
pytorch2onnx(
model,
args.input_img,
input_shape,
opset_version=args.opset_version,
show=args.show,
output_file=args.output_file,
verify=args.verify,
normalize_cfg=normalize_cfg)