`ultralytics 8.2.9` OpenVINO INT8 fixes and tests (#10423)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/11667/head v8.2.9
Burhan 7 months ago committed by GitHub
parent 299797ff9e
commit 2583f842b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 22
      tests/__init__.py
  2. 46
      tests/test_cli.py
  3. 25
      tests/test_cuda.py
  4. 39
      tests/test_engine.py
  5. 128
      tests/test_exports.py
  6. 15
      tests/test_integrations.py
  7. 97
      tests/test_python.py
  8. 2
      ultralytics/__init__.py
  9. 3
      ultralytics/data/augment.py
  10. 3
      ultralytics/data/base.py
  11. 6
      ultralytics/data/build.py
  12. 66
      ultralytics/engine/exporter.py
  13. 4
      ultralytics/engine/model.py

@ -0,0 +1,22 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.utils import ASSETS, ROOT, WEIGHTS_DIR, checks, is_dir_writeable
# Constants used in tests
MODEL = WEIGHTS_DIR / "path with spaces" / "yolov8n.pt" # test spaces in path
CFG = "yolov8n.yaml"
SOURCE = ASSETS / "bus.jpg"
TMP = (ROOT / "../tests/tmp").resolve() # temp directory for test files
IS_TMP_WRITEABLE = is_dir_writeable(TMP)
CUDA_IS_AVAILABLE = checks.cuda_is_available()
CUDA_DEVICE_COUNT = checks.cuda_device_count()
__all__ = (
"MODEL",
"CFG",
"SOURCE",
"TMP",
"IS_TMP_WRITEABLE",
"CUDA_IS_AVAILABLE",
"CUDA_DEVICE_COUNT",
)

@ -4,24 +4,14 @@ import subprocess
import pytest import pytest
from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS
from ultralytics.utils import ASSETS, WEIGHTS_DIR, checks from ultralytics.utils import ASSETS, WEIGHTS_DIR, checks
CUDA_IS_AVAILABLE = checks.cuda_is_available() from . import CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE
CUDA_DEVICE_COUNT = checks.cuda_device_count()
TASK_ARGS = [ # Constants
("detect", "yolov8n", "coco8.yaml"), TASK_MODEL_DATA = [(task, WEIGHTS_DIR / TASK2MODEL[task], TASK2DATA[task]) for task in TASKS]
("segment", "yolov8n-seg", "coco8-seg.yaml"), MODELS = [WEIGHTS_DIR / TASK2MODEL[task] for task in TASKS]
("classify", "yolov8n-cls", "imagenet10"),
("pose", "yolov8n-pose", "coco8-pose.yaml"),
("obb", "yolov8n-obb", "dota8.yaml"),
] # (task, model, data)
EXPORT_ARGS = [
("yolov8n", "torchscript"),
("yolov8n-seg", "torchscript"),
("yolov8n-cls", "torchscript"),
("yolov8n-pose", "torchscript"),
("yolov8n-obb", "torchscript"),
] # (model, format)
def run(cmd): def run(cmd):
@ -38,28 +28,28 @@ def test_special_modes():
run("yolo cfg") run("yolo cfg")
@pytest.mark.parametrize("task,model,data", TASK_ARGS) @pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
def test_train(task, model, data): def test_train(task, model, data):
"""Test YOLO training for a given task, model, and data.""" """Test YOLO training for a given task, model, and data."""
run(f"yolo train {task} model={model}.yaml data={data} imgsz=32 epochs=1 cache=disk") run(f"yolo train {task} model={model} data={data} imgsz=32 epochs=1 cache=disk")
@pytest.mark.parametrize("task,model,data", TASK_ARGS) @pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
def test_val(task, model, data): def test_val(task, model, data):
"""Test YOLO validation for a given task, model, and data.""" """Test YOLO validation for a given task, model, and data."""
run(f"yolo val {task} model={WEIGHTS_DIR / model}.pt data={data} imgsz=32 save_txt save_json") run(f"yolo val {task} model={model} data={data} imgsz=32 save_txt save_json")
@pytest.mark.parametrize("task,model,data", TASK_ARGS) @pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
def test_predict(task, model, data): def test_predict(task, model, data):
"""Test YOLO prediction on sample assets for a given task and model.""" """Test YOLO prediction on sample assets for a given task and model."""
run(f"yolo predict model={WEIGHTS_DIR / model}.pt source={ASSETS} imgsz=32 save save_crop save_txt") run(f"yolo predict model={model} source={ASSETS} imgsz=32 save save_crop save_txt")
@pytest.mark.parametrize("model,format", EXPORT_ARGS) @pytest.mark.parametrize("model", MODELS)
def test_export(model, format): def test_export(model):
"""Test exporting a YOLO model to different formats.""" """Test exporting a YOLO model to different formats."""
run(f"yolo export model={WEIGHTS_DIR / model}.pt format={format} imgsz=32") run(f"yolo export model={model} format=torchscript imgsz=32")
def test_rtdetr(task="detect", model="yolov8n-rtdetr.yaml", data="coco8.yaml"): def test_rtdetr(task="detect", model="yolov8n-rtdetr.yaml", data="coco8.yaml"):
@ -129,10 +119,10 @@ def test_mobilesam():
# Slow Tests ----------------------------------------------------------------------------------------------------------- # Slow Tests -----------------------------------------------------------------------------------------------------------
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("task,model,data", TASK_ARGS) @pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA)
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available") @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available")
@pytest.mark.skipif(CUDA_DEVICE_COUNT < 2, reason="DDP is not available") @pytest.mark.skipif(CUDA_DEVICE_COUNT < 2, reason="DDP is not available")
def test_train_gpu(task, model, data): def test_train_gpu(task, model, data):
"""Test YOLO training on GPU(s) for various tasks and models.""" """Test YOLO training on GPU(s) for various tasks and models."""
run(f"yolo train {task} model={model}.yaml data={data} imgsz=32 epochs=1 device=0") # single GPU run(f"yolo train {task} model={model} data={data} imgsz=32 epochs=1 device=0") # single GPU
run(f"yolo train {task} model={model}.pt data={data} imgsz=32 epochs=1 device=0,1") # multi GPU run(f"yolo train {task} model={model} data={data} imgsz=32 epochs=1 device=0,1") # multi GPU

