`ultralytics 8.3.29` Sony IMX500 export (#14878)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
Co-authored-by: Francesco Mattioli <Francesco.mttl@gmail.com>
Co-authored-by: Lakshantha Dissanayake <lakshantha@ultralytics.com>
Co-authored-by: Lakshantha Dissanayake <lakshanthad@yahoo.com>
Co-authored-by: Chizkiyahu Raful <37312901+Chizkiyahu@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Muhammad Rizwan Munawar <muhammadrizwanmunawar123@gmail.com>
Co-authored-by: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com>
pull/17479/head v8.3.29
Laughing 2 weeks ago committed by GitHub
parent 2c6cd68144
commit 0fa1d7d5a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 1
      .gitignore
  2. 2
      docs/en/integrations/index.md
  3. 2
      docs/en/integrations/sony-imx500.md
  4. 1
      docs/en/macros/export-table.md
  5. 4
      docs/en/reference/utils/torch_utils.md
  6. 3
      docs/mkdocs_github_authors.yaml
  7. 4
      mkdocs.yml
  8. 9
      tests/test_exports.py
  9. 2
      ultralytics/__init__.py
  10. 176
      ultralytics/engine/exporter.py
  11. 27
      ultralytics/nn/autobackend.py
  12. 3
      ultralytics/nn/modules/block.py
  13. 12
      ultralytics/nn/modules/head.py
  14. 5
      ultralytics/utils/benchmarks.py
  15. 2
      ultralytics/utils/tal.py
  16. 45
      ultralytics/utils/torch_utils.py

1
.gitignore vendored

@ -163,6 +163,7 @@ weights/
*_openvino_model/ *_openvino_model/
*_paddle_model/ *_paddle_model/
*_ncnn_model/ *_ncnn_model/
*_imx_model/
pnnx* pnnx*
# Autogenerated files for tests # Autogenerated files for tests

@ -61,6 +61,8 @@ Welcome to the Ultralytics Integrations page! This page provides an overview of
- [Albumentations](albumentations.md): Enhance your Ultralytics models with powerful image augmentations to improve model robustness and generalization. - [Albumentations](albumentations.md): Enhance your Ultralytics models with powerful image augmentations to improve model robustness and generalization.
- [SONY IMX500](sony-imx500.md): Optimize and deploy [Ultralytics YOLOv8](https://docs.ultralytics.com/models/yolov8/) models on Raspberry Pi AI Cameras with the IMX500 sensor for fast, low-power performance.
## Deployment Integrations ## Deployment Integrations
- [CoreML](coreml.md): CoreML, developed by [Apple](https://www.apple.com/), is a framework designed for efficiently integrating machine learning models into applications across iOS, macOS, watchOS, and tvOS, using Apple's hardware for effective and secure [model deployment](https://www.ultralytics.com/glossary/model-deployment). - [CoreML](coreml.md): CoreML, developed by [Apple](https://www.apple.com/), is a framework designed for efficiently integrating machine learning models into applications across iOS, macOS, watchOS, and tvOS, using Apple's hardware for effective and secure [model deployment](https://www.ultralytics.com/glossary/model-deployment).

@ -4,7 +4,7 @@ description: Learn to export Ultralytics YOLOv8 models to Sony's IMX500 format t
keywords: Sony, IMX500, IMX 500, Atrios, MCT, model export, quantization, pruning, deep learning optimization, Raspberry Pi AI Camera, edge AI, PyTorch, IMX keywords: Sony, IMX500, IMX 500, Atrios, MCT, model export, quantization, pruning, deep learning optimization, Raspberry Pi AI Camera, edge AI, PyTorch, IMX
--- ---
# IMX500 Export for Ultralytics YOLOv8 # Sony IMX500 Export for Ultralytics YOLOv8
This guide covers exporting and deploying Ultralytics YOLOv8 models to Raspberry Pi AI Cameras that feature the Sony IMX500 sensor. This guide covers exporting and deploying Ultralytics YOLOv8 models to Raspberry Pi AI Cameras that feature the Sony IMX500 sensor.

@ -14,3 +14,4 @@
| [PaddlePaddle](../integrations/paddlepaddle.md) | `paddle` | `{{ model_name or "yolo11n" }}_paddle_model/` | ✅ | `imgsz`, `batch` | | [PaddlePaddle](../integrations/paddlepaddle.md) | `paddle` | `{{ model_name or "yolo11n" }}_paddle_model/` | ✅ | `imgsz`, `batch` |
| [MNN](../integrations/mnn.md) | `mnn` | `{{ model_name or "yolo11n" }}.mnn` | ✅ | `imgsz`, `batch`, `int8`, `half` | | [MNN](../integrations/mnn.md) | `mnn` | `{{ model_name or "yolo11n" }}.mnn` | ✅ | `imgsz`, `batch`, `int8`, `half` |
| [NCNN](../integrations/ncnn.md) | `ncnn` | `{{ model_name or "yolo11n" }}_ncnn_model/` | ✅ | `imgsz`, `half`, `batch` | | [NCNN](../integrations/ncnn.md) | `ncnn` | `{{ model_name or "yolo11n" }}_ncnn_model/` | ✅ | `imgsz`, `half`, `batch` |
| [IMX500](../integrations/sony-imx500.md) | `imx` | `{{ model_name or "yolo11n" }}_imx_model/` | ✅ | `imgsz`, `int8` |

@ -19,6 +19,10 @@ keywords: Ultralytics, torch utils, model optimization, device selection, infere
<br><br><hr><br> <br><br><hr><br>
## ::: ultralytics.utils.torch_utils.FXModel
<br><br><hr><br>
## ::: ultralytics.utils.torch_utils.torch_distributed_zero_first ## ::: ultralytics.utils.torch_utils.torch_distributed_zero_first
<br><br><hr><br> <br><br><hr><br>

@ -109,6 +109,9 @@ chr043416@gmail.com:
davis.justin@mssm.org: davis.justin@mssm.org:
avatar: https://avatars.githubusercontent.com/u/23462437?v=4 avatar: https://avatars.githubusercontent.com/u/23462437?v=4
username: justincdavis username: justincdavis
francesco.mttl@gmail.com:
avatar: https://avatars.githubusercontent.com/u/3855193?v=4
username: ambitious-octopus
glenn.jocher@ultralytics.com: glenn.jocher@ultralytics.com:
avatar: https://avatars.githubusercontent.com/u/26833433?v=4 avatar: https://avatars.githubusercontent.com/u/26833433?v=4
username: glenn-jocher username: glenn-jocher

@ -412,12 +412,14 @@ nav:
- TF.js: integrations/tfjs.md - TF.js: integrations/tfjs.md
- TFLite: integrations/tflite.md - TFLite: integrations/tflite.md
- TFLite Edge TPU: integrations/edge-tpu.md - TFLite Edge TPU: integrations/edge-tpu.md
- Sony IMX500: integrations/sony-imx500.md
- TensorBoard: integrations/tensorboard.md - TensorBoard: integrations/tensorboard.md
- TensorRT: integrations/tensorrt.md - TensorRT: integrations/tensorrt.md
- TorchScript: integrations/torchscript.md - TorchScript: integrations/torchscript.md
- VS Code: integrations/vscode.md - VS Code: integrations/vscode.md
- Weights & Biases: integrations/weights-biases.md - Weights & Biases: integrations/weights-biases.md
- Albumentations: integrations/albumentations.md - Albumentations: integrations/albumentations.md
- SONY IMX500: integrations/sony-imx500.md
- HUB: - HUB:
- hub/index.md - hub/index.md
- Web: - Web:
@ -559,7 +561,6 @@ nav:
- utils: reference/nn/modules/utils.md - utils: reference/nn/modules/utils.md
- tasks: reference/nn/tasks.md - tasks: reference/nn/tasks.md
- solutions: - solutions:
- solutions: reference/solutions/solutions.md
- ai_gym: reference/solutions/ai_gym.md - ai_gym: reference/solutions/ai_gym.md
- analytics: reference/solutions/analytics.md - analytics: reference/solutions/analytics.md
- distance_calculation: reference/solutions/distance_calculation.md - distance_calculation: reference/solutions/distance_calculation.md
@ -567,6 +568,7 @@ nav:
- object_counter: reference/solutions/object_counter.md - object_counter: reference/solutions/object_counter.md
- parking_management: reference/solutions/parking_management.md - parking_management: reference/solutions/parking_management.md
- queue_management: reference/solutions/queue_management.md - queue_management: reference/solutions/queue_management.md
- solutions: reference/solutions/solutions.md
- speed_estimation: reference/solutions/speed_estimation.md - speed_estimation: reference/solutions/speed_estimation.md
- streamlit_inference: reference/solutions/streamlit_inference.md - streamlit_inference: reference/solutions/streamlit_inference.md
- trackers: - trackers:

@ -205,3 +205,12 @@ def test_export_ncnn():
"""Test YOLO exports to NCNN format.""" """Test YOLO exports to NCNN format."""
file = YOLO(MODEL).export(format="ncnn", imgsz=32) file = YOLO(MODEL).export(format="ncnn", imgsz=32)
YOLO(file)(SOURCE, imgsz=32) # exported model inference YOLO(file)(SOURCE, imgsz=32) # exported model inference
@pytest.mark.skipif(True, reason="Test disabled as keras and tensorflow version conflicts with tflite export.")
@pytest.mark.skipif(not LINUX or MACOS, reason="Skipping test on Windows and Macos")
def test_export_imx():
"""Test YOLOv8n exports to IMX format."""
model = YOLO("yolov8n.pt")
file = model.export(format="imx", imgsz=32)
YOLO(file)(SOURCE, imgsz=32)

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.28" __version__ = "8.3.29"
import os import os

@ -18,6 +18,7 @@ TensorFlow.js | `tfjs` | yolo11n_web_model/
PaddlePaddle | `paddle` | yolo11n_paddle_model/ PaddlePaddle | `paddle` | yolo11n_paddle_model/
MNN | `mnn` | yolo11n.mnn MNN | `mnn` | yolo11n.mnn
NCNN | `ncnn` | yolo11n_ncnn_model/ NCNN | `ncnn` | yolo11n_ncnn_model/
IMX | `imx` | yolo11n_imx_model/
Requirements: Requirements:
$ pip install "ultralytics[export]" $ pip install "ultralytics[export]"
@ -44,6 +45,7 @@ Inference:
yolo11n_paddle_model # PaddlePaddle yolo11n_paddle_model # PaddlePaddle
yolo11n.mnn # MNN yolo11n.mnn # MNN
yolo11n_ncnn_model # NCNN yolo11n_ncnn_model # NCNN
yolo11n_imx_model # IMX
TensorFlow.js: TensorFlow.js:
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
@ -94,7 +96,7 @@ from ultralytics.utils.checks import check_imgsz, check_is_path_safe, check_requ
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download
from ultralytics.utils.files import file_size, spaces_in_path from ultralytics.utils.files import file_size, spaces_in_path
from ultralytics.utils.ops import Profile from ultralytics.utils.ops import Profile
from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device, smart_inference_mode from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device
def export_formats(): def export_formats():
@ -114,6 +116,7 @@ def export_formats():
["PaddlePaddle", "paddle", "_paddle_model", True, True], ["PaddlePaddle", "paddle", "_paddle_model", True, True],
["MNN", "mnn", ".mnn", True, True], ["MNN", "mnn", ".mnn", True, True],
["NCNN", "ncnn", "_ncnn_model", True, True], ["NCNN", "ncnn", "_ncnn_model", True, True],
["IMX", "imx", "_imx_model", True, True],
] ]
return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU"], zip(*x))) return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU"], zip(*x)))
@ -171,7 +174,6 @@ class Exporter:
self.callbacks = _callbacks or callbacks.get_default_callbacks() self.callbacks = _callbacks or callbacks.get_default_callbacks()
callbacks.add_integration_callbacks(self) callbacks.add_integration_callbacks(self)
@smart_inference_mode()
def __call__(self, model=None) -> str: def __call__(self, model=None) -> str:
"""Returns list of exported files/dirs after running callbacks.""" """Returns list of exported files/dirs after running callbacks."""
self.run_callbacks("on_export_start") self.run_callbacks("on_export_start")
@ -194,9 +196,22 @@ class Exporter:
flags = [x == fmt for x in fmts] flags = [x == fmt for x in fmts]
if sum(flags) != 1: if sum(flags) != 1:
raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}") raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}")
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, mnn, ncnn = ( (
flags # export booleans jit,
) onnx,
xml,
engine,
coreml,
saved_model,
pb,
tflite,
edgetpu,
tfjs,
paddle,
mnn,
ncnn,
imx,
) = flags # export booleans
is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs)) is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs))
# Device # Device
@ -210,6 +225,9 @@ class Exporter:
self.device = select_device("cpu" if self.args.device is None else self.args.device) self.device = select_device("cpu" if self.args.device is None else self.args.device)
# Checks # Checks
if imx and not self.args.int8:
LOGGER.warning("WARNING ⚠ IMX only supports int8 export, setting int8=True.")
self.args.int8 = True
if not hasattr(model, "names"): if not hasattr(model, "names"):
model.names = default_class_names() model.names = default_class_names()
model.names = check_class_names(model.names) model.names = check_class_names(model.names)
@ -249,6 +267,7 @@ class Exporter:
) )
if mnn and (IS_RASPBERRYPI or IS_JETSON): if mnn and (IS_RASPBERRYPI or IS_JETSON):
raise SystemError("MNN export not supported on Raspberry Pi and NVIDIA Jetson") raise SystemError("MNN export not supported on Raspberry Pi and NVIDIA Jetson")
# Input # Input
im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device) im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
file = Path( file = Path(
@ -264,6 +283,11 @@ class Exporter:
model.eval() model.eval()
model.float() model.float()
model = model.fuse() model = model.fuse()
if imx:
from ultralytics.utils.torch_utils import FXModel
model = FXModel(model)
for m in model.modules(): for m in model.modules():
if isinstance(m, (Detect, RTDETRDecoder)): # includes all Detect subclasses like Segment, Pose, OBB if isinstance(m, (Detect, RTDETRDecoder)): # includes all Detect subclasses like Segment, Pose, OBB
m.dynamic = self.args.dynamic m.dynamic = self.args.dynamic
@ -273,6 +297,15 @@ class Exporter:
elif isinstance(m, C2f) and not is_tf_format: elif isinstance(m, C2f) and not is_tf_format:
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph # EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
m.forward = m.forward_split m.forward = m.forward_split
if isinstance(m, Detect) and imx:
from ultralytics.utils.tal import make_anchors
m.anchors, m.strides = (
x.transpose(0, 1)
for x in make_anchors(
torch.cat([s / m.stride.unsqueeze(-1) for s in self.imgsz], dim=1), m.stride, 0.5
)
)
y = None y = None
for _ in range(2): for _ in range(2):
@ -347,6 +380,8 @@ class Exporter:
f[11], _ = self.export_mnn() f[11], _ = self.export_mnn()
if ncnn: # NCNN if ncnn: # NCNN
f[12], _ = self.export_ncnn() f[12], _ = self.export_ncnn()
if imx:
f[13], _ = self.export_imx()
# Finish # Finish
f = [str(x) for x in f if x] # filter out '' and None f = [str(x) for x in f if x] # filter out '' and None
@ -1068,6 +1103,137 @@ class Exporter:
yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
return f, None return f, None
@try_export
def export_imx(self, prefix=colorstr("IMX:")):
"""YOLO IMX export."""
gptq = False
assert LINUX, "export only supported on Linux. See https://developer.aitrios.sony-semicon.com/en/raspberrypi-ai-camera/documentation/imx500-converter"
if getattr(self.model, "end2end", False):
raise ValueError("IMX export is not supported for end2end models.")
if "C2f" not in self.model.__str__():
raise ValueError("IMX export is only supported for YOLOv8 detection models")
check_requirements(("model-compression-toolkit==2.1.1", "sony-custom-layers==0.2.0", "tensorflow==2.12.0"))
check_requirements("imx500-converter[pt]==3.14.3") # Separate requirements for imx500-converter
import model_compression_toolkit as mct
import onnx
from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms
try:
out = subprocess.run(
["java", "--version"], check=True, capture_output=True
) # Java 17 is required for imx500-converter
if "openjdk 17" not in str(out.stdout):
raise FileNotFoundError
except FileNotFoundError:
subprocess.run(["sudo", "apt", "install", "-y", "openjdk-17-jdk", "openjdk-17-jre"], check=True)
def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)):
for batch in dataloader:
img = batch["img"]
img = img / 255.0
yield [img]
tpc = mct.get_target_platform_capabilities(
fw_name="pytorch", target_platform_name="imx500", target_platform_version="v1"
)
config = mct.core.CoreConfig(
mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=10),
quantization_config=mct.core.QuantizationConfig(concat_threshold_update=True),
)
resource_utilization = mct.core.ResourceUtilization(weights_memory=3146176 * 0.76)
quant_model = (
mct.gptq.pytorch_gradient_post_training_quantization( # Perform Gradient-Based Post Training Quantization
model=self.model,
representative_data_gen=representative_dataset_gen,
target_resource_utilization=resource_utilization,
gptq_config=mct.gptq.get_pytorch_gptq_config(n_epochs=1000, use_hessian_based_weights=False),
core_config=config,
target_platform_capabilities=tpc,
)[0]
if gptq
else mct.ptq.pytorch_post_training_quantization( # Perform post training quantization
in_module=self.model,
representative_data_gen=representative_dataset_gen,
target_resource_utilization=resource_utilization,
core_config=config,
target_platform_capabilities=tpc,
)[0]
)
class NMSWrapper(torch.nn.Module):
def __init__(
self,
model: torch.nn.Module,
score_threshold: float = 0.001,
iou_threshold: float = 0.7,
max_detections: int = 300,
):
"""
Wrapping PyTorch Module with multiclass_nms layer from sony_custom_layers.
Args:
model (nn.Module): Model instance.
score_threshold (float): Score threshold for non-maximum suppression.
iou_threshold (float): Intersection over union threshold for non-maximum suppression.
max_detections (float): The number of detections to return.
"""
super().__init__()
self.model = model
self.score_threshold = score_threshold
self.iou_threshold = iou_threshold
self.max_detections = max_detections
def forward(self, images):
# model inference
outputs = self.model(images)
boxes = outputs[0]
scores = outputs[1]
nms = multiclass_nms(
boxes=boxes,
scores=scores,
score_threshold=self.score_threshold,
iou_threshold=self.iou_threshold,
max_detections=self.max_detections,
)
return nms
quant_model = NMSWrapper(
model=quant_model,
score_threshold=self.args.conf or 0.001,
iou_threshold=self.args.iou,
max_detections=self.args.max_det,
).to(self.device)
f = Path(str(self.file).replace(self.file.suffix, "_imx_model"))
f.mkdir(exist_ok=True)
onnx_model = f / Path(str(self.file).replace(self.file.suffix, "_imx.onnx")) # js dir
mct.exporter.pytorch_export_model(
model=quant_model, save_model_path=onnx_model, repr_dataset=representative_dataset_gen
)
model_onnx = onnx.load(onnx_model) # load onnx model
for k, v in self.metadata.items():
meta = model_onnx.metadata_props.add()
meta.key, meta.value = k, str(v)
onnx.save(model_onnx, onnx_model)
subprocess.run(
["imxconv-pt", "-i", str(onnx_model), "-o", str(f), "--no-input-persistency", "--overwrite-output"],
check=True,
)
# Needed for imx models.
with open(f / "labels.txt", "w") as file:
file.writelines([f"{name}\n" for _, name in self.model.names.items()])
return f, None
def _add_tflite_metadata(self, file): def _add_tflite_metadata(self, file):
"""Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata.""" """Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata."""
import flatbuffers import flatbuffers

