diff --git a/README.zh-CN.md b/README.zh-CN.md
index ac87d1bd4c..53cb7e05d6 100644
--- a/README.zh-CN.md
+++ b/README.zh-CN.md
@@ -26,7 +26,7 @@
ๆณ็ณ่ฏทไผไธ่ฎธๅฏ่ฏ๏ผ่ฏทๅฎๆ [Ultralytics Licensing](https://www.ultralytics.com/license) ไธ็่กจๅใ
-
diff --git a/docs/en/models/index.md b/docs/en/models/index.md
index 5e9d07f3d5..c0f4fd333d 100644
--- a/docs/en/models/index.md
+++ b/docs/en/models/index.md
@@ -8,7 +8,7 @@ keywords: Ultralytics, supported models, YOLOv3, YOLOv4, YOLOv5, YOLOv6, YOLOv7,
Welcome to Ultralytics' model documentation! We offer support for a wide range of models, each tailored to specific tasks like [object detection](../tasks/detect.md), [instance segmentation](../tasks/segment.md), [image classification](../tasks/classify.md), [pose estimation](../tasks/pose.md), and [multi-object tracking](../modes/track.md). If you're interested in contributing your model architecture to Ultralytics, check out our [Contributing Guide](../help/contributing.md).
-![Ultralytics YOLO11 Comparison Plots](https://github.com/user-attachments/assets/a311a4ed-bbf2-43b5-8012-5f183a28a845)
+![Ultralytics YOLO11 Comparison Plots](https://raw.githubusercontent.com/ultralytics/assets/refs/heads/main/yolo/performance-comparison.png)
## Featured Models
diff --git a/docs/en/models/yolo11.md b/docs/en/models/yolo11.md
index 0c755147ab..8baf2dd725 100644
--- a/docs/en/models/yolo11.md
+++ b/docs/en/models/yolo11.md
@@ -10,7 +10,7 @@ keywords: YOLO11, state-of-the-art object detection, YOLO series, Ultralytics, c
YOLO11 is the latest iteration in the [Ultralytics](https://www.ultralytics.com/) YOLO series of real-time object detectors, redefining what's possible with cutting-edge [accuracy](https://www.ultralytics.com/glossary/accuracy), speed, and efficiency. Building upon the impressive advancements of previous YOLO versions, YOLO11 introduces significant improvements in architecture and training methods, making it a versatile choice for a wide range of [computer vision](https://www.ultralytics.com/glossary/computer-vision-cv) tasks.
-![Ultralytics YOLO11 Comparison Plots](https://github.com/user-attachments/assets/a311a4ed-bbf2-43b5-8012-5f183a28a845)
+![Ultralytics YOLO11 Comparison Plots](hhttps://raw.githubusercontent.com/ultralytics/assets/refs/heads/main/yolo/performance-comparison.png)
From 6ebbe17bd82498140e5001ef8d39dfc023163a5d Mon Sep 17 00:00:00 2001
From: Muhammad Rizwan Munawar
Date: Tue, 22 Oct 2024 22:50:08 +0500
Subject: [PATCH 35/40] Add YOLO publication notice in Docs (#17095)
Co-authored-by: UltralyticsAssistant
Co-authored-by: Glenn Jocher
---
docs/en/models/yolo11.md | 6 +++++-
docs/en/models/yolov5.md | 6 +++++-
docs/en/models/yolov8.md | 4 ++++
3 files changed, 14 insertions(+), 2 deletions(-)
diff --git a/docs/en/models/yolo11.md b/docs/en/models/yolo11.md
index 8baf2dd725..fe9115f2ed 100644
--- a/docs/en/models/yolo11.md
+++ b/docs/en/models/yolo11.md
@@ -8,9 +8,13 @@ keywords: YOLO11, state-of-the-art object detection, YOLO series, Ultralytics, c
## Overview
+!!! tip "Ultralytics YOLO11 Publication"
+
+ Ultralytics has not published a formal research paper for YOLO11 due to the rapidly evolving nature of the models. We focus on advancing the technology and making it easier to use, rather than producing static documentation. For the most up-to-date information on YOLO architecture, features, and usage, please refer to our [GitHub repository](https://github.com/ultralytics/ultralytics) and [documentation](https://docs.ultralytics.com).
+
YOLO11 is the latest iteration in the [Ultralytics](https://www.ultralytics.com/) YOLO series of real-time object detectors, redefining what's possible with cutting-edge [accuracy](https://www.ultralytics.com/glossary/accuracy), speed, and efficiency. Building upon the impressive advancements of previous YOLO versions, YOLO11 introduces significant improvements in architecture and training methods, making it a versatile choice for a wide range of [computer vision](https://www.ultralytics.com/glossary/computer-vision-cv) tasks.
-![Ultralytics YOLO11 Comparison Plots](hhttps://raw.githubusercontent.com/ultralytics/assets/refs/heads/main/yolo/performance-comparison.png)
+![Ultralytics YOLO11 Comparison Plots](https://raw.githubusercontent.com/ultralytics/assets/refs/heads/main/yolo/performance-comparison.png)
diff --git a/docs/en/models/yolov5.md b/docs/en/models/yolov5.md
index 8ff1c36ec0..91c562a44e 100644
--- a/docs/en/models/yolov5.md
+++ b/docs/en/models/yolov5.md
@@ -4,7 +4,11 @@ description: Explore YOLOv5u, an advanced object detection model with optimized
keywords: YOLOv5, YOLOv5u, object detection, Ultralytics, anchor-free, pre-trained models, accuracy, speed, real-time detection
---
-# YOLOv5
+# Ultralytics YOLOv5
+
+!!! tip "Ultralytics YOLOv5 Publication"
+
+ Ultralytics has not published a formal research paper for YOLOv5 due to the rapidly evolving nature of the models. We focus on advancing the technology and making it easier to use, rather than producing static documentation. For the most up-to-date information on YOLO architecture, features, and usage, please refer to our [GitHub repository](https://github.com/ultralytics/ultralytics) and [documentation](https://docs.ultralytics.com).
## Overview
diff --git a/docs/en/models/yolov8.md b/docs/en/models/yolov8.md
index 036cd305a1..c8e4397d15 100644
--- a/docs/en/models/yolov8.md
+++ b/docs/en/models/yolov8.md
@@ -6,6 +6,10 @@ keywords: YOLOv8, real-time object detection, YOLO series, Ultralytics, computer
# Ultralytics YOLOv8
+!!! tip "Ultralytics YOLOv8 Publication"
+
+ Ultralytics has not published a formal research paper for YOLOv8 due to the rapidly evolving nature of the models. We focus on advancing the technology and making it easier to use, rather than producing static documentation. For the most up-to-date information on YOLO architecture, features, and usage, please refer to our [GitHub repository](https://github.com/ultralytics/ultralytics) and [documentation](https://docs.ultralytics.com).
+
## Overview
YOLOv8 is the latest iteration in the YOLO series of real-time object detectors, offering cutting-edge performance in terms of accuracy and speed. Building upon the advancements of previous YOLO versions, YOLOv8 introduces new features and optimizations that make it an ideal choice for various [object detection](https://www.ultralytics.com/glossary/object-detection) tasks in a wide range of applications.
From 797f2374618a231c82d82f3e572635ad776af3b1 Mon Sep 17 00:00:00 2001
From: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com>
Date: Wed, 23 Oct 2024 23:09:21 +0800
Subject: [PATCH 36/40] Add `project` and `name` args to docs for predict and
val task (#17114)
Co-authored-by: UltralyticsAssistant
---
docs/en/macros/predict-args.md | 2 ++
docs/en/macros/validation-args.md | 2 ++
2 files changed, 4 insertions(+)
diff --git a/docs/en/macros/predict-args.md b/docs/en/macros/predict-args.md
index 35c285afe0..091e692a69 100644
--- a/docs/en/macros/predict-args.md
+++ b/docs/en/macros/predict-args.md
@@ -15,3 +15,5 @@
| `classes` | `list[int]` | `None` | Filters predictions to a set of class IDs. Only detections belonging to the specified classes will be returned. Useful for focusing on relevant objects in multi-class detection tasks. |
| `retina_masks` | `bool` | `False` | Uses high-resolution segmentation masks if available in the model. This can enhance mask quality for segmentation tasks, providing finer detail. |
| `embed` | `list[int]` | `None` | Specifies the layers from which to extract feature vectors or [embeddings](https://www.ultralytics.com/glossary/embeddings). Useful for downstream tasks like clustering or similarity search. |
+| `project` | `str` | `None` | Name of the project directory where prediction outputs are saved if `save` is enabled. |
+| `name` | `str` | `None` | Name of the prediction run. Used for creating a subdirectory within the project folder, where prediction outputs are stored if `save` is enabled. |
diff --git a/docs/en/macros/validation-args.md b/docs/en/macros/validation-args.md
index 5c709f7bfc..5eeea81f49 100644
--- a/docs/en/macros/validation-args.md
+++ b/docs/en/macros/validation-args.md
@@ -14,3 +14,5 @@
| `plots` | `bool` | `False` | When set to `True`, generates and saves plots of predictions versus ground truth for visual evaluation of the model's performance. |
| `rect` | `bool` | `False` | If `True`, uses rectangular inference for batching, reducing padding and potentially increasing speed and efficiency. |
| `split` | `str` | `val` | Determines the dataset split to use for validation (`val`, `test`, or `train`). Allows flexibility in choosing the data segment for performance evaluation. |
+| `project` | `str` | `None` | Name of the project directory where validation outputs are saved. |
+| `name` | `str` | `None` | Name of the validation run. Used for creating a subdirectory within the project folder, where valdiation logs and outputs are stored. |
From d7eef9f330dfafa43fb9f1ca5b6d3ceb7fd51296 Mon Sep 17 00:00:00 2001
From: Iaroslav Omelianenko
Date: Wed, 23 Oct 2024 19:55:42 +0300
Subject: [PATCH 37/40] Comet integration fix (#17099)
Co-authored-by: UltralyticsAssistant
Co-authored-by: Glenn Jocher
---
ultralytics/utils/callbacks/comet.py | 43 +++++++++++++++++++++-------
1 file changed, 33 insertions(+), 10 deletions(-)
diff --git a/ultralytics/utils/callbacks/comet.py b/ultralytics/utils/callbacks/comet.py
index 3a217c3f25..3fae97f917 100644
--- a/ultralytics/utils/callbacks/comet.py
+++ b/ultralytics/utils/callbacks/comet.py
@@ -1,6 +1,7 @@
# Ultralytics YOLO ๐, AGPL-3.0 license
from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops
+from ultralytics.utils.metrics import ClassifyMetrics, DetMetrics, OBBMetrics, PoseMetrics, SegmentMetrics
try:
assert not TESTS_RUNNING # do not log pytest
@@ -16,8 +17,11 @@ try:
COMET_SUPPORTED_TASKS = ["detect"]
# Names of plots created by Ultralytics that are logged to Comet
- EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve", "confusion_matrix"
+ CONFUSION_MATRIX_PLOT_NAMES = "confusion_matrix", "confusion_matrix_normalized"
+ EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve"
LABEL_PLOT_NAMES = "labels", "labels_correlogram"
+ SEGMENT_METRICS_PLOT_PREFIX = "Box", "Mask"
+ POSE_METRICS_PLOT_PREFIX = "Box", "Pose"
_comet_image_prediction_count = 0
@@ -86,7 +90,7 @@ def _create_experiment(args):
"max_image_predictions": _get_max_image_predictions_to_log(),
}
)
- experiment.log_other("Created from", "yolov8")
+ experiment.log_other("Created from", "ultralytics")
except Exception as e:
LOGGER.warning(f"WARNING โ ๏ธ Comet installed but not initialized correctly, not logging this run. {e}")
@@ -274,11 +278,31 @@ def _log_image_predictions(experiment, validator, curr_step):
def _log_plots(experiment, trainer):
"""Logs evaluation plots and label plots for the experiment."""
- plot_filenames = [trainer.save_dir / f"{plots}.png" for plots in EVALUATION_PLOT_NAMES]
- _log_images(experiment, plot_filenames, None)
-
- label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES]
- _log_images(experiment, label_plot_filenames, None)
+ plot_filenames = None
+ if isinstance(trainer.validator.metrics, SegmentMetrics) and trainer.validator.metrics.task == "segment":
+ plot_filenames = [
+ trainer.save_dir / f"{prefix}{plots}.png"
+ for plots in EVALUATION_PLOT_NAMES
+ for prefix in SEGMENT_METRICS_PLOT_PREFIX
+ ]
+ elif isinstance(trainer.validator.metrics, PoseMetrics):
+ plot_filenames = [
+ trainer.save_dir / f"{prefix}{plots}.png"
+ for plots in EVALUATION_PLOT_NAMES
+ for prefix in POSE_METRICS_PLOT_PREFIX
+ ]
+ elif isinstance(trainer.validator.metrics, DetMetrics) or isinstance(trainer.validator.metrics, OBBMetrics):
+ plot_filenames = [trainer.save_dir / f"{plots}.png" for plots in EVALUATION_PLOT_NAMES]
+
+ if plot_filenames is not None:
+ _log_images(experiment, plot_filenames, None)
+
+ confusion_matrix_filenames = [trainer.save_dir / f"{plots}.png" for plots in CONFUSION_MATRIX_PLOT_NAMES]
+ _log_images(experiment, confusion_matrix_filenames, None)
+
+ if not isinstance(trainer.validator.metrics, ClassifyMetrics):
+ label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES]
+ _log_images(experiment, label_plot_filenames, None)
def _log_model(experiment, trainer):
@@ -307,9 +331,6 @@ def on_train_epoch_end(trainer):
experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=curr_step, epoch=curr_epoch)
- if curr_epoch == 1:
- _log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step)
-
def on_fit_epoch_end(trainer):
"""Logs model assets at the end of each epoch."""
@@ -356,6 +377,8 @@ def on_train_end(trainer):
_log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
_log_image_predictions(experiment, trainer.validator, curr_step)
+ _log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step)
+ _log_images(experiment, trainer.save_dir.glob("val_batch*.jpg"), curr_step)
experiment.end()
global _comet_image_prediction_count
From b8fbee3a975dee918ea3ce9369847246500327e9 Mon Sep 17 00:00:00 2001
From: Glenn Jocher
Date: Wed, 23 Oct 2024 18:57:42 +0200
Subject: [PATCH 38/40] Update datasets index.md (#17098)
---
docs/en/datasets/index.md | 2 +-
mkdocs.yml | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/docs/en/datasets/index.md b/docs/en/datasets/index.md
index 9d7a10ed7e..4b6bd6c968 100644
--- a/docs/en/datasets/index.md
+++ b/docs/en/datasets/index.md
@@ -46,7 +46,7 @@ Create [embeddings](https://www.ultralytics.com/glossary/embeddings) for your da
- [VisDrone](detect/visdrone.md): A dataset containing object detection and multi-object tracking data from drone-captured imagery with over 10K images and video sequences.
- [VOC](detect/voc.md): The Pascal Visual Object Classes (VOC) dataset for object detection and segmentation with 20 object classes and over 11K images.
- [xView](detect/xview.md): A dataset for object detection in overhead imagery with 60 object categories and over 1 million annotated objects.
-- [Roboflow 100](detect/roboflow-100.md): A diverse object detection benchmark with 100 datasets spanning seven imagery domains for comprehensive model evaluation.
+- [RF100](detect/roboflow-100.md): A diverse object detection benchmark with 100 datasets spanning seven imagery domains for comprehensive model evaluation.
- [Brain-tumor](detect/brain-tumor.md): A dataset for detecting brain tumors includes MRI or CT scan images with details on tumor presence, location, and characteristics.
- [African-wildlife](detect/african-wildlife.md): A dataset featuring images of African wildlife, including buffalo, elephant, rhino, and zebras.
- [Signature](detect/signature.md): A dataset featuring images of various documents with annotated signatures, supporting document verification and fraud detection research.
diff --git a/mkdocs.yml b/mkdocs.yml
index 17a72c2e2f..a7157ec942 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -274,7 +274,7 @@ nav:
- VisDrone: datasets/detect/visdrone.md
- VOC: datasets/detect/voc.md
- xView: datasets/detect/xview.md
- - Roboflow 100: datasets/detect/roboflow-100.md
+ - RF100: datasets/detect/roboflow-100.md
- Brain-tumor: datasets/detect/brain-tumor.md
- African-wildlife: datasets/detect/african-wildlife.md
- Signature: datasets/detect/signature.md
From 8f0a94409fb2f6320b2d42db9feb4dea7ec40ac1 Mon Sep 17 00:00:00 2001
From: Justin Davis
Date: Wed, 23 Oct 2024 11:00:15 -0600
Subject: [PATCH 39/40] `ultralytics 8.3.21` NVIDIA DLA export support (#16449)
Co-authored-by: UltralyticsAssistant
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
Co-authored-by: Lakshantha Dissanayake
Co-authored-by: Lakshantha
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Glenn Jocher
Co-authored-by: Laughing-q <1185102784@qq.com>
---
docs/en/guides/nvidia-jetson.md | 45 +++++++++++++++++++++++++++++++--
docs/en/macros/export-args.md | 1 +
ultralytics/__init__.py | 2 +-
ultralytics/engine/exporter.py | 23 +++++++++++++++--
4 files changed, 66 insertions(+), 5 deletions(-)
diff --git a/docs/en/guides/nvidia-jetson.md b/docs/en/guides/nvidia-jetson.md
index f352c76b8c..16793288a2 100644
--- a/docs/en/guides/nvidia-jetson.md
+++ b/docs/en/guides/nvidia-jetson.md
@@ -240,7 +240,7 @@ pip install onnxruntime_gpu-1.17.0-cp38-cp38-linux_aarch64.whl
Out of all the model export formats supported by Ultralytics, TensorRT delivers the best inference performance when working with NVIDIA Jetson devices and our recommendation is to use TensorRT with Jetson. We also have a detailed document on TensorRT [here](../integrations/tensorrt.md).
-## Convert Model to TensorRT and Run Inference
+### Convert Model to TensorRT and Run Inference
The YOLOv8n model in PyTorch format is converted to TensorRT to run inference with the exported model.
@@ -254,7 +254,7 @@ The YOLOv8n model in PyTorch format is converted to TensorRT to run inference wi
# Load a YOLOv8n PyTorch model
model = YOLO("yolov8n.pt")
- # Export the model
+ # Export the model to TensorRT
model.export(format="engine") # creates 'yolov8n.engine'
# Load the exported TensorRT model
@@ -274,6 +274,47 @@ The YOLOv8n model in PyTorch format is converted to TensorRT to run inference wi
yolo predict model=yolov8n.engine source='https://ultralytics.com/images/bus.jpg'
```
+### Use NVIDIA Deep Learning Accelerator (DLA)
+
+[NVIDIA Deep Learning Accelerator (DLA)](https://developer.nvidia.com/deep-learning-accelerator) is a specialized hardware component built into NVIDIA Jetson devices that optimizes deep learning inference for energy efficiency and performance. By offloading tasks from the GPU (freeing it up for more intensive processes), DLA enables models to run with lower power consumption while maintaining high throughput, ideal for embedded systems and real-time AI applications.
+
+The following Jetson devices are equipped with DLA hardware:
+
+- Jetson Orin NX 16GB
+- Jetson AGX Orin Series
+- Jetson AGX Xavier Series
+- Jetson Xavier NX Series
+
+!!! example
+
+ === "Python"
+
+ ```python
+ from ultralytics import YOLO
+
+ # Load a YOLOv8n PyTorch model
+ model = YOLO("yolov8n.pt")
+
+ # Export the model to TensorRT with DLA enabled (only works with FP16 or INT8)
+ model.export(format="engine", device="dla:0", half=True) # dla:0 or dla:1 corresponds to the DLA cores
+
+ # Load the exported TensorRT model
+ trt_model = YOLO("yolov8n.engine")
+
+ # Run inference
+ results = trt_model("https://ultralytics.com/images/bus.jpg")
+ ```
+
+ === "CLI"
+
+ ```bash
+ # Export a YOLOv8n PyTorch model to TensorRT format with DLA enabled (only works with FP16 or INT8)
+ yolo export model=yolov8n.pt format=engine device="dla:0" half=True # dla:0 or dla:1 corresponds to the DLA cores
+
+ # Run inference with the exported model on the DLA
+ yolo predict model=yolov8n.engine source='https://ultralytics.com/images/bus.jpg'
+ ```
+
!!! note
Visit the [Export page](../modes/export.md#arguments) to access additional arguments when exporting models to different model formats
diff --git a/docs/en/macros/export-args.md b/docs/en/macros/export-args.md
index 99dd5f4d0a..242090d7c6 100644
--- a/docs/en/macros/export-args.md
+++ b/docs/en/macros/export-args.md
@@ -12,3 +12,4 @@
| `workspace` | `float` | `4.0` | Sets the maximum workspace size in GiB for TensorRT optimizations, balancing memory usage and performance. |
| `nms` | `bool` | `False` | Adds Non-Maximum Suppression (NMS) to the CoreML export, essential for accurate and efficient detection post-processing. |
| `batch` | `int` | `1` | Specifies export model batch inference size or the max number of images the exported model will process concurrently in `predict` mode. |
+| `device` | `str` | `None` | Specifies the device for exporting: GPU (`device=0`), CPU (`device=cpu`), MPS for Apple silicon (`device=mps`) or DLA for NVIDIA Jetson (`device=dla:0` or `device=dla:1`). |
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index 09e2fde550..ac22fe8620 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO ๐, AGPL-3.0 license
-__version__ = "8.3.20"
+__version__ = "8.3.21"
import os
diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py
index b25b837fd1..da2e746cbe 100644
--- a/ultralytics/engine/exporter.py
+++ b/ultralytics/engine/exporter.py
@@ -194,6 +194,11 @@ class Exporter:
is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs))
# Device
+ dla = None
+ if fmt == "engine" and "dla" in self.args.device:
+ dla = self.args.device.split(":")[-1]
+ assert dla in {"0", "1"}, f"Expected self.args.device='dla:0' or 'dla:1, but got {self.args.device}."
+ self.args.device = "0"
if fmt == "engine" and self.args.device is None:
LOGGER.warning("WARNING โ ๏ธ TensorRT requires GPU export, automatically assigning device=0")
self.args.device = "0"
@@ -309,7 +314,7 @@ class Exporter:
if jit or ncnn: # TorchScript
f[0], _ = self.export_torchscript()
if engine: # TensorRT required before ONNX
- f[1], _ = self.export_engine()
+ f[1], _ = self.export_engine(dla=dla)
if onnx: # ONNX
f[2], _ = self.export_onnx()
if xml: # OpenVINO
@@ -682,7 +687,7 @@ class Exporter:
return f, ct_model
@try_export
- def export_engine(self, prefix=colorstr("TensorRT:")):
+ def export_engine(self, dla=None, prefix=colorstr("TensorRT:")):
"""YOLO TensorRT export https://developer.nvidia.com/tensorrt."""
assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
f_onnx, _ = self.export_onnx() # run before TRT import https://github.com/ultralytics/ultralytics/issues/7016
@@ -717,6 +722,20 @@ class Exporter:
network = builder.create_network(flag)
half = builder.platform_has_fast_fp16 and self.args.half
int8 = builder.platform_has_fast_int8 and self.args.int8
+
+ # Optionally switch to DLA if enabled
+ if dla is not None:
+ if not IS_JETSON:
+ raise ValueError("DLA is only available on NVIDIA Jetson devices")
+ LOGGER.info(f"{prefix} enabling DLA on core {dla}...")
+ if not self.args.half and not self.args.int8:
+ raise ValueError(
+ "DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again."
+ )
+ config.default_device_type = trt.DeviceType.DLA
+ config.DLA_core = int(dla)
+ config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
+
# Read ONNX file
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(f_onnx):
From be40a45ec32d9173ff42b4bdb00dbfe5a8c7a838 Mon Sep 17 00:00:00 2001
From: Laughing <61612323+Laughing-q@users.noreply.github.com>
Date: Thu, 24 Oct 2024 17:58:10 +0800
Subject: [PATCH 40/40] Fix DLA export when device=None (#17128)
Co-authored-by: UltralyticsAssistant
---
ultralytics/engine/exporter.py | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py
index da2e746cbe..6d403a2afb 100644
--- a/ultralytics/engine/exporter.py
+++ b/ultralytics/engine/exporter.py
@@ -195,13 +195,12 @@ class Exporter:
# Device
dla = None
- if fmt == "engine" and "dla" in self.args.device:
- dla = self.args.device.split(":")[-1]
- assert dla in {"0", "1"}, f"Expected self.args.device='dla:0' or 'dla:1, but got {self.args.device}."
- self.args.device = "0"
if fmt == "engine" and self.args.device is None:
LOGGER.warning("WARNING โ ๏ธ TensorRT requires GPU export, automatically assigning device=0")
self.args.device = "0"
+ if fmt == "engine" and "dla" in str(self.args.device): # convert int/list to str first
+ dla = self.args.device.split(":")[-1]
+ assert dla in {"0", "1"}, f"Expected self.args.device='dla:0' or 'dla:1, but got {self.args.device}."
self.device = select_device("cpu" if self.args.device is None else self.args.device)
# Checks