@ -4,14 +4,9 @@ import pytest
import torch import torch
from ultralytics import YOLO from ultralytics import YOLO
from ultralytics.utils import ASSETS, WEIGHTS_DIR, checks from ultralytics.utils import ASSETS, WEIGHTS_DIR
CUDA_IS_AVAILABLE = checks.cuda_is_available() from . import CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE, MODEL, SOURCE
CUDA_DEVICE_COUNT = checks.cuda_device_count()
MODEL = WEIGHTS_DIR / "path with spaces" / "yolov8n.pt" # test spaces in path
DATA = "coco8.yaml"
BUS = ASSETS / "bus.jpg"
def test_checks(): def test_checks():
@ -25,14 +20,14 @@ def test_checks():
def test_export_engine(): def test_export_engine():
"""Test exporting the YOLO model to NVIDIA TensorRT format.""" """Test exporting the YOLO model to NVIDIA TensorRT format."""
f = YOLO(MODEL).export(format="engine", device=0) f = YOLO(MODEL).export(format="engine", device=0)
YOLO(f)(BUS, device=0) YOLO(f)(SOURCE, device=0)
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available") @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available")
def test_train(): def test_train():
"""Test model training on a minimal dataset.""" """Test model training on a minimal dataset."""
device = 0 if CUDA_DEVICE_COUNT == 1 else [0, 1] device = 0 if CUDA_DEVICE_COUNT == 1 else [0, 1]
YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, device=device) # requires imgsz>=64 YOLO(MODEL).train(data="coco8.yaml", imgsz=64, epochs=1, device=device) # requires imgsz>=64
@pytest.mark.slow @pytest.mark.slow
@ -42,22 +37,22 @@ def test_predict_multiple_devices():
model = YOLO("yolov8n.pt") model = YOLO("yolov8n.pt")
model = model.cpu() model = model.cpu()
assert str(model.device) == "cpu" assert str(model.device) == "cpu"
_ = model(BUS) # CPU inference _ = model(SOURCE) # CPU inference
assert str(model.device) == "cpu" assert str(model.device) == "cpu"
model = model.to("cuda:0") model = model.to("cuda:0")
assert str(model.device) == "cuda:0" assert str(model.device) == "cuda:0"
_ = model(BUS) # CUDA inference _ = model(SOURCE) # CUDA inference
assert str(model.device) == "cuda:0" assert str(model.device) == "cuda:0"
model = model.cpu() model = model.cpu()
assert str(model.device) == "cpu" assert str(model.device) == "cpu"
_ = model(BUS) # CPU inference _ = model(SOURCE) # CPU inference
assert str(model.device) == "cpu" assert str(model.device) == "cpu"
model = model.cuda() model = model.cuda()
assert str(model.device) == "cuda:0" assert str(model.device) == "cuda:0"
_ = model(BUS) # CUDA inference _ = model(SOURCE) # CUDA inference
assert str(model.device) == "cuda:0" assert str(model.device) == "cuda:0"
@ -93,10 +88,10 @@ def test_predict_sam():
model.info() model.info()
# Run inference # Run inference
model(BUS, device=0) model(SOURCE, device=0)
# Run inference with bboxes prompt # Run inference with bboxes prompt
model(BUS, bboxes=[439, 437, 524, 709], device=0) model(SOURCE, bboxes=[439, 437, 524, 709], device=0)
# Run inference with points prompt # Run inference with points prompt
model(ASSETS / "zidane.jpg", points=[900, 370], labels=[1], device=0) model(ASSETS / "zidane.jpg", points=[900, 370], labels=[1], device=0)

