`ultralytics 8.1.24` new OpenVINO 2023.3 export updates (#8417)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/8680/head^2 v8.1.24
Adrian Boguszewski 12 months ago committed by GitHub
parent 16a91a9b6b
commit a7cfd83c5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      pyproject.toml
  2. 3
      tests/test_python.py
  3. 2
      ultralytics/__init__.py
  4. 47
      ultralytics/engine/exporter.py
  5. 10
      ultralytics/nn/autobackend.py
  6. 1
      ultralytics/utils/torch_utils.py

@ -98,7 +98,7 @@ dev = [
export = [
"onnx>=1.12.0", # ONNX export
"coremltools>=7.0; platform_system != 'Windows' and python_version <= '3.11'", # CoreML supported on macOS and Linux
"openvino-dev>=2023.0; python_version <= '3.11'", # OpenVINO export
"openvino>=2023.3; python_version <= '3.11'", # OpenVINO export
"tensorflow<=2.13.1; python_version <= '3.11'", # TF bug https://github.com/ultralytics/ultralytics/issues/5161
"tensorflowjs>=3.9.0; python_version <= '3.11'", # TF.js export, automatically installs tensorflow
]

@ -29,7 +29,7 @@ from ultralytics.utils import (
is_dir_writeable,
)
from ultralytics.utils.downloads import download
from ultralytics.utils.torch_utils import TORCH_1_9
from ultralytics.utils.torch_utils import TORCH_1_9, TORCH_1_13
MODEL = WEIGHTS_DIR / "path with spaces" / "yolov8n.pt" # test spaces in path
CFG = "yolov8n.yaml"
@ -219,6 +219,7 @@ def test_export_onnx():
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="OpenVINO not supported in Python 3.12")
@pytest.mark.skipif(not TORCH_1_13, reason="OpenVINO requires torch>=1.13")
def test_export_openvino():
"""Test exporting the YOLO model to OpenVINO format."""
f = YOLO(MODEL).export(format="openvino")

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.23"
__version__ = "8.1.24"
from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

@ -87,7 +87,7 @@ from ultralytics.utils.checks import PYTHON_VERSION, check_imgsz, check_is_path_
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets
from ultralytics.utils.files import file_size, spaces_in_path
from ultralytics.utils.ops import Profile
from ultralytics.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode
from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device, smart_inference_mode
def export_formats():
@ -283,7 +283,7 @@ class Exporter:
f[0], _ = self.export_torchscript()
if engine: # TensorRT required before ONNX
f[1], _ = self.export_engine()
if onnx or xml: # OpenVINO requires ONNX
if onnx: # ONNX
f[2], _ = self.export_onnx()
if xml: # OpenVINO
f[3], _ = self.export_openvino()
@ -411,16 +411,16 @@ class Exporter:
@try_export
def export_openvino(self, prefix=colorstr("OpenVINO:")):
"""YOLOv8 OpenVINO export."""
check_requirements("openvino-dev>=2023.0") # requires openvino-dev: https://pypi.org/project/openvino-dev/
import openvino.runtime as ov # noqa
from openvino.tools import mo # noqa
check_requirements("openvino>=2023.3") # requires openvino: https://pypi.org/project/openvino-dev/
import openvino as ov # noqa
LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}")
fq = str(self.file).replace(self.file.suffix, f"_int8_openvino_model{os.sep}")
f_onnx = self.file.with_suffix(".onnx")
f_ov = str(Path(f) / self.file.with_suffix(".xml").name)
fq_ov = str(Path(fq) / self.file.with_suffix(".xml").name)
assert TORCH_1_13, f"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed"
ov_model = ov.convert_model(
self.model.cpu(),
input=None if self.args.dynamic else [self.im.shape],
example_input=self.im,
)
def serialize(ov_model, file):
"""Set RT info, serialize and save metadata YAML."""
@ -433,21 +433,19 @@ class Exporter:
if self.model.task != "classify":
ov_model.set_rt_info("fit_to_window_letterbox", ["model_info", "resize_type"])
ov.serialize(ov_model, file) # save
ov.save_model(ov_model, file, compress_to_fp16=self.args.half)
yaml_save(Path(file).parent / "metadata.yaml", self.metadata) # add metadata.yaml
ov_model = mo.convert_model(
f_onnx, model_name=self.pretty_name, framework="onnx", compress_to_fp16=self.args.half
) # export
if self.args.int8:
fq = str(self.file).replace(self.file.suffix, f"_int8_openvino_model{os.sep}")
fq_ov = str(Path(fq) / self.file.with_suffix(".xml").name)
if not self.args.data:
self.args.data = DEFAULT_CFG.data or "coco128.yaml"
LOGGER.warning(
f"{prefix} WARNING ⚠ INT8 export requires a missing 'data' arg for calibration. "
f"Using default 'data={self.args.data}'."
)
check_requirements("nncf>=2.5.0")
check_requirements("nncf>=2.8.0")
import nncf
def transform_fn(data_item):
@ -466,6 +464,7 @@ class Exporter:
if n < 300:
LOGGER.warning(f"{prefix} WARNING ⚠ >300 images recommended for INT8 calibration, found {n} images.")
quantization_dataset = nncf.Dataset(dataset, transform_fn)
ignored_scope = None
if isinstance(self.model.model[-1], Detect):
# Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
@ -473,20 +472,24 @@ class Exporter:
ignored_scope = nncf.IgnoredScope( # ignore operations
patterns=[
f"/{head_module_name}/Add",
f"/{head_module_name}/Sub",
f"/{head_module_name}/Mul",
f"/{head_module_name}/Div",
f"/{head_module_name}/dfl",
f".*{head_module_name}/.*/Add",
f".*{head_module_name}/.*/Sub*",
f".*{head_module_name}/.*/Mul*",
f".*{head_module_name}/.*/Div*",
f".*{head_module_name}\\.dfl.*",
],
names=[f"/{head_module_name}/Sigmoid"],
types=["Sigmoid"],
)
quantized_ov_model = nncf.quantize(
ov_model, quantization_dataset, preset=nncf.QuantizationPreset.MIXED, ignored_scope=ignored_scope
)
serialize(quantized_ov_model, fq_ov)
return fq, None
f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}")
f_ov = str(Path(f) / self.file.with_suffix(".xml").name)
serialize(ov_model, f_ov)
return f, None

