chnaged mct to imx500

test-quan
Francesco Mattioli 4 weeks ago
parent 354ea825e8
commit 478d126c07
  1. 4
      tests/test_exports.py
  2. 34
      ultralytics/engine/exporter.py
  3. 10
      ultralytics/nn/autobackend.py
  4. 4
      ultralytics/nn/modules/head.py
  5. 8
      ultralytics/utils/benchmarks.py

@ -207,8 +207,8 @@ def test_export_ncnn():
@pytest.mark.skipif(True, reason="Test disabled") @pytest.mark.skipif(True, reason="Test disabled")
def test_export_mct(): def test_export_imx500():
"""Test YOLOv8n exports to MCT format.""" """Test YOLOv8n exports to MCT format."""
model = YOLO("yolov8n.pt") model = YOLO("yolov8n.pt")
file = model.export(format="mct", imgsz=32) file = model.export(format="imx500", imgsz=32)
YOLO(file)(SOURCE, imgsz=32) YOLO(file)(SOURCE, imgsz=32)

@ -18,7 +18,7 @@ TensorFlow.js | `tfjs` | yolo11n_web_model/
PaddlePaddle | `paddle` | yolo11n_paddle_model/ PaddlePaddle | `paddle` | yolo11n_paddle_model/
MNN | `mnn` | yolo11n.mnn MNN | `mnn` | yolo11n.mnn
NCNN | `ncnn` | yolo11n_ncnn_model/ NCNN | `ncnn` | yolo11n_ncnn_model/
Sony MCT | `mct` | yolo11n_mct_model.onnx imx500 | `imx500` | yolo11n_imx500_model.onnx
Requirements: Requirements:
$ pip install "ultralytics[export]" $ pip install "ultralytics[export]"
@ -45,7 +45,7 @@ Inference:
yolo11n_paddle_model # PaddlePaddle yolo11n_paddle_model # PaddlePaddle
yolo11n.mnn # MNN yolo11n.mnn # MNN
yolo11n_ncnn_model # NCNN yolo11n_ncnn_model # NCNN
yolo11n_mct_model.onnx # Sony MCT yolo11n_imx500_model.onnx # IMX500
TensorFlow.js: TensorFlow.js:
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
@ -115,7 +115,7 @@ def export_formats():
["PaddlePaddle", "paddle", "_paddle_model", True, True], ["PaddlePaddle", "paddle", "_paddle_model", True, True],
["MNN", "mnn", ".mnn", True, True], ["MNN", "mnn", ".mnn", True, True],
["NCNN", "ncnn", "_ncnn_model", True, True], ["NCNN", "ncnn", "_ncnn_model", True, True],
["Sony MCT", "mct", "_mct_model.onnx", True, True], ["IMX500", "imx500", "_imx500_model.onnx", True, True],
] ]
return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU"], zip(*x))) return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU"], zip(*x)))
@ -209,12 +209,12 @@ class Exporter:
paddle, paddle,
mnn, mnn,
ncnn, ncnn,
mct, imx500,
) = flags # export booleans ) = flags # export booleans
is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs)) is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs))
if mct: if imx500:
LOGGER.warning("WARNING ⚠ Sony MCT only supports int8 export, setting int8=True.") LOGGER.warning("WARNING ⚠ IMX500 only supports int8 export, setting int8=True.")
self.args.int8 = True self.args.int8 = True
# Device # Device
dla = None dla = None
@ -285,7 +285,7 @@ class Exporter:
elif isinstance(m, C2f) and not is_tf_format: elif isinstance(m, C2f) and not is_tf_format:
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph # EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
m.forward = m.forward_split m.forward = m.forward_split
if isinstance(m, Detect) and mct: if isinstance(m, Detect) and imx500:
from ultralytics.utils.tal import make_anchors from ultralytics.utils.tal import make_anchors
anchors, strides = ( anchors, strides = (
@ -299,7 +299,7 @@ class Exporter:
m.anchors = anchors m.anchors = anchors
m.strides = strides m.strides = strides
if isinstance(m, C2f) and mct: if isinstance(m, C2f) and imx500:
m.forward = m.forward_fx m.forward = m.forward_fx
y = None y = None
@ -375,8 +375,8 @@ class Exporter:
f[11], _ = self.export_mnn() f[11], _ = self.export_mnn()
if ncnn: # NCNN if ncnn: # NCNN
f[12], _ = self.export_ncnn() f[12], _ = self.export_ncnn()
if mct: if imx500:
f[13], _ = self.export_mct() f[13], _ = self.export_imx500()
# Finish # Finish
f = [str(x) for x in f if x] # filter out '' and None f = [str(x) for x in f if x] # filter out '' and None
@ -1100,12 +1100,12 @@ class Exporter:
return f, None return f, None
@try_export @try_export
def export_mct(self, prefix=colorstr("Sony MCT:")): def export_imx500(self, prefix=colorstr("IMX500:")):
"""YOLO Sony MCT export.""" """YOLO IMX500 export."""
if getattr(self.model, "end2end", False): if getattr(self.model, "end2end", False):
raise ValueError("MCT export is not supported for end2end models.") raise ValueError("IMX500 export is not supported for end2end models.")
if "C2f" not in self.model.__str__(): if "C2f" not in self.model.__str__():
raise ValueError("MCT export is only supported for YOLOv8 detection models") raise ValueError("IMX500 export is only supported for YOLOv8 detection models")
check_requirements(("model-compression-toolkit==2.1.1", "sony-custom-layers[torch]")) check_requirements(("model-compression-toolkit==2.1.1", "sony-custom-layers[torch]"))
from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms
import model_compression_toolkit as mct import model_compression_toolkit as mct
@ -1191,7 +1191,7 @@ class Exporter:
max_detections=self.args.max_det, max_detections=self.args.max_det,
).to(self.device) ).to(self.device)
f = Path(str(self.file).replace(self.file.suffix, "_mct_model.onnx")) # js dir f = Path(str(self.file).replace(self.file.suffix, "_imx500_model.onnx")) # js dir
mct.exporter.pytorch_export_model(model=quant_model, save_model_path=f, repr_dataset=representative_dataset_gen) mct.exporter.pytorch_export_model(model=quant_model, save_model_path=f, repr_dataset=representative_dataset_gen)
import onnx import onnx
@ -1204,7 +1204,7 @@ class Exporter:
onnx.save(model_onnx, f) onnx.save(model_onnx, f)
if not LINUX: if not LINUX:
LOGGER.warning(f"{prefix} WARNING ⚠ MCT imx500-converter is only supported on Linux.") LOGGER.warning(f"{prefix} WARNING ⚠ imx500-converter is only supported on Linux.")
else: else:
check_requirements("imx500-converter[pt]==3.14.3") check_requirements("imx500-converter[pt]==3.14.3")
try: try:
@ -1217,7 +1217,7 @@ class Exporter:
) )
return None return None
subprocess.run(["imxconv-pt", "-i", "yolov8n_mct_model.onnx", "-o", "yolov8n_imx500_model"], check=True) subprocess.run(["imxconv-pt", "-i", "yolov8n_imx500_model.onnx", "-o", "yolov8n_imx500_model"], check=True)
return f, None return f, None

