From bfa6f9a8e76a6bc1799d9505935c4ad116d8baa1 Mon Sep 17 00:00:00 2001 From: Muhammad Rizwan Munawar Date: Thu, 31 Oct 2024 16:42:05 +0500 Subject: [PATCH 1/6] Update `sam.md` and `sam-2.md` (#17286) --- docs/en/models/sam-2.md | 4 ++-- docs/en/models/sam.md | 4 ++-- ultralytics/data/annotator.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/en/models/sam-2.md b/docs/en/models/sam-2.md index d5e8888e2..9083899ea 100644 --- a/docs/en/models/sam-2.md +++ b/docs/en/models/sam-2.md @@ -250,13 +250,13 @@ To auto-annotate your dataset using SAM 2, follow this example: ```python from ultralytics.data.annotator import auto_annotate - auto_annotate(data="path/to/images", det_model="yolov8x.pt", sam_model="sam2_b.pt") + auto_annotate(data="path/to/images", det_model="yolo11x.pt", sam_model="sam2_b.pt") ``` | Argument | Type | Description | Default | | ------------ | ----------------------- | ------------------------------------------------------------------------------------------------------- | -------------- | | `data` | `str` | Path to a folder containing images to be annotated. | | -| `det_model` | `str`, optional | Pre-trained YOLO detection model. Defaults to 'yolov8x.pt'. | `'yolov8x.pt'` | +| `det_model` | `str`, optional | Pre-trained YOLO detection model. Defaults to 'yolo11x.pt'. | `'yolov8x.pt'` | | `sam_model` | `str`, optional | Pre-trained SAM 2 segmentation model. Defaults to 'sam2_b.pt'. | `'sam2_b.pt'` | | `device` | `str`, optional | Device to run the models on. Defaults to an empty string (CPU or GPU, if available). | | | `output_dir` | `str`, `None`, optional | Directory to save the annotated results. Defaults to a 'labels' folder in the same directory as 'data'. | `None` | diff --git a/docs/en/models/sam.md b/docs/en/models/sam.md index f9acad72d..c38b06e35 100644 --- a/docs/en/models/sam.md +++ b/docs/en/models/sam.md @@ -205,13 +205,13 @@ To auto-annotate your dataset with the Ultralytics framework, use the `auto_anno ```python from ultralytics.data.annotator import auto_annotate - auto_annotate(data="path/to/images", det_model="yolov8x.pt", sam_model="sam_b.pt") + auto_annotate(data="path/to/images", det_model="yolo11x.pt", sam_model="sam_b.pt") ``` | Argument | Type | Description | Default | | ------------ | --------------------- | ------------------------------------------------------------------------------------------------------- | -------------- | | `data` | `str` | Path to a folder containing images to be annotated. | | -| `det_model` | `str`, optional | Pre-trained YOLO detection model. Defaults to 'yolov8x.pt'. | `'yolov8x.pt'` | +| `det_model` | `str`, optional | Pre-trained YOLO detection model. Defaults to 'yolo11x.pt'. | `'yolov8x.pt'` | | `sam_model` | `str`, optional | Pre-trained SAM segmentation model. Defaults to 'sam_b.pt'. | `'sam_b.pt'` | | `device` | `str`, optional | Device to run the models on. Defaults to an empty string (CPU or GPU, if available). | | | `output_dir` | `str`, None, optional | Directory to save the annotated results. Defaults to a 'labels' folder in the same directory as 'data'. | `None` | diff --git a/ultralytics/data/annotator.py b/ultralytics/data/annotator.py index 30d02d9d7..3880741d3 100644 --- a/ultralytics/data/annotator.py +++ b/ultralytics/data/annotator.py @@ -5,7 +5,7 @@ from pathlib import Path from ultralytics import SAM, YOLO -def auto_annotate(data, det_model="yolov8x.pt", sam_model="sam_b.pt", device="", output_dir=None): +def auto_annotate(data, det_model="yolo11x.pt", sam_model="sam_b.pt", device="", output_dir=None): """ Automatically annotates images using a YOLO object detection model and a SAM segmentation model. From 66adbd79ad4cdb87082cec0023588a1dc4b6d88c Mon Sep 17 00:00:00 2001 From: Compunet <117437050+dme-compunet@users.noreply.github.com> Date: Thu, 31 Oct 2024 13:46:46 +0200 Subject: [PATCH 2/6] Update examples/README.md (#17284) --- examples/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/README.md b/examples/README.md index 22da53f29..ab875b3ba 100644 --- a/examples/README.md +++ b/examples/README.md @@ -8,7 +8,7 @@ This directory features a collection of real-world applications and walkthroughs | ----------------------------------------------------------------------------------------------------------------------------------------- | ------------------ | ----------------------------------------------------------------------------------------- | | [YOLO ONNX Detection Inference with C++](./YOLOv8-CPP-Inference) | C++/ONNX | [Justas Bartnykas](https://github.com/JustasBart) | | [YOLO OpenCV ONNX Detection Python](./YOLOv8-OpenCV-ONNX-Python) | OpenCV/Python/ONNX | [Farid Inawan](https://github.com/frdteknikelektro) | -| [YOLOv8 .NET ONNX ImageSharp](https://github.com/dme-compunet/YOLOv8) | C#/ONNX/ImageSharp | [Compunet](https://github.com/dme-compunet) | +| [YOLO C# ONNX-Runtime](https://github.com/dme-compunet/YoloSharp) | .NET/ONNX-Runtime | [Compunet](https://github.com/dme-compunet) | | [YOLO .Net ONNX Detection C#](https://www.nuget.org/packages/Yolov8.Net) | C# .Net | [Samuel Stainback](https://github.com/sstainba) | | [YOLOv8 on NVIDIA Jetson(TensorRT and DeepStream)](https://wiki.seeedstudio.com/YOLOv8-DeepStream-TRT-Jetson/) | Python | [Lakshantha](https://github.com/lakshanthad) | | [YOLOv8 ONNXRuntime Python](./YOLOv8-ONNXRuntime) | Python/ONNXRuntime | [Semih Demirel](https://github.com/semihhdemirel) | From b8783cad24dc751a33b708f454d9d745d91578e6 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 31 Oct 2024 12:48:24 +0100 Subject: [PATCH 3/6] Patch MNN test order bug (#17290) Co-authored-by: UltralyticsAssistant --- tests/test_exports.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_exports.py b/tests/test_exports.py index 12443fa30..a05f0e059 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -193,14 +193,14 @@ def test_export_paddle(): @pytest.mark.slow -def test_export_ncnn(): - """Test YOLO exports to NCNN format.""" - file = YOLO(MODEL).export(format="ncnn", imgsz=32) +def test_export_mnn(): + """Test YOLO exports to MNN format (WARNING: MNN test must precede NCNN test or CI error on Windows).""" + file = YOLO(MODEL).export(format="mnn", imgsz=32) YOLO(file)(SOURCE, imgsz=32) # exported model inference @pytest.mark.slow -def test_export_mnn(): - """Test YOLO exports to MNN format.""" - file = YOLO(MODEL).export(format="mnn", imgsz=32) +def test_export_ncnn(): + """Test YOLO exports to NCNN format.""" + file = YOLO(MODEL).export(format="ncnn", imgsz=32) YOLO(file)(SOURCE, imgsz=32) # exported model inference From c943a3b747cd445a31c116690823d17a5996b0b9 Mon Sep 17 00:00:00 2001 From: Muhammad Rizwan Munawar Date: Thu, 31 Oct 2024 16:52:14 +0500 Subject: [PATCH 4/6] Case-insensitive optimizer name (#17287) Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher --- ultralytics/engine/trainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 352067397..e82aed9e0 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -791,6 +791,8 @@ class BaseTrainer: else: # weight (with decay) g[0].append(param) + optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"} + name = {x.lower(): x for x in optimizers}.get(name.lower(), None) if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}: optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0) elif name == "RMSProp": @@ -799,9 +801,8 @@ class BaseTrainer: optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True) else: raise NotImplementedError( - f"Optimizer '{name}' not found in list of available optimizers " - f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]." - "To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics." + f"Optimizer '{name}' not found in list of available optimizers {optimizers}. " + "Request support for addition optimizers at https://github.com/ultralytics/ultralytics." ) optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay From e8743f2ac9f43d83143e6575598282a8ec55cd88 Mon Sep 17 00:00:00 2001 From: Muhammad Rizwan Munawar Date: Thu, 31 Oct 2024 17:35:26 +0500 Subject: [PATCH 5/6] Auto annotation new parameters for SAM models (#17288) Co-authored-by: UltralyticsAssistant --- docs/en/models/sam-2.md | 5 ++++- docs/en/models/sam.md | 5 ++++- ultralytics/data/annotator.py | 9 +++++++-- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/docs/en/models/sam-2.md b/docs/en/models/sam-2.md index 9083899ea..de5881c42 100644 --- a/docs/en/models/sam-2.md +++ b/docs/en/models/sam-2.md @@ -256,9 +256,12 @@ To auto-annotate your dataset using SAM 2, follow this example: | Argument | Type | Description | Default | | ------------ | ----------------------- | ------------------------------------------------------------------------------------------------------- | -------------- | | `data` | `str` | Path to a folder containing images to be annotated. | | -| `det_model` | `str`, optional | Pre-trained YOLO detection model. Defaults to 'yolo11x.pt'. | `'yolov8x.pt'` | +| `det_model` | `str`, optional | Pre-trained YOLO detection model. Defaults to 'yolo11x.pt'. | `'yolo11x.pt'` | | `sam_model` | `str`, optional | Pre-trained SAM 2 segmentation model. Defaults to 'sam2_b.pt'. | `'sam2_b.pt'` | | `device` | `str`, optional | Device to run the models on. Defaults to an empty string (CPU or GPU, if available). | | +| `conf` | `float`, optional | Confidence threshold for detection model; default is 0.25. | `0.25` | +| `iou` | `float`, optional | IoU threshold for filtering overlapping boxes in detection results; default is 0.45. | `0.45` | +| `imgsz` | `int`, optional | Input image resize dimension; default is 640. | `640` | | `output_dir` | `str`, `None`, optional | Directory to save the annotated results. Defaults to a 'labels' folder in the same directory as 'data'. | `None` | This function facilitates the rapid creation of high-quality segmentation datasets, ideal for researchers and developers aiming to accelerate their projects. diff --git a/docs/en/models/sam.md b/docs/en/models/sam.md index c38b06e35..fe4c01bd8 100644 --- a/docs/en/models/sam.md +++ b/docs/en/models/sam.md @@ -211,9 +211,12 @@ To auto-annotate your dataset with the Ultralytics framework, use the `auto_anno | Argument | Type | Description | Default | | ------------ | --------------------- | ------------------------------------------------------------------------------------------------------- | -------------- | | `data` | `str` | Path to a folder containing images to be annotated. | | -| `det_model` | `str`, optional | Pre-trained YOLO detection model. Defaults to 'yolo11x.pt'. | `'yolov8x.pt'` | +| `det_model` | `str`, optional | Pre-trained YOLO detection model. Defaults to 'yolo11x.pt'. | `'yolo11x.pt'` | | `sam_model` | `str`, optional | Pre-trained SAM segmentation model. Defaults to 'sam_b.pt'. | `'sam_b.pt'` | | `device` | `str`, optional | Device to run the models on. Defaults to an empty string (CPU or GPU, if available). | | +| `conf` | `float`, optional | Confidence threshold for detection model; default is 0.25. | `0.25` | +| `iou` | `float`, optional | IoU threshold for filtering overlapping boxes in detection results; default is 0.45. | `0.45` | +| `imgsz` | `int`, optional | Input image resize dimension; default is 640. | `640` | | `output_dir` | `str`, None, optional | Directory to save the annotated results. Defaults to a 'labels' folder in the same directory as 'data'. | `None` | The `auto_annotate` function takes the path to your images, with optional arguments for specifying the pre-trained detection and SAM segmentation models, the device to run the models on, and the output directory for saving the annotated results. diff --git a/ultralytics/data/annotator.py b/ultralytics/data/annotator.py index 3880741d3..64ee9af6c 100644 --- a/ultralytics/data/annotator.py +++ b/ultralytics/data/annotator.py @@ -5,7 +5,9 @@ from pathlib import Path from ultralytics import SAM, YOLO -def auto_annotate(data, det_model="yolo11x.pt", sam_model="sam_b.pt", device="", output_dir=None): +def auto_annotate( + data, det_model="yolo11x.pt", sam_model="sam_b.pt", device="", conf=0.25, iou=0.45, imgsz=640, output_dir=None +): """ Automatically annotates images using a YOLO object detection model and a SAM segmentation model. @@ -17,6 +19,9 @@ def auto_annotate(data, det_model="yolo11x.pt", sam_model="sam_b.pt", device="", det_model (str): Path or name of the pre-trained YOLO detection model. sam_model (str): Path or name of the pre-trained SAM segmentation model. device (str): Device to run the models on (e.g., 'cpu', 'cuda', '0'). + conf (float): Confidence threshold for detection model; default is 0.25. + iou (float): IoU threshold for filtering overlapping boxes in detection results; default is 0.45. + imgsz (int): Input image resize dimension; default is 640. output_dir (str | None): Directory to save the annotated results. If None, a default directory is created. Examples: @@ -36,7 +41,7 @@ def auto_annotate(data, det_model="yolo11x.pt", sam_model="sam_b.pt", device="", output_dir = data.parent / f"{data.stem}_auto_annotate_labels" Path(output_dir).mkdir(exist_ok=True, parents=True) - det_results = det_model(data, stream=True, device=device) + det_results = det_model(data, stream=True, device=device, conf=conf, iou=iou, imgsz=imgsz) for result in det_results: class_ids = result.boxes.cls.int().tolist() # noqa From f4e7756bff5c1b0c246afb491f42e9f9de84dc84 Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Thu, 31 Oct 2024 20:36:27 +0800 Subject: [PATCH 6/6] `ultralytics 8.3.26` EdgeTPU Pose models fix (#17281) Co-authored-by: Glenn Jocher Co-authored-by: UltralyticsAssistant --- ultralytics/__init__.py | 2 +- ultralytics/nn/autobackend.py | 3 +++ ultralytics/nn/modules/head.py | 18 +++++++++++++++--- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index c847dd4d1..fedf8629a 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.3.25" +__version__ = "8.3.26" import os diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index 245e42c4e..cef05a357 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -663,6 +663,9 @@ class AutoBackend(nn.Module): else: x[:, [0, 2]] *= w x[:, [1, 3]] *= h + if self.task == "pose": + x[:, 5::3] *= w + x[:, 6::3] *= h y.append(x) # TF segment fixes: export is reversed vs ONNX export and protos are transposed if len(y) == 2: # segment with (det, proto) output order reversed diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 4bc1fa25e..84c31709c 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -246,9 +246,21 @@ class Pose(Detect): def kpts_decode(self, bs, kpts): """Decodes keypoints.""" ndim = self.kpt_shape[1] - if self.export: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug - y = kpts.view(bs, *self.kpt_shape, -1) - a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides + if self.export: + if self.format in { + "tflite", + "edgetpu", + }: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug + # Precompute normalization factor to increase numerical stability + y = kpts.view(bs, *self.kpt_shape, -1) + grid_h, grid_w = self.shape[2], self.shape[3] + grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1) + norm = self.strides / (self.stride[0] * grid_size) + a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm + else: + # NCNN fix + y = kpts.view(bs, *self.kpt_shape, -1) + a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides if ndim == 3: a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2) return a.view(bs, self.nk, -1)