Fix TFLite INT8 quant bug (#13082)

pull/13094/head
Glenn Jocher 9 months ago committed by GitHub
parent cb99f71728
commit 11623eeb00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 6
      .github/workflows/ci.yaml
  2. 30
      tests/test_exports.py
  3. 26
      ultralytics/engine/exporter.py
  4. 7
      ultralytics/utils/checks.py

@ -164,7 +164,7 @@ jobs:
Tests:
if: github.event_name != 'workflow_dispatch' || github.event.inputs.tests == 'true'
timeout-minutes: 60
timeout-minutes: 120
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
@ -241,7 +241,7 @@ jobs:
RaspberryPi:
if: github.repository == 'ultralytics/ultralytics' && (github.event_name == 'schedule' || github.event.inputs.raspberrypi == 'true')
timeout-minutes: 60
timeout-minutes: 120
runs-on: raspberry-pi
steps:
- uses: actions/checkout@v4
@ -253,7 +253,7 @@ jobs:
- name: Install requirements
run: |
python -m pip install --upgrade pip wheel
pip install -e ".[export]" pytest mlflow pycocotools "ray[tune]"
pip install -e ".[export]" pytest
- name: Check environment
run: |
yolo checks

@ -23,22 +23,22 @@ from tests import MODEL, SOURCE
def test_export_torchscript():
"""Test YOLO exports to TorchScript format."""
f = YOLO(MODEL).export(format="torchscript", optimize=False, imgsz=32)
YOLO(f)(SOURCE, imgsz=32) # exported model inference
file = YOLO(MODEL).export(format="torchscript", optimize=False, imgsz=32)
YOLO(file)(SOURCE, imgsz=32) # exported model inference
def test_export_onnx():
"""Test YOLO exports to ONNX format."""
f = YOLO(MODEL).export(format="onnx", dynamic=True, imgsz=32)
YOLO(f)(SOURCE, imgsz=32) # exported model inference
file = YOLO(MODEL).export(format="onnx", dynamic=True, imgsz=32)
YOLO(file)(SOURCE, imgsz=32) # exported model inference
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="OpenVINO not supported in Python 3.12")
@pytest.mark.skipif(not TORCH_1_13, reason="OpenVINO requires torch>=1.13")
def test_export_openvino():
"""Test YOLO exports to OpenVINO format."""
f = YOLO(MODEL).export(format="openvino", imgsz=32)
YOLO(f)(SOURCE, imgsz=32) # exported model inference
file = YOLO(MODEL).export(format="openvino", imgsz=32)
YOLO(file)(SOURCE, imgsz=32) # exported model inference
@pytest.mark.slow
@ -118,7 +118,7 @@ def test_export_torchscript_matrix(task, dynamic, int8, half, batch):
],
)
def test_export_coreml_matrix(task, dynamic, int8, half, batch):
"""Test YOLO exports to TorchScript format."""
"""Test YOLO exports to CoreML format."""
file = YOLO(TASK2MODEL[task]).export(
format="coreml",
imgsz=32,
@ -138,8 +138,8 @@ def test_export_coreml_matrix(task, dynamic, int8, half, batch):
def test_export_coreml():
"""Test YOLO exports to CoreML format."""
if MACOS:
f = YOLO(MODEL).export(format="coreml", imgsz=32)
YOLO(f)(SOURCE, imgsz=32) # model prediction only supported on macOS for nms=False models
file = YOLO(MODEL).export(format="coreml", imgsz=32)
YOLO(file)(SOURCE, imgsz=32) # model prediction only supported on macOS for nms=False models
else:
YOLO(MODEL).export(format="coreml", nms=True, imgsz=32)
@ -152,8 +152,8 @@ def test_export_tflite():
Note TF suffers from install conflicts on Windows and macOS.
"""
model = YOLO(MODEL)
f = model.export(format="tflite", imgsz=32)
YOLO(f)(SOURCE, imgsz=32)
file = model.export(format="tflite", imgsz=32)
YOLO(file)(SOURCE, imgsz=32)
@pytest.mark.skipif(True, reason="Test disabled")
@ -165,8 +165,8 @@ def test_export_pb():
Note TF suffers from install conflicts on Windows and macOS.
"""
model = YOLO(MODEL)
f = model.export(format="pb", imgsz=32)
YOLO(f)(SOURCE, imgsz=32)
file = model.export(format="pb", imgsz=32)
YOLO(file)(SOURCE, imgsz=32)
@pytest.mark.skipif(True, reason="Test disabled as Paddle protobuf and ONNX protobuf requirementsk conflict.")
@ -182,5 +182,5 @@ def test_export_paddle():
@pytest.mark.slow
def test_export_ncnn():
"""Test YOLO exports to NCNN format."""
f = YOLO(MODEL).export(format="ncnn", imgsz=32)
YOLO(f)(SOURCE, imgsz=32) # exported model inference
file = YOLO(MODEL).export(format="ncnn", imgsz=32)
YOLO(file)(SOURCE, imgsz=32) # exported model inference

@ -83,6 +83,7 @@ from ultralytics.utils import (
WINDOWS,
__version__,
callbacks,
checks,
colorstr,
get_default_args,
yaml_save,
@ -184,6 +185,7 @@ class Exporter:
if sum(flags) != 1:
raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}")
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans
is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs))
# Device
if fmt == "engine" and self.args.device is None:
@ -243,7 +245,7 @@ class Exporter:
m.dynamic = self.args.dynamic
m.export = True
m.format = self.args.format
elif isinstance(m, C2f) and not any((saved_model, pb, tflite, edgetpu, tfjs)):
elif isinstance(m, C2f) and not is_tf_format:
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
m.forward = m.forward_split
@ -303,7 +305,7 @@ class Exporter:
f[3], _ = self.export_openvino()
if coreml: # CoreML
f[4], _ = self.export_coreml()
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
if is_tf_format: # TensorFlow formats
self.args.int8 |= edgetpu
f[5], keras_model = self.export_saved_model()
if pb or tfjs: # pb prerequisite to tfjs
@ -777,11 +779,10 @@ class Exporter:
_ = self.cache.write_bytes(cache)
# Load dataset w/ builder (for batching) and calibrate
dataset = self.get_int8_calibration_dataloader(prefix)
config.int8_calibrator = EngineCalibrator(
dataset=dataset,
dataset=self.get_int8_calibration_dataloader(prefix),
batch=2 * self.args.batch,
cache=self.file.with_suffix(".cache"),
cache=str(self.file.with_suffix(".cache")),
)
elif half:
@ -813,7 +814,7 @@ class Exporter:
except ImportError:
suffix = "-macos" if MACOS else "-aarch64" if ARM64 else "" if cuda else "-cpu"
version = "" if ARM64 else "<=2.13.1"
check_requirements(f"tensorflow{suffix}{version}")
check_requirements((f"tensorflow{suffix}{version}", "keras"))
import tensorflow as tf # noqa
if ARM64:
check_requirements("cmake") # 'cmake' is needed to build onnxsim on aarch64
@ -855,24 +856,17 @@ class Exporter:
f_onnx, _ = self.export_onnx()
# Export to TF
tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
np_data = None
if self.args.int8:
tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
verbosity = "info"
if self.args.data:
# Generate calibration data for integer quantization
dataloader = self.get_int8_calibration_dataloader(prefix)
images = []
for i, batch in enumerate(dataloader):
if i >= 100: # maximum number of calibration images
break
im = batch["img"].permute(1, 2, 0)[None] # list to nparray, CHW to BHWC
images.append(im)
f.mkdir()
images = [batch["img"].permute(0, 2, 3, 1) for batch in self.get_int8_calibration_dataloader(prefix)]
images = torch.cat(images, 0).float()
# mean = images.view(-1, 3).mean(0) # imagenet mean [123.675, 116.28, 103.53]
# std = images.view(-1, 3).std(0) # imagenet std [58.395, 57.12, 57.375]
np.save(str(tmp_file), images.numpy()) # BHWC
np.save(str(tmp_file), images.numpy().astype(np.float32)) # BHWC
np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]]
else:
verbosity = "error"

@ -23,7 +23,6 @@ from ultralytics.utils import (
ASSETS,
AUTOINSTALL,
IS_COLAB,
IS_DOCKER,
IS_JUPYTER,
IS_KAGGLE,
IS_PIP_PACKAGE,
@ -322,17 +321,18 @@ def check_font(font="Arial.ttf"):
return file
def check_python(minimum: str = "3.8.0") -> bool:
def check_python(minimum: str = "3.8.0", hard: bool = True) -> bool:
"""
Check current python version against the required minimum version.
Args:
minimum (str): Required minimum version of python.
hard (bool, optional): If True, raise an AssertionError if the requirement is not met.
Returns:
(bool): Whether the installed Python version meets the minimum constraints.
"""
return check_version(PYTHON_VERSION, minimum, name="Python ", hard=True)
return check_version(PYTHON_VERSION, minimum, name="Python", hard=hard)
@TryExcept()
@ -735,4 +735,5 @@ def cuda_is_available() -> bool:
# Define constants
IS_PYTHON_MINIMUM_3_10 = check_python("3.10", hard=False)
IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12")

Loading…
Cancel
Save