@ -123,7 +123,7 @@ class AutoBackend(nn.Module):
paddle, paddle,
mnn, mnn,
ncnn, ncnn,
mct, imx500,
triton, triton,
) = self._model_type(w) ) = self._model_type(w)
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16 fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
@ -183,7 +183,7 @@ class AutoBackend(nn.Module):
check_requirements("opencv-python>=4.5.4") check_requirements("opencv-python>=4.5.4")
net = cv2.dnn.readNetFromONNX(w) net = cv2.dnn.readNetFromONNX(w)
# ONNX Runtime and MCT # ONNX Runtime and IMX500
elif onnx: elif onnx:
LOGGER.info(f"Loading {w} for ONNX Runtime inference...") LOGGER.info(f"Loading {w} for ONNX Runtime inference...")
check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime")) check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime"))
@ -202,11 +202,11 @@ class AutoBackend(nn.Module):
device = torch.device("cpu") device = torch.device("cpu")
cuda = False cuda = False
LOGGER.info(f"Preferring ONNX Runtime {providers[0]}") LOGGER.info(f"Preferring ONNX Runtime {providers[0]}")
if mct: if imx500:
check_requirements( check_requirements(
["model-compression-toolkit==2.1.1", "sony-custom-layers[torch]==0.2.0", "onnxruntime-extensions"] ["model-compression-toolkit==2.1.1", "sony-custom-layers[torch]==0.2.0", "onnxruntime-extensions"]
) )
LOGGER.info(f"Loading {w} for ONNX MCT quantization inference...") LOGGER.info(f"Loading {w} for ONNX IMX500 inference...")
import mct_quantizers as mctq import mct_quantizers as mctq
from sony_custom_layers.pytorch.object_detection import nms_ort # noqa from sony_custom_layers.pytorch.object_detection import nms_ort # noqa
@ -554,7 +554,7 @@ class AutoBackend(nn.Module):
) )
self.session.run_with_iobinding(self.io) self.session.run_with_iobinding(self.io)
y = self.bindings y = self.bindings
if self.mct: if self.imx500:
y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None]], axis=-1) y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None]], axis=-1)
# OpenVINO # OpenVINO

@ -102,7 +102,7 @@ class Detect(nn.Module):
# Inference path # Inference path
shape = x[0].shape # BCHW shape = x[0].shape # BCHW
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
if self.format != "mct" and (self.dynamic or self.shape != shape): if self.format != "imx500" and (self.dynamic or self.shape != shape):
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape self.shape = shape
@ -120,7 +120,7 @@ class Detect(nn.Module):
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1) grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
norm = self.strides / (self.stride[0] * grid_size) norm = self.strides / (self.stride[0] * grid_size)
dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2]) dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
elif self.export and self.format == "mct": elif self.export and self.format == "imx500":
dbox = self.decode_bboxes( dbox = self.decode_bboxes(
self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides, xywh=False self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides, xywh=False
) )

@ -114,11 +114,11 @@ def benchmark(
assert LINUX or MACOS, "Windows Paddle exports not supported yet" assert LINUX or MACOS, "Windows Paddle exports not supported yet"
if i in {12, 13}: # MNN, NCNN if i in {12, 13}: # MNN, NCNN
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 MNN, NCNN exports not supported yet" assert not isinstance(model, YOLOWorld), "YOLOWorldv2 MNN, NCNN exports not supported yet"
if i in {14}: # MCT if i in {14}: # IMX500
assert not is_end2end assert not is_end2end
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 MCT exports not supported" assert not isinstance(model, YOLOWorld), "YOLOWorldv2 IMX500 exports not supported"
assert model.task == "detect", "MCT only supported for detection task" assert model.task == "detect", "IMX500 only supported for detection task"
assert "C2f" in model.__str__(), "MCT only supported for YOLOv8" assert "C2f" in model.__str__(), "IMX500 only supported for YOLOv8"
if "cpu" in device.type: if "cpu" in device.type:
assert cpu, "inference not supported on CPU" assert cpu, "inference not supported on CPU"
if "cuda" in device.type: if "cuda" in device.type:

Loading…
Cancel
Save