@ -9,11 +9,7 @@ from ultralytics.engine.exporter import Exporter
from ultralytics.models.yolo import classify, detect, segment from ultralytics.models.yolo import classify, detect, segment
from ultralytics.utils import ASSETS, DEFAULT_CFG, WEIGHTS_DIR from ultralytics.utils import ASSETS, DEFAULT_CFG, WEIGHTS_DIR
CFG_DET = "yolov8n.yaml" from . import MODEL
CFG_SEG = "yolov8n-seg.yaml"
CFG_CLS = "yolov8n-cls.yaml" # or 'squeezenet1_0'
CFG = get_cfg(DEFAULT_CFG)
MODEL = WEIGHTS_DIR / "yolov8n"
def test_func(*args): # noqa def test_func(*args): # noqa
@ -26,15 +22,16 @@ def test_export():
exporter = Exporter() exporter = Exporter()
exporter.add_callback("on_export_start", test_func) exporter.add_callback("on_export_start", test_func)
assert test_func in exporter.callbacks["on_export_start"], "callback test failed" assert test_func in exporter.callbacks["on_export_start"], "callback test failed"
f = exporter(model=YOLO(CFG_DET).model) f = exporter(model=YOLO("yolov8n.yaml").model)
YOLO(f)(ASSETS) # exported model inference YOLO(f)(ASSETS) # exported model inference
def test_detect(): def test_detect():
"""Test object detection functionality.""" """Test object detection functionality."""
overrides = {"data": "coco8.yaml", "model": CFG_DET, "imgsz": 32, "epochs": 1, "save": False} overrides = {"data": "coco8.yaml", "model": "yolov8n.yaml", "imgsz": 32, "epochs": 1, "save": False}
CFG.data = "coco8.yaml" cfg = get_cfg(DEFAULT_CFG)
CFG.imgsz = 32 cfg.data = "coco8.yaml"
cfg.imgsz = 32
# Trainer # Trainer
trainer = detect.DetectionTrainer(overrides=overrides) trainer = detect.DetectionTrainer(overrides=overrides)
@ -43,7 +40,7 @@ def test_detect():
trainer.train() trainer.train()
# Validator # Validator
val = detect.DetectionValidator(args=CFG) val = detect.DetectionValidator(args=cfg)
val.add_callback("on_val_start", test_func) val.add_callback("on_val_start", test_func)
assert test_func in val.callbacks["on_val_start"], "callback test failed" assert test_func in val.callbacks["on_val_start"], "callback test failed"
val(model=trainer.best) # validate best.pt val(model=trainer.best) # validate best.pt
@ -54,7 +51,7 @@ def test_detect():
assert test_func in pred.callbacks["on_predict_start"], "callback test failed" assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
# Confirm there is no issue with sys.argv being empty. # Confirm there is no issue with sys.argv being empty.
with mock.patch.object(sys, "argv", []): with mock.patch.object(sys, "argv", []):
result = pred(source=ASSETS, model=f"{MODEL}.pt") result = pred(source=ASSETS, model=MODEL)
assert len(result), "predictor test failed" assert len(result), "predictor test failed"
overrides["resume"] = trainer.last overrides["resume"] = trainer.last
@ -70,9 +67,10 @@ def test_detect():
def test_segment(): def test_segment():
"""Test image segmentation functionality.""" """Test image segmentation functionality."""
overrides = {"data": "coco8-seg.yaml", "model": CFG_SEG, "imgsz": 32, "epochs": 1, "save": False} overrides = {"data": "coco8-seg.yaml", "model": "yolov8n-seg.yaml", "imgsz": 32, "epochs": 1, "save": False}
CFG.data = "coco8-seg.yaml" cfg = get_cfg(DEFAULT_CFG)
CFG.imgsz = 32 cfg.data = "coco8-seg.yaml"
cfg.imgsz = 32
# YOLO(CFG_SEG).train(**overrides) # works # YOLO(CFG_SEG).train(**overrides) # works
# Trainer # Trainer
@ -82,7 +80,7 @@ def test_segment():
trainer.train() trainer.train()
# Validator # Validator
val = segment.SegmentationValidator(args=CFG) val = segment.SegmentationValidator(args=cfg)
val.add_callback("on_val_start", test_func) val.add_callback("on_val_start", test_func)
assert test_func in val.callbacks["on_val_start"], "callback test failed" assert test_func in val.callbacks["on_val_start"], "callback test failed"
val(model=trainer.best) # validate best.pt val(model=trainer.best) # validate best.pt
@ -91,7 +89,7 @@ def test_segment():
pred = segment.SegmentationPredictor(overrides={"imgsz": [64, 64]}) pred = segment.SegmentationPredictor(overrides={"imgsz": [64, 64]})
pred.add_callback("on_predict_start", test_func) pred.add_callback("on_predict_start", test_func)
assert test_func in pred.callbacks["on_predict_start"], "callback test failed" assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
result = pred(source=ASSETS, model=f"{MODEL}-seg.pt") result = pred(source=ASSETS, model=WEIGHTS_DIR / "yolov8n-seg.pt")
assert len(result), "predictor test failed" assert len(result), "predictor test failed"
# Test resume # Test resume
@ -108,9 +106,10 @@ def test_segment():
def test_classify(): def test_classify():
"""Test image classification functionality.""" """Test image classification functionality."""
overrides = {"data": "imagenet10", "model": CFG_CLS, "imgsz": 32, "epochs": 1, "save": False} overrides = {"data": "imagenet10", "model": "yolov8n-cls.yaml", "imgsz": 32, "epochs": 1, "save": False}
CFG.data = "imagenet10" cfg = get_cfg(DEFAULT_CFG)
CFG.imgsz = 32 cfg.data = "imagenet10"
cfg.imgsz = 32
# YOLO(CFG_SEG).train(**overrides) # works # YOLO(CFG_SEG).train(**overrides) # works
# Trainer # Trainer
@ -120,7 +119,7 @@ def test_classify():
trainer.train() trainer.train()
# Validator # Validator
val = classify.ClassificationValidator(args=CFG) val = classify.ClassificationValidator(args=cfg)
val.add_callback("on_val_start", test_func) val.add_callback("on_val_start", test_func)
assert test_func in val.callbacks["on_val_start"], "callback test failed" assert test_func in val.callbacks["on_val_start"], "callback test failed"
val(model=trainer.best) val(model=trainer.best)