@ -123,6 +123,7 @@ class AutoBackend(nn.Module):
paddle, paddle,
mnn, mnn,
ncnn, ncnn,
imx,
triton, triton,
) = self._model_type(w) ) = self._model_type(w)
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16 fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
@ -182,8 +183,8 @@ class AutoBackend(nn.Module):
check_requirements("opencv-python>=4.5.4") check_requirements("opencv-python>=4.5.4")
net = cv2.dnn.readNetFromONNX(w) net = cv2.dnn.readNetFromONNX(w)
# ONNX Runtime # ONNX Runtime and IMX
elif onnx: elif onnx or imx:
LOGGER.info(f"Loading {w} for ONNX Runtime inference...") LOGGER.info(f"Loading {w} for ONNX Runtime inference...")
check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime")) check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime"))
if IS_RASPBERRYPI or IS_JETSON: if IS_RASPBERRYPI or IS_JETSON:
@ -199,7 +200,22 @@ class AutoBackend(nn.Module):
device = torch.device("cpu") device = torch.device("cpu")
cuda = False cuda = False
LOGGER.info(f"Preferring ONNX Runtime {providers[0]}") LOGGER.info(f"Preferring ONNX Runtime {providers[0]}")
session = onnxruntime.InferenceSession(w, providers=providers) if onnx:
session = onnxruntime.InferenceSession(w, providers=providers)
else:
check_requirements(
["model-compression-toolkit==2.1.1", "sony-custom-layers[torch]==0.2.0", "onnxruntime-extensions"]
)
w = next(Path(w).glob("*.onnx"))
LOGGER.info(f"Loading {w} for ONNX IMX inference...")
import mct_quantizers as mctq
from sony_custom_layers.pytorch.object_detection import nms_ort # noqa
session = onnxruntime.InferenceSession(
w, mctq.get_ort_session_options(), providers=["CPUExecutionProvider"]
)
task = "detect"
output_names = [x.name for x in session.get_outputs()] output_names = [x.name for x in session.get_outputs()]
metadata = session.get_modelmeta().custom_metadata_map metadata = session.get_modelmeta().custom_metadata_map
dynamic = isinstance(session.get_outputs()[0].shape[0], str) dynamic = isinstance(session.get_outputs()[0].shape[0], str)
@ -520,7 +536,7 @@ class AutoBackend(nn.Module):
y = self.net.forward() y = self.net.forward()
# ONNX Runtime # ONNX Runtime
elif self.onnx: elif self.onnx or self.imx:
if self.dynamic: if self.dynamic:
im = im.cpu().numpy() # torch to numpy im = im.cpu().numpy() # torch to numpy
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im}) y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
@ -537,6 +553,9 @@ class AutoBackend(nn.Module):
) )
self.session.run_with_iobinding(self.io) self.session.run_with_iobinding(self.io)
y = self.bindings y = self.bindings
if self.imx:
# boxes, conf, cls
y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None]], axis=-1)
# OpenVINO # OpenVINO
elif self.xml: elif self.xml:

