|
|
|
@ -101,31 +101,17 @@ def pytorch2onnx(model, |
|
|
|
|
|
|
|
|
|
model.forward = origin_forward |
|
|
|
|
|
|
|
|
|
# get the custom op path |
|
|
|
|
ort_custom_op_path = '' |
|
|
|
|
try: |
|
|
|
|
from mmcv.ops import get_onnxruntime_op_path |
|
|
|
|
ort_custom_op_path = get_onnxruntime_op_path() |
|
|
|
|
except (ImportError, ModuleNotFoundError): |
|
|
|
|
warnings.warn('If input model has custom op from mmcv, \ |
|
|
|
|
you may have to build mmcv with ONNXRuntime from source.') |
|
|
|
|
|
|
|
|
|
if do_simplify: |
|
|
|
|
import onnxsim |
|
|
|
|
|
|
|
|
|
from mmdet import digit_version |
|
|
|
|
|
|
|
|
|
min_required_version = '0.3.0' |
|
|
|
|
min_required_version = '0.4.0' |
|
|
|
|
assert digit_version(onnxsim.__version__) >= digit_version( |
|
|
|
|
min_required_version |
|
|
|
|
), f'Requires to install onnx-simplify>={min_required_version}' |
|
|
|
|
), f'Requires to install onnxsim>={min_required_version}' |
|
|
|
|
|
|
|
|
|
input_dic = {'input': img_list[0].detach().cpu().numpy()} |
|
|
|
|
model_opt, check_ok = onnxsim.simplify( |
|
|
|
|
output_file, |
|
|
|
|
input_data=input_dic, |
|
|
|
|
custom_lib=ort_custom_op_path, |
|
|
|
|
dynamic_input_shape=dynamic_export) |
|
|
|
|
model_opt, check_ok = onnxsim.simplify(output_file) |
|
|
|
|
if check_ok: |
|
|
|
|
onnx.save(model_opt, output_file) |
|
|
|
|
print(f'Successfully simplified ONNX model: {output_file}') |
|
|
|
|