@ -0,0 +1,128 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import shutil
import uuid
from itertools import product
from pathlib import Path
import pytest
from ultralytics import YOLO
from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS
from ultralytics.utils import (
IS_RASPBERRYPI,
LINUX,
MACOS,
WINDOWS,
Retry,
checks,
)
from ultralytics.utils.torch_utils import TORCH_1_9, TORCH_1_13
from . import MODEL, SOURCE
# Constants
EXPORT_PARAMETERS_LIST = [ # generate all combinations but exclude those where both int8 and half are True
(task, dynamic, int8, half, batch)
for task, dynamic, int8, half, batch in product(TASKS, [True, False], [True, False], [True, False], [1, 2])
if not (int8 and half) # exclude cases where both int8 and half are True
]
def test_export_torchscript():
"""Test exporting the YOLO model to TorchScript format."""
f = YOLO(MODEL).export(format="torchscript", optimize=False, imgsz=32)
YOLO(f)(SOURCE, imgsz=32) # exported model inference
def test_export_onnx():
"""Test exporting the YOLO model to ONNX format."""
f = YOLO(MODEL).export(format="onnx", dynamic=True, imgsz=32)
YOLO(f)(SOURCE, imgsz=32) # exported model inference
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="OpenVINO not supported in Python 3.12")
@pytest.mark.skipif(not TORCH_1_13, reason="OpenVINO requires torch>=1.13")
def test_export_openvino():
"""Test exporting the YOLO model to OpenVINO format."""
f = YOLO(MODEL).export(format="openvino", imgsz=32)
YOLO(f)(SOURCE, imgsz=32) # exported model inference
@pytest.mark.slow
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="OpenVINO not supported in Python 3.12")
@pytest.mark.skipif(not TORCH_1_13, reason="OpenVINO requires torch>=1.13")
@pytest.mark.parametrize("task, dynamic, int8, half, batch", EXPORT_PARAMETERS_LIST)
def test_export_openvino_matrix(task, dynamic, int8, half, batch):
"""Test exporting the YOLO model to OpenVINO format."""
file = YOLO(TASK2MODEL[task]).export(
format="openvino",
imgsz=32,
dynamic=dynamic,
int8=int8,
half=half,
batch=batch,
data=TASK2DATA[task],
)
if WINDOWS:
# Use unique filenames due to Windows file permissions bug possibly due to latent threaded use
# See https://github.com/ultralytics/ultralytics/actions/runs/8957949304/job/24601616830?pr=10423
file = Path(file)
file = file.rename(file.with_stem(f"{file.stem}-{uuid.uuid4()}"))
YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32) # exported model inference
with Retry(times=3, delay=1): # retry in case of potential lingering multi-threaded file usage errors
shutil.rmtree(file)
@pytest.mark.skipif(not TORCH_1_9, reason="CoreML>=7.2 not supported with PyTorch<=1.8")
@pytest.mark.skipif(WINDOWS, reason="CoreML not supported on Windows") # RuntimeError: BlobWriter not loaded
@pytest.mark.skipif(IS_RASPBERRYPI, reason="CoreML not supported on Raspberry Pi")
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="CoreML not supported in Python 3.12")
def test_export_coreml():
"""Test exporting the YOLO model to CoreML format."""
if MACOS:
f = YOLO(MODEL).export(format="coreml", imgsz=32)
YOLO(f)(SOURCE, imgsz=32) # model prediction only supported on macOS for nms=False models
else:
YOLO(MODEL).export(format="coreml", nms=True, imgsz=32)
@pytest.mark.skipif(not LINUX, reason="Test disabled as TF suffers from install conflicts on Windows and macOS")
def test_export_tflite():
"""
Test exporting the YOLO model to TFLite format.
Note TF suffers from install conflicts on Windows and macOS.
"""
model = YOLO(MODEL)
f = model.export(format="tflite", imgsz=32)
YOLO(f)(SOURCE, imgsz=32)
@pytest.mark.skipif(True, reason="Test disabled")
@pytest.mark.skipif(not LINUX, reason="TF suffers from install conflicts on Windows and macOS")
def test_export_pb():
"""
Test exporting the YOLO model to *.pb format.
Note TF suffers from install conflicts on Windows and macOS.
"""
model = YOLO(MODEL)
f = model.export(format="pb", imgsz=32)
YOLO(f)(SOURCE, imgsz=32)
@pytest.mark.skipif(True, reason="Test disabled as Paddle protobuf and ONNX protobuf requirementsk conflict.")
def test_export_paddle():
"""
Test exporting the YOLO model to Paddle format.
Note Paddle protobuf requirements conflicting with onnx protobuf requirements.
"""
YOLO(MODEL).export(format="paddle", imgsz=32)
@pytest.mark.slow
def test_export_ncnn():
"""Test exporting the YOLO model to NCNN format."""
f = YOLO(MODEL).export(format="ncnn", imgsz=32)
YOLO(f)(SOURCE, imgsz=32) # exported model inference

