fixed tests

test-quan
Francesco Mattioli 2 weeks ago
parent 8e2b138a84
commit c5ad60213c
  1. 1
      mkdocs.yml
  2. 28
      tests/test_exports.py
  3. 6
      ultralytics/engine/exporter.py

@ -407,7 +407,6 @@ nav:
- Paperspace Gradient: integrations/paperspace.md
- Ray Tune: integrations/ray-tune.md
- Roboflow: integrations/roboflow.md
- Sony MCT: integrations/sony-mct.md
- TF GraphDef: integrations/tf-graphdef.md
- TF SavedModel: integrations/tf-savedmodel.md
- TF.js: integrations/tfjs.md

@ -31,6 +31,20 @@ def test_export_onnx():
file = YOLO(MODEL).export(format="onnx", dynamic=True, imgsz=32)
YOLO(file)(SOURCE, imgsz=32) # exported model inference
@pytest.mark.skipif(not LINUX or MACOS, reason="Skipping test on Windows and Macos")
def test_export_imx500_ptq():
"""Test YOLOv8n exports to imx500 format."""
model = YOLO("yolov8n.pt")
file = model.export(format="imx500", imgsz=32, gptq=False)
YOLO(file)(SOURCE, imgsz=32)
@pytest.mark.slow
@pytest.mark.skipif(IS_RASPBERRYPI or not LINUX or MACOS, reason="Skipping test on Raspberry Pi and Windows")
def test_export_imx500_gptq():
"""Test YOLOv8n exports to imx500 format with gptq."""
model = YOLO("yolov8n.pt")
file = model.export(format="imx500", imgsz=32, gptq=True)
YOLO(file)(SOURCE, imgsz=32)
@pytest.mark.skipif(not TORCH_1_13, reason="OpenVINO requires torch>=1.13")
def test_export_openvino():
@ -207,16 +221,4 @@ def test_export_ncnn():
YOLO(file)(SOURCE, imgsz=32) # exported model inference
@pytest.mark.skipif(not LINUX or MACOS, reason="Skipping test on Windows and Macos")
def test_export_imx500_ptq():
"""Test YOLOv8n exports to imx500 format."""
model = YOLO("yolov8n.pt")
file = model.export(format="imx500", imgsz=32, gptq=False)
YOLO(file)(SOURCE, imgsz=32)
@pytest.mark.skipif(IS_RASPBERRYPI or not LINUX or MACOS, reason="Skipping test on Raspberry Pi and Windows")
def test_export_imx500_gptq():
"""Test YOLOv8n exports to imx500 format with gptq."""
model = YOLO("yolov8n.pt")
file = model.export(format="imx500", imgsz=32, gptq=True)
YOLO(file)(SOURCE, imgsz=32)

@ -1123,10 +1123,8 @@ class Exporter:
try:
subprocess.run(["java", "--version"], check=True)
except OSError:
raise OSError(
"Java 17 is required for the imx500 conversion. \n Please install Java with: \n sudo apt install openjdk-17-jdk openjdk-17-jre"
)
except FileNotFoundError:
subprocess.run(["apt", "install", "-y", "openjdk-17-jdk", "openjdk-17-jre"], check=True)
def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)):
for batch in dataloader:

Loading…
Cancel
Save