`ultralytics 8.3.21` NVIDIA DLA export support (#16449)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
Co-authored-by: Lakshantha Dissanayake <lakshanthad@yahoo.com>
Co-authored-by: Lakshantha <lakshantha@ultralytics.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
pull/17128/head v8.3.21
Justin Davis 1 month ago committed by GitHub
parent b8fbee3a97
commit 8f0a94409f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 45
      docs/en/guides/nvidia-jetson.md
  2. 1
      docs/en/macros/export-args.md
  3. 2
      ultralytics/__init__.py
  4. 23
      ultralytics/engine/exporter.py

@ -240,7 +240,7 @@ pip install onnxruntime_gpu-1.17.0-cp38-cp38-linux_aarch64.whl
Out of all the model export formats supported by Ultralytics, TensorRT delivers the best inference performance when working with NVIDIA Jetson devices and our recommendation is to use TensorRT with Jetson. We also have a detailed document on TensorRT [here](../integrations/tensorrt.md).
## Convert Model to TensorRT and Run Inference
### Convert Model to TensorRT and Run Inference
The YOLOv8n model in PyTorch format is converted to TensorRT to run inference with the exported model.
@ -254,7 +254,7 @@ The YOLOv8n model in PyTorch format is converted to TensorRT to run inference wi
# Load a YOLOv8n PyTorch model
model = YOLO("yolov8n.pt")
# Export the model
# Export the model to TensorRT
model.export(format="engine") # creates 'yolov8n.engine'
# Load the exported TensorRT model
@ -274,6 +274,47 @@ The YOLOv8n model in PyTorch format is converted to TensorRT to run inference wi
yolo predict model=yolov8n.engine source='https://ultralytics.com/images/bus.jpg'
```
### Use NVIDIA Deep Learning Accelerator (DLA)
[NVIDIA Deep Learning Accelerator (DLA)](https://developer.nvidia.com/deep-learning-accelerator) is a specialized hardware component built into NVIDIA Jetson devices that optimizes deep learning inference for energy efficiency and performance. By offloading tasks from the GPU (freeing it up for more intensive processes), DLA enables models to run with lower power consumption while maintaining high throughput, ideal for embedded systems and real-time AI applications.
The following Jetson devices are equipped with DLA hardware:
- Jetson Orin NX 16GB
- Jetson AGX Orin Series
- Jetson AGX Xavier Series
- Jetson Xavier NX Series
!!! example
=== "Python"
```python
from ultralytics import YOLO
# Load a YOLOv8n PyTorch model
model = YOLO("yolov8n.pt")
# Export the model to TensorRT with DLA enabled (only works with FP16 or INT8)
model.export(format="engine", device="dla:0", half=True) # dla:0 or dla:1 corresponds to the DLA cores
# Load the exported TensorRT model
trt_model = YOLO("yolov8n.engine")
# Run inference
results = trt_model("https://ultralytics.com/images/bus.jpg")
```
=== "CLI"
```bash
# Export a YOLOv8n PyTorch model to TensorRT format with DLA enabled (only works with FP16 or INT8)
yolo export model=yolov8n.pt format=engine device="dla:0" half=True # dla:0 or dla:1 corresponds to the DLA cores
# Run inference with the exported model on the DLA
yolo predict model=yolov8n.engine source='https://ultralytics.com/images/bus.jpg'
```
!!! note
Visit the [Export page](../modes/export.md#arguments) to access additional arguments when exporting models to different model formats

@ -12,3 +12,4 @@
| `workspace` | `float` | `4.0` | Sets the maximum workspace size in GiB for TensorRT optimizations, balancing memory usage and performance. |
| `nms` | `bool` | `False` | Adds Non-Maximum Suppression (NMS) to the CoreML export, essential for accurate and efficient detection post-processing. |
| `batch` | `int` | `1` | Specifies export model batch inference size or the max number of images the exported model will process concurrently in `predict` mode. |
| `device` | `str` | `None` | Specifies the device for exporting: GPU (`device=0`), CPU (`device=cpu`), MPS for Apple silicon (`device=mps`) or DLA for NVIDIA Jetson (`device=dla:0` or `device=dla:1`). |

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.20"
__version__ = "8.3.21"
import os

@ -194,6 +194,11 @@ class Exporter:
is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs))
# Device
dla = None
if fmt == "engine" and "dla" in self.args.device:
dla = self.args.device.split(":")[-1]
assert dla in {"0", "1"}, f"Expected self.args.device='dla:0' or 'dla:1, but got {self.args.device}."
self.args.device = "0"
if fmt == "engine" and self.args.device is None:
LOGGER.warning("WARNING ⚠ TensorRT requires GPU export, automatically assigning device=0")
self.args.device = "0"
@ -309,7 +314,7 @@ class Exporter:
if jit or ncnn: # TorchScript
f[0], _ = self.export_torchscript()
if engine: # TensorRT required before ONNX
f[1], _ = self.export_engine()
f[1], _ = self.export_engine(dla=dla)
if onnx: # ONNX
f[2], _ = self.export_onnx()
if xml: # OpenVINO
@ -682,7 +687,7 @@ class Exporter:
return f, ct_model
@try_export
def export_engine(self, prefix=colorstr("TensorRT:")):
def export_engine(self, dla=None, prefix=colorstr("TensorRT:")):
"""YOLO TensorRT export https://developer.nvidia.com/tensorrt."""
assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
f_onnx, _ = self.export_onnx() # run before TRT import https://github.com/ultralytics/ultralytics/issues/7016
@ -717,6 +722,20 @@ class Exporter:
network = builder.create_network(flag)
half = builder.platform_has_fast_fp16 and self.args.half
int8 = builder.platform_has_fast_int8 and self.args.int8
# Optionally switch to DLA if enabled
if dla is not None:
if not IS_JETSON:
raise ValueError("DLA is only available on NVIDIA Jetson devices")
LOGGER.info(f"{prefix} enabling DLA on core {dla}...")
if not self.args.half and not self.args.int8:
raise ValueError(
"DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again."
)
config.default_device_type = trt.DeviceType.DLA
config.DLA_core = int(dla)
config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
# Read ONNX file
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(f_onnx):

Loading…
Cancel
Save