@ -1,18 +1,18 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib import contextlib
import os
import subprocess
import time
from pathlib import Path from pathlib import Path
import pytest import pytest
from ultralytics import YOLO, download from ultralytics import YOLO, download
from ultralytics.utils import ASSETS, DATASETS_DIR, ROOT, SETTINGS, WEIGHTS_DIR from ultralytics.utils import DATASETS_DIR, SETTINGS
from ultralytics.utils.checks import check_requirements from ultralytics.utils.checks import check_requirements
MODEL = WEIGHTS_DIR / "path with spaces" / "yolov8n.pt" # test spaces in path from . import MODEL, SOURCE, TMP
CFG = "yolov8n.yaml"
SOURCE = ASSETS / "bus.jpg"
TMP = (ROOT / "../tests/tmp").resolve() # temp directory for test files
@pytest.mark.skipif(not check_requirements("ray", install=False), reason="ray[tune] not installed") @pytest.mark.skipif(not check_requirements("ray", install=False), reason="ray[tune] not installed")
@ -33,8 +33,6 @@ def test_mlflow():
@pytest.mark.skipif(True, reason="Test failing in scheduled CI https://github.com/ultralytics/ultralytics/pull/8868") @pytest.mark.skipif(True, reason="Test failing in scheduled CI https://github.com/ultralytics/ultralytics/pull/8868")
@pytest.mark.skipif(not check_requirements("mlflow", install=False), reason="mlflow not installed") @pytest.mark.skipif(not check_requirements("mlflow", install=False), reason="mlflow not installed")
def test_mlflow_keep_run_active(): def test_mlflow_keep_run_active():
import os
import mlflow import mlflow
"""Test training with MLflow tracking enabled.""" """Test training with MLflow tracking enabled."""
@ -67,9 +65,6 @@ def test_mlflow_keep_run_active():
def test_triton(): def test_triton():
"""Test NVIDIA Triton Server functionalities.""" """Test NVIDIA Triton Server functionalities."""
check_requirements("tritonclient[all]") check_requirements("tritonclient[all]")
import subprocess
import time
from tritonclient.http import InferenceServerClient # noqa from tritonclient.http import InferenceServerClient # noqa
# Create variables # Create variables

