`ultralytics 8.2.27` replace `onnxsim` with `onnxslim` (#12989)

Co-authored-by: inisis <desmond.yao@buaa.edu.cn>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: inisis <46103969+inisis@users.noreply.github.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
pull/13289/head v8.2.27
Kayzwer 6 months ago committed by GitHub
parent dd13707bf8
commit 8fb140688a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 3
      docker/Dockerfile-arm64
  2. 2
      docs/en/modes/export.md
  3. 1
      tests/test_cuda.py
  4. 7
      tests/test_exports.py
  5. 2
      ultralytics/__init__.py
  6. 2
      ultralytics/cfg/default.yaml
  7. 24
      ultralytics/engine/exporter.py
  8. 2
      ultralytics/nn/modules/__init__.py
  9. 4
      ultralytics/utils/benchmarks.py

@ -16,10 +16,9 @@ ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/Arial.ttf \
# Install linux packages # Install linux packages
# g++ required to build 'tflite_support' and 'lap' packages, libusb-1.0-0 required for 'tflite_support' package # g++ required to build 'tflite_support' and 'lap' packages, libusb-1.0-0 required for 'tflite_support' package
# cmake and build-essential are needed to build 'onnxsim' when exporting to TFLite
# pkg-config and libhdf5-dev (not included) are needed to build 'h5py==3.11.0' aarch64 wheel required by 'tensorflow' # pkg-config and libhdf5-dev (not included) are needed to build 'h5py==3.11.0' aarch64 wheel required by 'tensorflow'
RUN apt update \ RUN apt update \
&& apt install --no-install-recommends -y python3-pip git zip curl htop gcc libgl1 libglib2.0-0 libpython3-dev gnupg g++ libusb-1.0-0 build-essential && apt install --no-install-recommends -y python3-pip git zip curl htop gcc libgl1 libglib2.0-0 libpython3-dev gnupg g++ libusb-1.0-0
# Create working directory # Create working directory
WORKDIR $APP_HOME WORKDIR $APP_HOME

@ -83,7 +83,7 @@ This table details the configurations and options available for exporting YOLO m
| `half` | `bool` | `False` | Enables FP16 (half-precision) quantization, reducing model size and potentially speeding up inference on supported hardware. | | `half` | `bool` | `False` | Enables FP16 (half-precision) quantization, reducing model size and potentially speeding up inference on supported hardware. |
| `int8` | `bool` | `False` | Activates INT8 quantization, further compressing the model and speeding up inference with minimal accuracy loss, primarily for edge devices. | | `int8` | `bool` | `False` | Activates INT8 quantization, further compressing the model and speeding up inference with minimal accuracy loss, primarily for edge devices. |
| `dynamic` | `bool` | `False` | Allows dynamic input sizes for ONNX and TensorRT exports, enhancing flexibility in handling varying image dimensions. | | `dynamic` | `bool` | `False` | Allows dynamic input sizes for ONNX and TensorRT exports, enhancing flexibility in handling varying image dimensions. |
| `simplify` | `bool` | `False` | Simplifies the model graph for ONNX exports, potentially improving performance and compatibility. | | `simplify` | `bool` | `False` | Simplifies the model graph for ONNX exports with `onnxsim`, potentially improving performance and compatibility. |
| `opset` | `int` | `None` | Specifies the ONNX opset version for compatibility with different ONNX parsers and runtimes. If not set, uses the latest supported version. | | `opset` | `int` | `None` | Specifies the ONNX opset version for compatibility with different ONNX parsers and runtimes. If not set, uses the latest supported version. |
| `workspace` | `float` | `4.0` | Sets the maximum workspace size in GiB for TensorRT optimizations, balancing memory usage and performance. | | `workspace` | `float` | `4.0` | Sets the maximum workspace size in GiB for TensorRT optimizations, balancing memory usage and performance. |
| `nms` | `bool` | `False` | Adds Non-Maximum Suppression (NMS) to the CoreML export, essential for accurate and efficient detection post-processing. | | `nms` | `bool` | `False` | Adds Non-Maximum Suppression (NMS) to the CoreML export, essential for accurate and efficient detection post-processing. |

@ -41,6 +41,7 @@ def test_export_engine_matrix(task, dynamic, int8, half, batch):
batch=batch, batch=batch,
data=TASK2DATA[task], data=TASK2DATA[task],
workspace=1, # reduce workspace GB for less resource utilization during testing workspace=1, # reduce workspace GB for less resource utilization during testing
simplify=True, # use 'onnxslim'
) )
YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32) # exported model inference YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32) # exported model inference
Path(file).unlink() # cleanup Path(file).unlink() # cleanup