@ -240,7 +240,8 @@ class C2f(nn.Module):
def forward_split(self, x): def forward_split(self, x):
"""Forward pass using split() instead of chunk().""" """Forward pass using split() instead of chunk()."""
y = list(self.cv1(x).split((self.c, self.c), 1)) y = self.cv1(x).split((self.c, self.c), 1)
y = [y[0], y[1]]
y.extend(m(y[-1]) for m in self.m) y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1)) return self.cv2(torch.cat(y, 1))

@ -23,6 +23,7 @@ class Detect(nn.Module):
dynamic = False # force grid reconstruction dynamic = False # force grid reconstruction
export = False # export mode export = False # export mode
format = None # export format
end2end = False # end2end end2end = False # end2end
max_det = 300 # max_det max_det = 300 # max_det
shape = None shape = None
@ -101,7 +102,7 @@ class Detect(nn.Module):
# Inference path # Inference path
shape = x[0].shape # BCHW shape = x[0].shape # BCHW
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
if self.dynamic or self.shape != shape: if self.format != "imx" and (self.dynamic or self.shape != shape):
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape self.shape = shape
@ -119,6 +120,11 @@ class Detect(nn.Module):
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1) grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
norm = self.strides / (self.stride[0] * grid_size) norm = self.strides / (self.stride[0] * grid_size)
dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2]) dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
elif self.export and self.format == "imx":
dbox = self.decode_bboxes(
self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides, xywh=False
)
return dbox.transpose(1, 2), cls.sigmoid().permute(0, 2, 1)
else: else:
dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
@ -137,9 +143,9 @@ class Detect(nn.Module):
a[-1].bias.data[:] = 1.0 # box a[-1].bias.data[:] = 1.0 # box
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
def decode_bboxes(self, bboxes, anchors): def decode_bboxes(self, bboxes, anchors, xywh=True):
"""Decode bounding boxes.""" """Decode bounding boxes."""
return dist2bbox(bboxes, anchors, xywh=not self.end2end, dim=1) return dist2bbox(bboxes, anchors, xywh=xywh and (not self.end2end), dim=1)
@staticmethod @staticmethod
def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80): def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):