@ -18,25 +18,17 @@ from ultralytics.utils import (
ASSETS, ASSETS,
DEFAULT_CFG, DEFAULT_CFG,
DEFAULT_CFG_PATH, DEFAULT_CFG_PATH,
LINUX,
MACOS,
ONLINE, ONLINE,
ROOT, ROOT,
WEIGHTS_DIR, WEIGHTS_DIR,
WINDOWS, WINDOWS,
Retry, Retry,
checks, checks,
is_dir_writeable,
IS_RASPBERRYPI,
) )
from ultralytics.utils.downloads import download from ultralytics.utils.downloads import download
from ultralytics.utils.torch_utils import TORCH_1_9, TORCH_1_13 from ultralytics.utils.torch_utils import TORCH_1_9
MODEL = WEIGHTS_DIR / "path with spaces" / "yolov8n.pt" # test spaces in path from . import CFG, IS_TMP_WRITEABLE, MODEL, SOURCE, TMP
CFG = "yolov8n.yaml"
SOURCE = ASSETS / "bus.jpg"
TMP = (ROOT / "../tests/tmp").resolve() # temp directory for test files
IS_TMP_WRITEABLE = is_dir_writeable(TMP)
def test_model_forward(): def test_model_forward():
@ -202,81 +194,6 @@ def test_train_pretrained():
model(SOURCE) model(SOURCE)
def test_export_torchscript():
"""Test exporting the YOLO model to TorchScript format."""
f = YOLO(MODEL).export(format="torchscript", optimize=False)
YOLO(f)(SOURCE) # exported model inference
def test_export_onnx():
"""Test exporting the YOLO model to ONNX format."""
f = YOLO(MODEL).export(format="onnx", dynamic=True)
YOLO(f)(SOURCE) # exported model inference
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="OpenVINO not supported in Python 3.12")
@pytest.mark.skipif(not TORCH_1_13, reason="OpenVINO requires torch>=1.13")
def test_export_openvino():
"""Test exporting the YOLO model to OpenVINO format."""
f = YOLO(MODEL).export(format="openvino")
YOLO(f)(SOURCE) # exported model inference
@pytest.mark.skipif(not TORCH_1_9, reason="CoreML>=7.2 not supported with PyTorch<=1.8")
@pytest.mark.skipif(WINDOWS, reason="CoreML not supported on Windows") # RuntimeError: BlobWriter not loaded
@pytest.mark.skipif(IS_RASPBERRYPI, reason="CoreML not supported on Raspberry Pi")
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="CoreML not supported in Python 3.12")
def test_export_coreml():
"""Test exporting the YOLO model to CoreML format."""
if MACOS:
f = YOLO(MODEL).export(format="coreml")
YOLO(f)(SOURCE) # model prediction only supported on macOS for nms=False models
else:
YOLO(MODEL).export(format="coreml", nms=True)
@pytest.mark.skipif(not LINUX, reason="Test disabled as TF suffers from install conflicts on Windows and macOS")
def test_export_tflite():
"""
Test exporting the YOLO model to TFLite format.
Note TF suffers from install conflicts on Windows and macOS.
"""
model = YOLO(MODEL)
f = model.export(format="tflite")
YOLO(f)(SOURCE)
@pytest.mark.skipif(True, reason="Test disabled")
@pytest.mark.skipif(not LINUX, reason="TF suffers from install conflicts on Windows and macOS")
def test_export_pb():
"""
Test exporting the YOLO model to *.pb format.
Note TF suffers from install conflicts on Windows and macOS.
"""
model = YOLO(MODEL)
f = model.export(format="pb")
YOLO(f)(SOURCE)
@pytest.mark.skipif(True, reason="Test disabled as Paddle protobuf and ONNX protobuf requirementsk conflict.")
def test_export_paddle():
"""
Test exporting the YOLO model to Paddle format.
Note Paddle protobuf requirements conflicting with onnx protobuf requirements.
"""
YOLO(MODEL).export(format="paddle")
@pytest.mark.slow
def test_export_ncnn():
"""Test exporting the YOLO model to NCNN format."""
f = YOLO(MODEL).export(format="ncnn")
YOLO(f)(SOURCE) # exported model inference
def test_all_model_yamls(): def test_all_model_yamls():
"""Test YOLO model creation for all available YAML configurations.""" """Test YOLO model creation for all available YAML configurations."""
for m in (ROOT / "cfg" / "models").rglob("*.yaml"): for m in (ROOT / "cfg" / "models").rglob("*.yaml"):
@ -293,7 +210,7 @@ def test_workflow():
model.train(data="coco8.yaml", epochs=1, imgsz=32, optimizer="SGD") model.train(data="coco8.yaml", epochs=1, imgsz=32, optimizer="SGD")
model.val(imgsz=32) model.val(imgsz=32)
model.predict(SOURCE, imgsz=32) model.predict(SOURCE, imgsz=32)
model.export(format="onnx") # export a model to ONNX format model.export(format="torchscript")
def test_predict_callback_and_setup(): def test_predict_callback_and_setup():
@ -641,7 +558,7 @@ def test_yolo_world():
"""Tests YOLO world models with different configurations, including classes, detection, and training scenarios.""" """Tests YOLO world models with different configurations, including classes, detection, and training scenarios."""
model = YOLO("yolov8s-world.pt") # no YOLOv8n-world model yet model = YOLO("yolov8s-world.pt") # no YOLOv8n-world model yet
model.set_classes(["tree", "window"]) model.set_classes(["tree", "window"])
model(ASSETS / "bus.jpg", conf=0.01) model(SOURCE, conf=0.01)
model = YOLO("yolov8s-worldv2.pt") # no YOLOv8n-world model yet model = YOLO("yolov8s-worldv2.pt") # no YOLOv8n-world model yet
# Training from a pretrained model. Eval is included at the final stage of training. # Training from a pretrained model. Eval is included at the final stage of training.
@ -651,11 +568,7 @@ def test_yolo_world():
epochs=1, epochs=1,
imgsz=32, imgsz=32,
cache="disk", cache="disk",
batch=4,
close_mosaic=1, close_mosaic=1,
name="yolo-world",
save_txt=True,
save_json=True,
) )
# test WorWorldTrainerFromScratch # test WorWorldTrainerFromScratch
@ -667,8 +580,6 @@ def test_yolo_world():
epochs=1, epochs=1,
imgsz=32, imgsz=32,
cache="disk", cache="disk",
batch=4,
close_mosaic=1, close_mosaic=1,
name="yolo-world",
trainer=WorldTrainerFromScratch, trainer=WorldTrainerFromScratch,
) )

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.8" __version__ = "8.2.9"
from ultralytics.data.explorer.explorer import Explorer from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

