@ -1,31 +1,27 @@
import argparse
import os . path as osp
import warnings
from functools import partial
import numpy as np
import onnx
import onnxruntime as rt
import torch
from mmcv import DictAction
from mmcv import Config , DictAction
from mmdet . core . export import ( build_model_from_cfg ,
generate_inputs_and_wrap_model ,
preprocess_example_input )
from mmdet . core . export import build_model_from_cfg , preprocess_example_input
from mmdet . core . export . model_wrappers import ONNXRuntimeDetector
def pytorch2onnx ( config_path ,
checkpoint_path ,
def pytorch2onnx ( model ,
input_img ,
input_shape ,
normalize_cfg ,
opset_version = 11 ,
show = False ,
output_file = ' tmp.onnx ' ,
verify = False ,
normalize_cfg = None ,
dataset = ' coco ' ,
test_img = None ,
do_simplify = False ,
cfg_options = None ,
dynamic_export = None ) :
input_config = {
@ -33,13 +29,17 @@ def pytorch2onnx(config_path,
' input_path ' : input_img ,
' normalize_cfg ' : normalize_cfg
}
# prepare original model and meta for verifying the onnx model
orig_model = build_model_from_cfg (
config_path , checkpoint_path , cfg_options = cfg_options )
# prepare input
one_img , one_meta = preprocess_example_input ( input_config )
model , tensor_data = generate_inputs_and_wrap_model (
config_path , checkpoint_path , input_config , cfg_options = cfg_options )
img_list , img_meta_list = [ one_img ] , [ [ one_meta ] ]
# replace original forward function
origin_forward = model . forward
model . forward = partial (
model . forward ,
img_metas = img_meta_list ,
return_loss = False ,
rescale = False )
output_names = [ ' dets ' , ' labels ' ]
if model . with_mask :
output_names . append ( ' masks ' )
@ -66,7 +66,7 @@ def pytorch2onnx(config_path,
torch . onnx . export (
model ,
tensor_data ,
img_list ,
output_file ,
input_names = [ input_name ] ,
output_names = output_names ,
@ -77,7 +77,7 @@ def pytorch2onnx(config_path,
opset_version = opset_version ,
dynamic_axes = dynamic_axes )
model . forward = orig_model . forward
model . forward = origin_ forward
# get the custom op path
ort_custom_op_path = ' '
@ -89,79 +89,74 @@ def pytorch2onnx(config_path,
you may have to build mmcv with ONNXRuntime from source . ' )
if do_simplify :
from mmdet import digit_version
import onnxsim
from mmdet import digit_version
min_required_version = ' 0.3.0 '
assert digit_version ( onnxsim . __version__ ) > = digit_version (
min_required_version
) , f ' Requires to install onnx-simplify>= { min_required_version } '
input_dic = { ' input ' : one_ img. detach ( ) . cpu ( ) . numpy ( ) }
input_dic = { ' input ' : img_list [ 0 ] . detach ( ) . cpu ( ) . numpy ( ) }
onnxsim . simplify (
output_file , input_data = input_dic , custom_lib = ort_custom_op_path )
print ( f ' Successfully exported ONNX model: { output_file } ' )
if verify :
from mmdet . core import get_classes , bbox2result
from mmdet . apis import show_result_pyplot
model . CLASSES = get_classes ( dataset )
num_classes = len ( model . CLASSES )
# check by onnx
onnx_model = onnx . load ( output_file )
onnx . checker . check_model ( onnx_model )
# wrap onnx model
onnx_model = ONNXRuntimeDetector ( output_file , model . CLASSES , 0 )
if dynamic_export :
# scale up to test dynamic shape
h , w = [ int ( ( _ * 1.5 ) / / 32 * 32 ) for _ in input_shape [ 2 : ] ]
h , w = min ( 1344 , h ) , min ( 1344 , w )
input_config [ ' input_shape ' ] = ( 1 , 3 , h , w )
if test_img is not None :
input_config [ ' input_path ' ] = test_img
if test_img is None :
input_config [ ' input_path ' ] = input_img
# prepare input once again
one_img , one_meta = preprocess_example_input ( input_config )
tensor_data = [ one_img ]
img_list , img_meta_list = [ one_img ] , [ [ one_meta ] ]
# get pytorch output
pytorch_results = model ( tensor_data , [ [ one_meta ] ] , return_loss = False )
pytorch_results = pytorch_results [ 0 ]
# get onnx output
input_all = [ node . name for node in onnx_model . graph . input ]
input_initializer = [
node . name for node in onnx_model . graph . initializer
]
net_feed_input = list ( set ( input_all ) - set ( input_initializer ) )
assert ( len ( net_feed_input ) == 1 )
session_options = rt . SessionOptions ( )
# register custom op for ONNX Runtime
if osp . exists ( ort_custom_op_path ) :
session_options . register_custom_ops_library ( ort_custom_op_path )
feed_input_img = one_img . detach ( ) . numpy ( )
pytorch_results = model (
img_list , img_metas = img_meta_list , return_loss = False ,
rescale = True ) [ 0 ]
img_list = [ _ . cuda ( ) . contiguous ( ) for _ in img_list ]
if dynamic_export :
# test batch with two input images
feed_input_img = np . vstack ( [ feed_input_img , feed_input_img ] )
sess = rt . InferenceSession ( output_file , session_options )
onnx_outputs = sess . run ( None , { net_feed_input [ 0 ] : feed_input_img } )
output_names = [ _ . name for _ in sess . get_outputs ( ) ]
output_shapes = [ _ . shape for _ in onnx_outputs ]
print ( f ' ONNX Runtime output names: { output_names } , \
output shapes : { output_shapes } ' )
# get last image's outputs
onnx_outputs = [ _ [ - 1 ] for _ in onnx_outputs ]
ort_dets , ort_labels = onnx_outputs [ : 2 ]
onnx_results = bbox2result ( ort_dets , ort_labels , num_classes )
if model . with_mask :
segm_results = onnx_outputs [ 2 ]
if segm_results . dtype != np . bool :
segm_results = ( segm_results * 255 ) . astype ( np . uint8 )
cls_segms = [ [ ] for _ in range ( num_classes ) ]
for i in range ( ort_dets . shape [ 0 ] ) :
cls_segms [ ort_labels [ i ] ] . append ( segm_results [ i ] )
onnx_results = ( onnx_results , cls_segms )
img_list = img_list + [ _ . flip ( - 1 ) . contiguous ( ) for _ in img_list ]
img_meta_list = img_meta_list * 2
# get onnx output
onnx_results = onnx_model (
img_list , img_metas = img_meta_list , return_loss = False ) [ 0 ]
# visualize predictions
score_thr = 0.3
if show :
show_result_pyplot (
model , one_meta [ ' show_img ' ] , pytorch_results , title = ' Pytorch ' )
show_result_pyplot (
model , one_meta [ ' show_img ' ] , onnx_results , title = ' ONNXRuntime ' )
out_file_ort , out_file_pt = None , None
else :
out_file_ort , out_file_pt = ' show-ort.png ' , ' show-pt.png '
show_img = one_meta [ ' show_img ' ]
model . show_result (
show_img ,
pytorch_results ,
score_thr = score_thr ,
show = True ,
win_name = ' PyTorch ' ,
out_file = out_file_pt )
onnx_model . show_result (
show_img ,
onnx_results ,
score_thr = score_thr ,
show = True ,
win_name = ' ONNXRuntime ' ,
out_file = out_file_ort )
# compare a part of result
if model . with_mask :
@ -179,6 +174,19 @@ def pytorch2onnx(config_path,
print ( ' The numerical values are the same between Pytorch and ONNX ' )
def parse_normalize_cfg ( test_pipeline ) :
transforms = None
for pipeline in test_pipeline :
if ' transforms ' in pipeline :
transforms = pipeline [ ' transforms ' ]
break
assert transforms is not None , ' Failed to find `transforms` '
norm_config_li = [ _ for _ in transforms if _ [ ' type ' ] == ' Normalize ' ]
assert len ( norm_config_li ) == 1 , ' `norm_config` should only have one '
norm_config = norm_config_li [ 0 ]
return norm_config
def parse_args ( ) :
parser = argparse . ArgumentParser (
description = ' Convert MMDetection models to ONNX ' )
@ -194,7 +202,11 @@ def parse_args():
parser . add_argument (
' --test-img ' , type = str , default = None , help = ' Images for test ' )
parser . add_argument (
' --dataset ' , type = str , default = ' coco ' , help = ' Dataset name ' )
' --dataset ' ,
type = str ,
default = ' coco ' ,
help = ' Dataset name. This argument is deprecated and will be removed \
in future releases . ' )
parser . add_argument (
' --verify ' ,
action = ' store_true ' ,
@ -214,13 +226,15 @@ def parse_args():
type = float ,
nargs = ' + ' ,
default = [ 123.675 , 116.28 , 103.53 ] ,
help = ' mean value used for preprocess input data ' )
help = ' mean value used for preprocess input data.This argument \
is deprecated and will be removed in future releases . ' )
parser . add_argument (
' --std ' ,
type = float ,
nargs = ' + ' ,
default = [ 58.395 , 57.12 , 57.375 ] ,
help = ' variance value used for preprocess input data ' )
help = ' variance value used for preprocess input data. '
' This argument is deprecated and will be removed in future releases. ' )
parser . add_argument (
' --cfg-options ' ,
nargs = ' + ' ,
@ -241,38 +255,51 @@ def parse_args():
if __name__ == ' __main__ ' :
args = parse_args ( )
warnings . warn ( ' Arguments like `--mean`, `--std`, `--dataset` would be \
parsed directly from config file and are deprecated and \
will be removed in future releases . ' )
assert args . opset_version == 11 , ' MMDet only support opset 11 now '
if not args . input_img :
args . input_img = osp . join (
osp . dirname ( __file__ ) , ' ../../tests/data/color.jpg ' )
try :
from mmcv . onnx . symbolic import register_extra_symbolics
except ModuleNotFoundError :
raise NotImplementedError ( ' please update mmcv to version>=v1.0.4 ' )
register_extra_symbolics ( args . opset_version )
cfg = Config . fromfile ( args . config )
if args . cfg_options is not None :
cfg . merge_from_dict ( args . cfg_options )
if len ( args . shape ) == 1 :
if args . shape is None :
img_scale = cfg . test_pipeline [ 1 ] [ ' img_scale ' ]
input_shape = ( 1 , 3 , img_scale [ 1 ] , img_scale [ 0 ] )
elif len ( args . shape ) == 1 :
input_shape = ( 1 , 3 , args . shape [ 0 ] , args . shape [ 0 ] )
elif len ( args . shape ) == 2 :
input_shape = ( 1 , 3 ) + tuple ( args . shape )
else :
raise ValueError ( ' invalid input shape ' )
assert len ( args . mean ) == 3
assert len ( args . std ) == 3
# build the model and load checkpoint
model = build_model_from_cfg ( args . config , args . checkpoint ,
args . cfg_options )
normalize_cfg = { ' mean ' : args . mean , ' std ' : args . std }
if not args . input_img :
args . input_img = osp . join ( osp . dirname ( __file__ ) , ' ../../demo/demo.jpg ' )
normalize_cfg = parse_normalize_cfg ( cfg . test_pipeline )
# convert model to onnx file
pytorch2onnx (
args . config ,
args . checkpoint ,
model ,
args . input_img ,
input_shape ,
normalize_cfg ,
opset_version = args . opset_version ,
show = args . show ,
output_file = args . output_file ,
verify = args . verify ,
normalize_cfg = normalize_cfg ,
dataset = args . dataset ,
test_img = args . test_img ,
do_simplify = args . simplify ,
cfg_options = args . cfg_options ,
dynamic_export = args . dynamic_export )