Merge branch 'main' into quan

test-quan
Francesco Mattioli 1 month ago committed by GitHub
commit eb4642abe8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 7
      docs/en/models/sam-2.md
  2. 7
      docs/en/models/sam.md
  3. 2
      examples/README.md
  4. 12
      tests/test_exports.py
  5. 2
      ultralytics/__init__.py
  6. 9
      ultralytics/data/annotator.py
  7. 7
      ultralytics/engine/trainer.py
  8. 3
      ultralytics/nn/autobackend.py
  9. 18
      ultralytics/nn/modules/head.py

@ -250,15 +250,18 @@ To auto-annotate your dataset using SAM 2, follow this example:
```python
from ultralytics.data.annotator import auto_annotate
auto_annotate(data="path/to/images", det_model="yolov8x.pt", sam_model="sam2_b.pt")
auto_annotate(data="path/to/images", det_model="yolo11x.pt", sam_model="sam2_b.pt")
```
| Argument | Type | Description | Default |
| ------------ | ----------------------- | ------------------------------------------------------------------------------------------------------- | -------------- |
| `data` | `str` | Path to a folder containing images to be annotated. | |
| `det_model` | `str`, optional | Pre-trained YOLO detection model. Defaults to 'yolov8x.pt'. | `'yolov8x.pt'` |
| `det_model` | `str`, optional | Pre-trained YOLO detection model. Defaults to 'yolo11x.pt'. | `'yolo11x.pt'` |
| `sam_model` | `str`, optional | Pre-trained SAM 2 segmentation model. Defaults to 'sam2_b.pt'. | `'sam2_b.pt'` |
| `device` | `str`, optional | Device to run the models on. Defaults to an empty string (CPU or GPU, if available). | |
| `conf` | `float`, optional | Confidence threshold for detection model; default is 0.25. | `0.25` |
| `iou` | `float`, optional | IoU threshold for filtering overlapping boxes in detection results; default is 0.45. | `0.45` |
| `imgsz` | `int`, optional | Input image resize dimension; default is 640. | `640` |
| `output_dir` | `str`, `None`, optional | Directory to save the annotated results. Defaults to a 'labels' folder in the same directory as 'data'. | `None` |
This function facilitates the rapid creation of high-quality segmentation datasets, ideal for researchers and developers aiming to accelerate their projects.