@ -72,8 +72,10 @@ def test_export_openvino_matrix(task, dynamic, int8, half, batch):
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("task, dynamic, int8, half, batch", product(TASKS, [True, False], [False], [False], [1, 2])) @pytest.mark.parametrize(
def test_export_onnx_matrix(task, dynamic, int8, half, batch): "task, dynamic, int8, half, batch, simplify", product(TASKS, [True, False], [False], [False], [1, 2], [True, False])
)
def test_export_onnx_matrix(task, dynamic, int8, half, batch, simplify):
"""Test YOLO exports to ONNX format.""" """Test YOLO exports to ONNX format."""
file = YOLO(TASK2MODEL[task]).export( file = YOLO(TASK2MODEL[task]).export(
format="onnx", format="onnx",
@ -82,6 +84,7 @@ def test_export_onnx_matrix(task, dynamic, int8, half, batch):
int8=int8, int8=int8,
half=half, half=half,
batch=batch, batch=batch,
simplify=simplify,
) )
YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32) # exported model inference YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32) # exported model inference
Path(file).unlink() # cleanup Path(file).unlink() # cleanup

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.26" __version__ = "8.2.27"
import os import os

@ -81,7 +81,7 @@ keras: False # (bool) use Kera=s
optimize: False # (bool) TorchScript: optimize for mobile optimize: False # (bool) TorchScript: optimize for mobile
int8: False # (bool) CoreML/TF INT8 quantization int8: False # (bool) CoreML/TF INT8 quantization
dynamic: False # (bool) ONNX/TF/TensorRT: dynamic axes dynamic: False # (bool) ONNX/TF/TensorRT: dynamic axes
simplify: False # (bool) ONNX: simplify model simplify: False # (bool) ONNX: simplify model using `onnxslim`
opset: # (int, optional) ONNX: opset version opset: # (int, optional) ONNX: opset version
workspace: 4 # (int) TensorRT: workspace size (GB) workspace: 4 # (int) TensorRT: workspace size (GB)
nms: False # (bool) CoreML: add NMS nms: False # (bool) CoreML: add NMS