@ -180,17 +180,17 @@ class AutoBackend(nn.Module):
metadata = session.get_modelmeta().custom_metadata_map # metadata
elif xml: # OpenVINO
LOGGER.info(f"Loading {w} for OpenVINO inference...")
check_requirements("openvino>=2023.0") # requires openvino-dev: https://pypi.org/project/openvino-dev/
from openvino.runtime import Core, Layout, get_batch # noqa
check_requirements("openvino>=2023.3") # requires openvino: https://pypi.org/project/openvino-dev/
import openvino as ov # noqa
core = Core()
core = ov.Core()
w = Path(w)
if not w.is_file(): # if not *.xml
w = next(w.glob("*.xml")) # get *.xml file from *_openvino_model dir
ov_model = core.read_model(model=str(w), weights=w.with_suffix(".bin"))
if ov_model.get_parameters()[0].get_layout().empty:
ov_model.get_parameters()[0].set_layout(Layout("NCHW"))
batch_dim = get_batch(ov_model)
ov_model.get_parameters()[0].set_layout(ov.Layout("NCHW"))
batch_dim = ov.get_batch(ov_model)
if batch_dim.is_static:
batch_size = batch_dim.get_length()
ov_compiled_model = core.compile_model(ov_model, device_name="AUTO") # AUTO selects best available device

@ -25,6 +25,7 @@ except ImportError:
thop = None
TORCH_1_9 = check_version(torch.__version__, "1.9.0")
TORCH_1_13 = check_version(torch.__version__, "1.13.0")
TORCH_2_0 = check_version(torch.__version__, "2.0.0")
TORCHVISION_0_10 = check_version(torchvision.__version__, "0.10.0")
TORCHVISION_0_11 = check_version(torchvision.__version__, "0.11.0")

Loading…
Cancel
Save