OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
170 lines
5.5 KiB
170 lines
5.5 KiB
import argparse |
|
import os.path as osp |
|
from functools import partial |
|
|
|
import mmcv |
|
import numpy as np |
|
import onnx |
|
import onnxruntime as rt |
|
import torch |
|
from mmcv.runner import load_checkpoint |
|
|
|
from mmdet.models import build_detector |
|
|
|
try: |
|
from mmcv.onnx.symbolic import register_extra_symbolics |
|
except ModuleNotFoundError: |
|
raise NotImplementedError('please update mmcv to version>=v1.0.4') |
|
|
|
|
|
def pytorch2onnx(model, |
|
input_img, |
|
input_shape, |
|
opset_version=11, |
|
show=False, |
|
output_file='tmp.onnx', |
|
verify=False, |
|
normalize_cfg=None): |
|
model.cpu().eval() |
|
# read image |
|
one_img = mmcv.imread(input_img) |
|
if normalize_cfg: |
|
one_img = mmcv.imnormalize(one_img, normalize_cfg['mean'], |
|
normalize_cfg['std']) |
|
one_img = mmcv.imresize(one_img, input_shape[2:]).transpose(2, 0, 1) |
|
one_img = torch.from_numpy(one_img).unsqueeze(0).float() |
|
(_, C, H, W) = input_shape |
|
one_meta = { |
|
'img_shape': (H, W, C), |
|
'ori_shape': (H, W, C), |
|
'pad_shape': (H, W, C), |
|
'filename': '<demo>.png', |
|
'scale_factor': 1.0, |
|
'flip': False |
|
} |
|
# onnx.export does not support kwargs |
|
origin_forward = model.forward |
|
model.forward = partial( |
|
model.forward, img_metas=[[one_meta]], return_loss=False) |
|
# pytorch has some bug in pytorch1.3, we have to fix it |
|
# by replacing these existing op |
|
register_extra_symbolics(opset_version) |
|
torch.onnx.export( |
|
model, ([one_img]), |
|
output_file, |
|
export_params=True, |
|
keep_initializers_as_inputs=True, |
|
verbose=show, |
|
opset_version=opset_version) |
|
model.forward = origin_forward |
|
print(f'Successfully exported ONNX model: {output_file}') |
|
if verify: |
|
# check by onnx |
|
onnx_model = onnx.load(output_file) |
|
onnx.checker.check_model(onnx_model) |
|
|
|
# check the numerical value |
|
# get pytorch output |
|
pytorch_result = model([one_img], [[one_meta]], return_loss=False) |
|
|
|
# get onnx output |
|
input_all = [node.name for node in onnx_model.graph.input] |
|
input_initializer = [ |
|
node.name for node in onnx_model.graph.initializer |
|
] |
|
net_feed_input = list(set(input_all) - set(input_initializer)) |
|
assert (len(net_feed_input) == 1) |
|
sess = rt.InferenceSession(output_file) |
|
from mmdet.core import bbox2result |
|
det_bboxes, det_labels = sess.run( |
|
None, {net_feed_input[0]: one_img.detach().numpy()}) |
|
# only compare a part of result |
|
bbox_results = bbox2result(det_bboxes, det_labels, 1) |
|
onnx_results = bbox_results[0] |
|
assert np.allclose( |
|
pytorch_result[0][:, 4], onnx_results[:, 4] |
|
), 'The outputs are different between Pytorch and ONNX' |
|
print('The numerical values are same between Pytorch and ONNX') |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser( |
|
description='Convert MMDetection models to ONNX') |
|
parser.add_argument('config', help='test config file path') |
|
parser.add_argument('checkpoint', help='checkpoint file') |
|
parser.add_argument('--input-img', type=str, help='Images for input') |
|
parser.add_argument('--show', action='store_true', help='show onnx graph') |
|
parser.add_argument('--output-file', type=str, default='tmp.onnx') |
|
parser.add_argument('--opset-version', type=int, default=11) |
|
parser.add_argument( |
|
'--verify', |
|
action='store_true', |
|
help='verify the onnx model output against pytorch output') |
|
parser.add_argument( |
|
'--shape', |
|
type=int, |
|
nargs='+', |
|
default=[800, 1216], |
|
help='input image size') |
|
parser.add_argument( |
|
'--mean', |
|
type=int, |
|
nargs='+', |
|
default=[123.675, 116.28, 103.53], |
|
help='mean value used for preprocess input data') |
|
parser.add_argument( |
|
'--std', |
|
type=int, |
|
nargs='+', |
|
default=[58.395, 57.12, 57.375], |
|
help='variance value used for preprocess input data') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
|
|
assert args.opset_version == 11, 'MMDet only support opset 11 now' |
|
|
|
if not args.input_img: |
|
args.input_img = osp.join( |
|
osp.dirname(__file__), '../tests/data/color.jpg') |
|
|
|
if len(args.shape) == 1: |
|
input_shape = (1, 3, args.shape[0], args.shape[0]) |
|
elif len(args.shape) == 2: |
|
input_shape = (1, 3) + tuple(args.shape) |
|
else: |
|
raise ValueError('invalid input shape') |
|
|
|
assert len(args.mean) == 3 |
|
assert len(args.std) == 3 |
|
|
|
normalize_cfg = { |
|
'mean': np.array(args.mean, dtype=np.float32), |
|
'std': np.array(args.std, dtype=np.float32) |
|
} |
|
|
|
cfg = mmcv.Config.fromfile(args.config) |
|
# import modules from string list. |
|
if cfg.get('custom_imports', None): |
|
from mmcv.utils import import_modules_from_strings |
|
import_modules_from_strings(**cfg['custom_imports']) |
|
cfg.model.pretrained = None |
|
cfg.data.test.test_mode = True |
|
|
|
# build the model |
|
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) |
|
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') |
|
|
|
# conver model to onnx file |
|
pytorch2onnx( |
|
model, |
|
args.input_img, |
|
input_shape, |
|
opset_version=args.opset_version, |
|
show=args.show, |
|
output_file=args.output_file, |
|
verify=args.verify, |
|
normalize_cfg=normalize_cfg)
|
|
|