@ -10,6 +10,7 @@ import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from ultralytics.data.utils import polygons2masks, polygons2masks_overlap
from ultralytics.utils import LOGGER, colorstr from ultralytics.utils import LOGGER, colorstr
from ultralytics.utils.checks import check_version from ultralytics.utils.checks import check_version
from ultralytics.utils.instance import Instances from ultralytics.utils.instance import Instances
@ -17,8 +18,6 @@ from ultralytics.utils.metrics import bbox_ioa
from ultralytics.utils.ops import segment2box, xyxyxyxy2xywhr from ultralytics.utils.ops import segment2box, xyxyxyxy2xywhr
from ultralytics.utils.torch_utils import TORCHVISION_0_10, TORCHVISION_0_11, TORCHVISION_0_13 from ultralytics.utils.torch_utils import TORCHVISION_0_10, TORCHVISION_0_11, TORCHVISION_0_13
from .utils import polygons2masks, polygons2masks_overlap
DEFAULT_MEAN = (0.0, 0.0, 0.0) DEFAULT_MEAN = (0.0, 0.0, 0.0)
DEFAULT_STD = (1.0, 1.0, 1.0) DEFAULT_STD = (1.0, 1.0, 1.0)
DEFAULT_CROP_FRACTION = 1.0 DEFAULT_CROP_FRACTION = 1.0

@ -14,10 +14,9 @@ import numpy as np
import psutil import psutil
from torch.utils.data import Dataset from torch.utils.data import Dataset
from ultralytics.data.utils import FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
from .utils import FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS
class BaseDataset(Dataset): class BaseDataset(Dataset):
""" """