@ -118,6 +118,11 @@ def benchmark(
assert not IS_JETSON, "MNN export not supported on NVIDIA Jetson" assert not IS_JETSON, "MNN export not supported on NVIDIA Jetson"
if i == 13: # NCNN if i == 13: # NCNN
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet" assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
if i == 14: # IMX
assert not is_end2end
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 IMX exports not supported"
assert model.task == "detect", "IMX only supported for detection task"
assert "C2f" in model.__str__(), "IMX only supported for YOLOv8"
if "cpu" in device.type: if "cpu" in device.type:
assert cpu, "inference not supported on CPU" assert cpu, "inference not supported on CPU"
if "cuda" in device.type: if "cuda" in device.type:

@ -306,7 +306,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
assert feats is not None assert feats is not None
dtype, device = feats[0].dtype, feats[0].device dtype, device = feats[0].dtype, feats[0].device
for i, stride in enumerate(strides): for i, stride in enumerate(strides):
_, _, h, w = feats[i].shape h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx) sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)

@ -729,3 +729,48 @@ class EarlyStopping:
f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping." f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
) )
return stop return stop
class FXModel(nn.Module):
"""
A custom model class for torch.fx compatibility.
This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph manipulation.
It copies attributes from an existing model and explicitly sets the model attribute to ensure proper copying.
Args:
model (torch.nn.Module): The original model to wrap for torch.fx compatibility.
"""
def __init__(self, model):
"""
Initialize the FXModel.
Args:
model (torch.nn.Module): The original model to wrap for torch.fx compatibility.
"""
super().__init__()
copy_attr(self, model)
# Explicitly set `model` since `copy_attr` somehow does not copy it.
self.model = model.model
def forward(self, x):
"""
Forward pass through the model.
This method performs the forward pass through the model, handling the dependencies between layers and saving intermediate outputs.
Args:
x (torch.Tensor): The input tensor to the model.
Returns:
(torch.Tensor): The output tensor from the model.
"""
y = [] # outputs
for m in self.model:
if m.f != -1: # if not from previous layer
# from earlier layers
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]
x = m(x) # run
y.append(x) # save output
return x

Loading…
Cancel
Save