@ -384,7 +384,7 @@ class Exporter:
"""YOLOv8 ONNX export.""" """YOLOv8 ONNX export."""
requirements = ["onnx>=1.12.0"] requirements = ["onnx>=1.12.0"]
if self.args.simplify: if self.args.simplify:
requirements += ["cmake", "onnxsim>=0.4.33", "onnxruntime" + ("-gpu" if torch.cuda.is_available() else "")] requirements += ["onnxslim==0.1.28", "onnxruntime" + ("-gpu" if torch.cuda.is_available() else "")]
check_requirements(requirements) check_requirements(requirements)
import onnx # noqa import onnx # noqa
@ -421,14 +421,17 @@ class Exporter:
# Simplify # Simplify
if self.args.simplify: if self.args.simplify:
try: try:
import onnxsim import onnxslim
LOGGER.info(f"{prefix} simplifying with onnxsim {onnxsim.__version__}...") LOGGER.info(f"{prefix} slimming with onnxslim {onnxslim.__version__}...")
# subprocess.run(f'onnxsim "{f}" "{f}"', shell=True) model_onnx = onnxslim.slim(model_onnx)
model_onnx, check = onnxsim.simplify(model_onnx)
assert check, "Simplified ONNX model could not be validated" # ONNX Simplifier (deprecated as must be compiled with 'cmake' in aarch64 and Conda CI environments)
# import onnxsim
# model_onnx, check = onnxsim.simplify(model_onnx)
# assert check, "Simplified ONNX model could not be validated"
except Exception as e: except Exception as e:
LOGGER.info(f"{prefix} simplifier failure: {e}") LOGGER.warning(f"{prefix} simplifier failure: {e}")
# Metadata # Metadata
for k, v in self.metadata.items(): for k, v in self.metadata.items():
@ -672,8 +675,8 @@ class Exporter:
def export_engine(self, prefix=colorstr("TensorRT:")): def export_engine(self, prefix=colorstr("TensorRT:")):
"""YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt.""" """YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt."""
assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'" assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
self.args.simplify = True # self.args.simplify = True
f_onnx, _ = self.export_onnx() # run before trt import https://github.com/ultralytics/ultralytics/issues/7016 f_onnx, _ = self.export_onnx() # run before TRT import https://github.com/ultralytics/ultralytics/issues/7016
try: try:
import tensorrt as trt # noqa import tensorrt as trt # noqa
@ -815,13 +818,12 @@ class Exporter:
import tensorflow as tf # noqa import tensorflow as tf # noqa
check_requirements( check_requirements(
( (
"cmake", # 'cmake' is needed to build onnxsim on aarch64 and Conda runners
"keras", # required by onnx2tf package "keras", # required by onnx2tf package
"tf_keras", # required by onnx2tf package "tf_keras", # required by onnx2tf package
"onnx>=1.12.0", "onnx>=1.12.0",
"onnx2tf>1.17.5,<=1.22.3", "onnx2tf>1.17.5,<=1.22.3",
"sng4onnx>=1.0.1", "sng4onnx>=1.0.1",
"onnxsim>=0.4.33", "onnxslim==0.1.28",
"onnx_graphsurgeon>=0.3.26", "onnx_graphsurgeon>=0.3.26",
"tflite_support<=0.4.3" if IS_JETSON else "tflite_support", # fix ImportError 'GLIBCXX_3.4.29' "tflite_support<=0.4.3" if IS_JETSON else "tflite_support", # fix ImportError 'GLIBCXX_3.4.29'
"flatbuffers>=23.5.26,<100", # update old 'flatbuffers' included inside tensorflow package "flatbuffers>=23.5.26,<100", # update old 'flatbuffers' included inside tensorflow package

@ -13,7 +13,7 @@ Example:
m = Conv(128, 128) m = Conv(128, 128)
f = f'{m._get_name()}.onnx' f = f'{m._get_name()}.onnx'
torch.onnx.export(m, x, f) torch.onnx.export(m, x, f)
os.system(f'onnxsim {f} {f} && open {f}') os.system(f'onnxslim {f} {f} && open {f}') # pip install onnxslim
``` ```
""" """

@ -457,6 +457,8 @@ class ProfileModels:
input_tensor = sess.get_inputs()[0] input_tensor = sess.get_inputs()[0]
input_type = input_tensor.type input_type = input_tensor.type
dynamic = not all(isinstance(dim, int) and dim >= 0 for dim in input_tensor.shape) # dynamic input shape
input_shape = (1, 3, self.imgsz, self.imgsz) if dynamic else input_tensor.shape
# Mapping ONNX datatype to numpy datatype # Mapping ONNX datatype to numpy datatype
if "float16" in input_type: if "float16" in input_type:
@ -472,7 +474,7 @@ class ProfileModels:
else: else:
raise ValueError(f"Unsupported ONNX datatype {input_type}") raise ValueError(f"Unsupported ONNX datatype {input_type}")
input_data = np.random.rand(*input_tensor.shape).astype(input_dtype) input_data = np.random.rand(*input_shape).astype(input_dtype)
input_name = input_tensor.name input_name = input_tensor.name
output_name = sess.get_outputs()[0].name output_name = sess.get_outputs()[0].name

Loading…
Cancel
Save