@ -9,6 +9,7 @@ import torch
from PIL import Image from PIL import Image
from torch.utils.data import dataloader, distributed from torch.utils.data import dataloader, distributed
from ultralytics.data.dataset import GroundingDataset, YOLODataset, YOLOMultiModalDataset
from ultralytics.data.loaders import ( from ultralytics.data.loaders import (
LOADERS, LOADERS,
LoadImagesAndVideos, LoadImagesAndVideos,
@ -19,13 +20,10 @@ from ultralytics.data.loaders import (
SourceTypes, SourceTypes,
autocast_list, autocast_list,
) )
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS from ultralytics.data.utils import IMG_FORMATS, PIN_MEMORY, VID_FORMATS
from ultralytics.utils import LINUX, NUM_THREADS, RANK, colorstr from ultralytics.utils import LINUX, NUM_THREADS, RANK, colorstr
from ultralytics.utils.checks import check_file from ultralytics.utils.checks import check_file
from .dataset import GroundingDataset, YOLODataset, YOLOMultiModalDataset
from .utils import PIN_MEMORY
class InfiniteDataLoader(dataloader.DataLoader): class InfiniteDataLoader(dataloader.DataLoader):
""" """

@ -64,9 +64,10 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from ultralytics.cfg import get_cfg from ultralytics.cfg import TASK2DATA, get_cfg
from ultralytics.data import build_dataloader
from ultralytics.data.dataset import YOLODataset from ultralytics.data.dataset import YOLODataset
from ultralytics.data.utils import check_det_dataset from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.autobackend import check_class_names, default_class_names from ultralytics.nn.autobackend import check_class_names, default_class_names
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
@ -169,7 +170,7 @@ class Exporter:
callbacks.add_integration_callbacks(self) callbacks.add_integration_callbacks(self)
@smart_inference_mode() @smart_inference_mode()
def __call__(self, model=None): def __call__(self, model=None) -> str:
"""Returns list of exported files/dirs after running callbacks.""" """Returns list of exported files/dirs after running callbacks."""
self.run_callbacks("on_export_start") self.run_callbacks("on_export_start")
t = time.time() t = time.time()
@ -211,7 +212,12 @@ class Exporter:
"(torchscript, onnx, openvino, engine, coreml) formats. " "(torchscript, onnx, openvino, engine, coreml) formats. "
"See https://docs.ultralytics.com/models/yolo-world for details." "See https://docs.ultralytics.com/models/yolo-world for details."
) )
if self.args.int8 and not self.args.data:
self.args.data = DEFAULT_CFG.data or TASK2DATA[getattr(model, "task", "detect")] # assign default data
LOGGER.warning(
"WARNING ⚠ INT8 export requires a missing 'data' arg for calibration. "
f"Using default 'data={self.args.data}'."
)
# Input # Input
im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device) im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
file = Path( file = Path(
@ -333,6 +339,23 @@ class Exporter:
self.run_callbacks("on_export_end") self.run_callbacks("on_export_end")
return f # return list of exported files/dirs return f # return list of exported files/dirs
def get_int8_calibration_dataloader(self, prefix=""):
"""Build and return a dataloader suitable for calibration of INT8 models."""
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
data = (check_cls_dataset if self.model.task == "classify" else check_det_dataset)(self.args.data)
dataset = YOLODataset(
data[self.args.split or "val"],
data=data,
task=self.model.task,
imgsz=self.imgsz[0],
augment=False,
batch_size=self.args.batch,
)
n = len(dataset)
if n < 300:
LOGGER.warning(f"{prefix} WARNING ⚠ >300 images recommended for INT8 calibration, found {n} images.")
return build_dataloader(dataset, batch=self.args.batch, workers=0) # required for batch loading
@try_export @try_export
def export_torchscript(self, prefix=colorstr("TorchScript:")): def export_torchscript(self, prefix=colorstr("TorchScript:")):
"""YOLOv8 TorchScript model export.""" """YOLOv8 TorchScript model export."""
@ -442,37 +465,21 @@ class Exporter:
if self.args.int8: if self.args.int8:
fq = str(self.file).replace(self.file.suffix, f"_int8_openvino_model{os.sep}") fq = str(self.file).replace(self.file.suffix, f"_int8_openvino_model{os.sep}")
fq_ov = str(Path(fq) / self.file.with_suffix(".xml").name) fq_ov = str(Path(fq) / self.file.with_suffix(".xml").name)
if not self.args.data:
self.args.data = DEFAULT_CFG.data or "coco128.yaml"
LOGGER.warning(
f"{prefix} WARNING ⚠ INT8 export requires a missing 'data' arg for calibration. "
f"Using default 'data={self.args.data}'."
)
check_requirements("nncf>=2.8.0") check_requirements("nncf>=2.8.0")
import nncf import nncf
def transform_fn(data_item): def transform_fn(data_item) -> np.ndarray:
"""Quantization transform function.""" """Quantization transform function."""
assert ( data_item: torch.Tensor = data_item["img"] if isinstance(data_item, dict) else data_item
data_item["img"].dtype == torch.uint8 assert data_item.dtype == torch.uint8, "Input image must be uint8 for the quantization preprocessing"
), "Input image must be uint8 for the quantization preprocessing" im = data_item.numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0
im = data_item["img"].numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0
return np.expand_dims(im, 0) if im.ndim == 3 else im return np.expand_dims(im, 0) if im.ndim == 3 else im
# Generate calibration data for integer quantization # Generate calibration data for integer quantization
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
data = check_det_dataset(self.args.data)
dataset = YOLODataset(data["val"], data=data, task=self.model.task, imgsz=self.imgsz[0], augment=False)
n = len(dataset)
if n < 300:
LOGGER.warning(f"{prefix} WARNING ⚠ >300 images recommended for INT8 calibration, found {n} images.")
quantization_dataset = nncf.Dataset(dataset, transform_fn)
ignored_scope = None ignored_scope = None
if isinstance(self.model.model[-1], Detect): if isinstance(self.model.model[-1], Detect):
# Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect # Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
head_module_name = ".".join(list(self.model.named_modules())[-1][0].split(".")[:2]) head_module_name = ".".join(list(self.model.named_modules())[-1][0].split(".")[:2])
ignored_scope = nncf.IgnoredScope( # ignore operations ignored_scope = nncf.IgnoredScope( # ignore operations
patterns=[ patterns=[
f".*{head_module_name}/.*/Add", f".*{head_module_name}/.*/Add",
@ -485,7 +492,10 @@ class Exporter:
) )
quantized_ov_model = nncf.quantize( quantized_ov_model = nncf.quantize(
ov_model, quantization_dataset, preset=nncf.QuantizationPreset.MIXED, ignored_scope=ignored_scope model=ov_model,
calibration_dataset=nncf.Dataset(self.get_int8_calibration_dataloader(prefix), transform_fn),
preset=nncf.QuantizationPreset.MIXED,
ignored_scope=ignored_scope,
) )
serialize(quantized_ov_model, fq_ov) serialize(quantized_ov_model, fq_ov)
return fq, None return fq, None
@ -787,11 +797,9 @@ class Exporter:
verbosity = "info" verbosity = "info"
if self.args.data: if self.args.data:
# Generate calibration data for integer quantization # Generate calibration data for integer quantization
LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'") dataloader = self.get_int8_calibration_dataloader(prefix)
data = check_det_dataset(self.args.data)
dataset = YOLODataset(data["val"], data=data, imgsz=self.imgsz[0], augment=False)
images = [] images = []
for i, batch in enumerate(dataset): for i, batch in enumerate(dataloader):
if i >= 100: # maximum number of calibration images if i >= 100: # maximum number of calibration images
break break
im = batch["img"].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC im = batch["img"].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC

@ -572,7 +572,7 @@ class Model(nn.Module):
def export( def export(
self, self,
**kwargs, **kwargs,
): ) -> str:
""" """
Exports the model to a different format suitable for deployment. Exports the model to a different format suitable for deployment.
@ -588,7 +588,7 @@ class Model(nn.Module):
model's overrides and method defaults. model's overrides and method defaults.
Returns: Returns:
(object): The exported model in the specified format, or an object related to the export process. (str): The exported model filename in the specified format, or an object related to the export process.
Raises: Raises:
AssertionError: If the model is not a PyTorch model. AssertionError: If the model is not a PyTorch model.

Loading…
Cancel
Save