@ -205,15 +205,18 @@ To auto-annotate your dataset with the Ultralytics framework, use the `auto_anno
```python
from ultralytics.data.annotator import auto_annotate
auto_annotate(data="path/to/images", det_model="yolov8x.pt", sam_model="sam_b.pt")
auto_annotate(data="path/to/images", det_model="yolo11x.pt", sam_model="sam_b.pt")
```
| Argument | Type | Description | Default |
| ------------ | --------------------- | ------------------------------------------------------------------------------------------------------- | -------------- |
| `data` | `str` | Path to a folder containing images to be annotated. | |
| `det_model` | `str`, optional | Pre-trained YOLO detection model. Defaults to 'yolov8x.pt'. | `'yolov8x.pt'` |
| `det_model` | `str`, optional | Pre-trained YOLO detection model. Defaults to 'yolo11x.pt'. | `'yolo11x.pt'` |
| `sam_model` | `str`, optional | Pre-trained SAM segmentation model. Defaults to 'sam_b.pt'. | `'sam_b.pt'` |
| `device` | `str`, optional | Device to run the models on. Defaults to an empty string (CPU or GPU, if available). | |
| `conf` | `float`, optional | Confidence threshold for detection model; default is 0.25. | `0.25` |
| `iou` | `float`, optional | IoU threshold for filtering overlapping boxes in detection results; default is 0.45. | `0.45` |
| `imgsz` | `int`, optional | Input image resize dimension; default is 640. | `640` |
| `output_dir` | `str`, None, optional | Directory to save the annotated results. Defaults to a 'labels' folder in the same directory as 'data'. | `None` |
The `auto_annotate` function takes the path to your images, with optional arguments for specifying the pre-trained detection and SAM segmentation models, the device to run the models on, and the output directory for saving the annotated results.

@ -8,7 +8,7 @@ This directory features a collection of real-world applications and walkthroughs
| ----------------------------------------------------------------------------------------------------------------------------------------- | ------------------ | ----------------------------------------------------------------------------------------- |
| [YOLO ONNX Detection Inference with C++](./YOLOv8-CPP-Inference) | C++/ONNX | [Justas Bartnykas](https://github.com/JustasBart) |
| [YOLO OpenCV ONNX Detection Python](./YOLOv8-OpenCV-ONNX-Python) | OpenCV/Python/ONNX | [Farid Inawan](https://github.com/frdteknikelektro) |
| [YOLOv8 .NET ONNX ImageSharp](https://github.com/dme-compunet/YOLOv8) | C#/ONNX/ImageSharp | [Compunet](https://github.com/dme-compunet) |
| [YOLO C# ONNX-Runtime](https://github.com/dme-compunet/YoloSharp) | .NET/ONNX-Runtime | [Compunet](https://github.com/dme-compunet) |
| [YOLO .Net ONNX Detection C#](https://www.nuget.org/packages/Yolov8.Net) | C# .Net | [Samuel Stainback](https://github.com/sstainba) |
| [YOLOv8 on NVIDIA Jetson(TensorRT and DeepStream)](https://wiki.seeedstudio.com/YOLOv8-DeepStream-TRT-Jetson/) | Python | [Lakshantha](https://github.com/lakshanthad) |
| [YOLOv8 ONNXRuntime Python](./YOLOv8-ONNXRuntime) | Python/ONNXRuntime | [Semih Demirel](https://github.com/semihhdemirel) |

@ -193,16 +193,16 @@ def test_export_paddle():
@pytest.mark.slow
def test_export_ncnn():
"""Test YOLO exports to NCNN format."""
file = YOLO(MODEL).export(format="ncnn", imgsz=32)
def test_export_mnn():
"""Test YOLO exports to MNN format (WARNING: MNN test must precede NCNN test or CI error on Windows)."""
file = YOLO(MODEL).export(format="mnn", imgsz=32)
YOLO(file)(SOURCE, imgsz=32) # exported model inference
@pytest.mark.slow
def test_export_mnn():
"""Test YOLO exports to MNN format."""
file = YOLO(MODEL).export(format="mnn", imgsz=32)
def test_export_ncnn():
"""Test YOLO exports to NCNN format."""
file = YOLO(MODEL).export(format="ncnn", imgsz=32)
YOLO(file)(SOURCE, imgsz=32) # exported model inference

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.25"
__version__ = "8.3.26"
import os

@ -5,7 +5,9 @@ from pathlib import Path
from ultralytics import SAM, YOLO
def auto_annotate(data, det_model="yolov8x.pt", sam_model="sam_b.pt", device="", output_dir=None):
def auto_annotate(
data, det_model="yolo11x.pt", sam_model="sam_b.pt", device="", conf=0.25, iou=0.45, imgsz=640, output_dir=None
):
"""
Automatically annotates images using a YOLO object detection model and a SAM segmentation model.
@ -17,6 +19,9 @@ def auto_annotate(data, det_model="yolov8x.pt", sam_model="sam_b.pt", device="",
det_model (str): Path or name of the pre-trained YOLO detection model.
sam_model (str): Path or name of the pre-trained SAM segmentation model.
device (str): Device to run the models on (e.g., 'cpu', 'cuda', '0').
conf (float): Confidence threshold for detection model; default is 0.25.
iou (float): IoU threshold for filtering overlapping boxes in detection results; default is 0.45.
imgsz (int): Input image resize dimension; default is 640.
output_dir (str | None): Directory to save the annotated results. If None, a default directory is created.
Examples:
@ -36,7 +41,7 @@ def auto_annotate(data, det_model="yolov8x.pt", sam_model="sam_b.pt", device="",
output_dir = data.parent / f"{data.stem}_auto_annotate_labels"
Path(output_dir).mkdir(exist_ok=True, parents=True)
det_results = det_model(data, stream=True, device=device)
det_results = det_model(data, stream=True, device=device, conf=conf, iou=iou, imgsz=imgsz)
for result in det_results:
class_ids = result.boxes.cls.int().tolist() # noqa

@ -791,6 +791,8 @@ class BaseTrainer:
else: # weight (with decay)
g[0].append(param)
optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"}
name = {x.lower(): x for x in optimizers}.get(name.lower(), None)
if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
elif name == "RMSProp":
@ -799,9 +801,8 @@ class BaseTrainer:
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
else:
raise NotImplementedError(
f"Optimizer '{name}' not found in list of available optimizers "
f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]."
"To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics."
f"Optimizer '{name}' not found in list of available optimizers {optimizers}. "
"Request support for addition optimizers at https://github.com/ultralytics/ultralytics."
)
optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay

@ -684,6 +684,9 @@ class AutoBackend(nn.Module):
else:
x[:, [0, 2]] *= w
x[:, [1, 3]] *= h
if self.task == "pose":
x[:, 5::3] *= w
x[:, 6::3] *= h
y.append(x)
# TF segment fixes: export is reversed vs ONNX export and protos are transposed
if len(y) == 2: # segment with (det, proto) output order reversed

@ -252,9 +252,21 @@ class Pose(Detect):
def kpts_decode(self, bs, kpts):
"""Decodes keypoints."""
ndim = self.kpt_shape[1]
if self.export: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
y = kpts.view(bs, *self.kpt_shape, -1)
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
if self.export:
if self.format in {
"tflite",
"edgetpu",
}: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
# Precompute normalization factor to increase numerical stability
y = kpts.view(bs, *self.kpt_shape, -1)
grid_h, grid_w = self.shape[2], self.shape[3]
grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
norm = self.strides / (self.stride[0] * grid_size)
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm
else:
# NCNN fix
y = kpts.view(bs, *self.kpt_shape, -1)
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
if ndim == 3:
a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
return a.view(bs, self.nk, -1)

Loading…
Cancel
Save