Merge branch 'main' into yolov9

pull/8571/head
Glenn Jocher 9 months ago committed by GitHub
commit 0b77c6e69d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 6
      .github/workflows/ci.yaml
  2. 32
      docs/en/models/yolo-world.md
  3. 4
      ultralytics/__init__.py
  4. 2
      ultralytics/cfg/models/v8/yolov8-world.yaml
  5. 10
      ultralytics/cfg/models/v8/yolov8-worldv2.yaml
  6. 13
      ultralytics/engine/exporter.py
  7. 14
      ultralytics/engine/model.py
  8. 2
      ultralytics/engine/trainer.py
  9. 6
      ultralytics/models/yolo/model.py
  10. 36
      ultralytics/utils/benchmarks.py
  11. 3
      ultralytics/utils/downloads.py
  12. 20
      ultralytics/utils/patches.py

@ -118,9 +118,9 @@ jobs:
run: |
yolo checks
pip list
# - name: Benchmark DetectionModel
# shell: bash
# run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}.pt' imgsz=160 verbose=0.318
- name: Benchmark World DetectionModel
shell: bash
run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/yolov8s-worldv2.pt' imgsz=160 verbose=0.318
- name: Benchmark SegmentationModel
shell: bash
run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}-seg.pt' imgsz=160 verbose=0.281

@ -36,21 +36,29 @@ This section details the models available with their specific pre-trained weight
All the YOLOv8-World weights have been directly migrated from the official [YOLO-World](https://github.com/AILab-CVC/YOLO-World) repository, highlighting their excellent contributions.
| Model Type | Pre-trained Weights | Tasks Supported | Inference | Validation | Training | Export |
|---------------|-----------------------------------------------------------------------------------------------------|----------------------------------------|-----------|------------|----------|--------|
| YOLOv8s-world | [yolov8s-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8s-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ |
| YOLOv8m-world | [yolov8m-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8m-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ |
| YOLOv8l-world | [yolov8l-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8l-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ |
| YOLOv8x-world | [yolov8x-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8x-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ |
| Model Type | Pre-trained Weights | Tasks Supported | Inference | Validation | Training | Export |
|-----------------|-------------------------------------------------------------------------------------------------------|----------------------------------------|-----------|------------|----------|--------|
| YOLOv8s-world | [yolov8s-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8s-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ |
| YOLOv8s-worldv2 | [yolov8s-worldv2.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8s-worldv2.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ✅ |
| YOLOv8m-world | [yolov8m-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8m-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ |
| YOLOv8m-worldv2 | [yolov8m-worldv2.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8m-worldv2.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ✅ |
| YOLOv8l-world | [yolov8l-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8l-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ |
| YOLOv8l-worldv2 | [yolov8l-worldv2.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8l-worldv2.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ✅ |
| YOLOv8x-world | [yolov8x-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8x-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ |
| YOLOv8x-worldv2 | [yolov8x-worldv2.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8x-worldv2.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ✅ |
## Zero-shot Transfer on COCO Dataset
| Model Type | mAP | mAP50 | mAP75 |
|---------------|------|-------|-------|
| yolov8s-world | 37.4 | 52.0 | 40.6 |
| yolov8m-world | 42.0 | 57.0 | 45.6 |
| yolov8l-world | 45.7 | 61.3 | 49.8 |
| yolov8x-world | 47.0 | 63.0 | 51.2 |
| Model Type | mAP | mAP50 | mAP75 |
|-----------------|------|-------|-------|
| yolov8s-world | 37.4 | 52.0 | 40.6 |
| yolov8s-worldv2 | 37.7 | 52.2 | 41.0 |
| yolov8m-world | 42.0 | 57.0 | 45.6 |
| yolov8m-worldv2 | 43.0 | 58.4 | 46.8 |
| yolov8l-world | 45.7 | 61.3 | 49.8 |
| yolov8l-worldv2 | 45.8 | 61.3 | 49.8 |
| yolov8x-world | 47.0 | 63.0 | 51.2 |
| yolov8x-worldv2 | 47.1 | 62.8 | 51.4 |
## Usage Examples

@ -1,7 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.22"
from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
from ultralytics.models.fastsam import FastSAM
@ -10,6 +8,7 @@ from ultralytics.utils import ASSETS, SETTINGS as settings
from ultralytics.utils.checks import check_yolo as checks
from ultralytics.utils.downloads import download
__all__ = (
"__version__",
"ASSETS",
@ -24,3 +23,4 @@ __all__ = (
"settings",
"Explorer",
)
__version__ = "8.1.22"

@ -1,5 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# YOLOv8-World object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 80 # number of classes

@ -1,5 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# YOLOv8-World-v2 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 80 # number of classes
@ -29,18 +29,18 @@ backbone:
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 2, C2fAttn, [512, 256, 8]] # 12
- [-1, 3, C2fAttn, [512, 256, 8]] # 12
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 2, C2fAttn, [256, 128, 4]] # 15 (P3/8-small)
- [-1, 3, C2fAttn, [256, 128, 4]] # 15 (P3/8-small)
- [15, 1, Conv, [256, 3, 2]]
- [[-1, 12], 1, Concat, [1]] # cat head P4
- [-1, 2, C2fAttn, [512, 256, 8]] # 18 (P4/16-medium)
- [-1, 3, C2fAttn, [512, 256, 8]] # 18 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 9], 1, Concat, [1]] # cat head P5
- [-1, 2, C2fAttn, [1024, 512, 16]] # 21 (P5/32-large)
- [-1, 3, C2fAttn, [1024, 512, 16]] # 21 (P5/32-large)
- [[15, 18, 21], 1, WorldDetect, [nc, 512, True]] # Detect(P3, P4, P5)

@ -68,7 +68,7 @@ from ultralytics.data.dataset import YOLODataset
from ultralytics.data.utils import check_det_dataset
from ultralytics.nn.autobackend import check_class_names, default_class_names
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
from ultralytics.nn.tasks import DetectionModel, SegmentationModel
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
from ultralytics.utils import (
ARM64,
DEFAULT_CFG,
@ -201,6 +201,14 @@ class Exporter:
assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
if edgetpu and not LINUX:
raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/")
print(type(model))
if isinstance(model, WorldModel):
LOGGER.warning(
"WARNING ⚠ YOLOWorld (original version) export is not supported to any format.\n"
"WARNING ⚠ YOLOWorldv2 models (i.e. 'yolov8s-worldv2.pt') only support export to "
"(torchscript, onnx, openvino, engine, coreml) formats. "
"See https://docs.ultralytics.com/models/yolo-world for details."
)
# Input
im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
@ -252,9 +260,10 @@ class Exporter:
self.metadata = {
"description": description,
"author": "Ultralytics",
"license": "AGPL-3.0 https://ultralytics.com/license",
"date": datetime.now().isoformat(),
"version": __version__,
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
"docs": "https://docs.ultralytics.com",
"stride": int(max(model.stride)),
"task": model.task,
"batch": self.args.batch,

@ -295,7 +295,7 @@ class Model(nn.Module):
self.model.load(weights)
return self
def save(self, filename: Union[str, Path] = "saved_model.pt") -> None:
def save(self, filename: Union[str, Path] = "saved_model.pt", use_dill=True) -> None:
"""
Saves the current model state to a file.
@ -303,12 +303,22 @@ class Model(nn.Module):
Args:
filename (str | Path): The name of the file to save the model to. Defaults to 'saved_model.pt'.
use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
Raises:
AssertionError: If the model is not a PyTorch model.
"""
self._check_is_pytorch_model()
torch.save(self.ckpt, filename)
from ultralytics import __version__
from datetime import datetime
updates = {
"date": datetime.now().isoformat(),
"version": __version__,
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
"docs": "https://docs.ultralytics.com",
}
torch.save({**self.ckpt, **updates}, filename, use_dill=use_dill)
def info(self, detailed: bool = False, verbose: bool = True):
"""

@ -488,6 +488,8 @@ class BaseTrainer:
"train_results": results,
"date": datetime.now().isoformat(),
"version": __version__,
"license": "AGPL-3.0 (https://ultralytics.com/license)",
"docs": "https://docs.ultralytics.com",
}
# Save last and best

@ -13,8 +13,8 @@ class YOLO(Model):
def __init__(self, model="yolov8n.pt", task=None, verbose=False):
"""Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
stem = Path(model).stem # filename stem without suffix, i.e. "yolov8n"
if "-world" in stem:
model = Path(model)
if "-world" in model.stem and model.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
new_instance = YOLOWorld(model)
self.__class__ = type(new_instance)
self.__dict__ = new_instance.__dict__
@ -67,7 +67,7 @@ class YOLOWorld(Model):
Initializes the YOLOv8-World model with the given pre-trained model file. Supports *.pt and *.yaml formats.
Args:
model (str): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'.
model (str | Path): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'.
"""
super().__init__(model=model, task="detect")

@ -32,7 +32,7 @@ from pathlib import Path
import numpy as np
import torch.cuda
from ultralytics import YOLO
from ultralytics import YOLO, YOLOWorld
from ultralytics.cfg import TASK2DATA, TASK2METRIC
from ultralytics.engine.exporter import export_formats
from ultralytics.utils import ASSETS, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR
@ -84,14 +84,20 @@ def benchmark(
emoji, filename = "", None # export defaults
try:
# Checks
if i == 9:
if i == 9: # Edge TPU
assert LINUX, "Edge TPU export only supported on Linux"
elif i == 7:
elif i == 7: # TF GraphDef
assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task"
elif i in {5, 10}: # CoreML and TF.js
assert MACOS or LINUX, "export only supported on macOS and Linux"
if i in {3, 5}: # CoreML and OpenVINO
assert not IS_PYTHON_3_12, "CoreML and OpenVINO not supported on Python 3.12"
if i in {6, 7, 8, 9, 10}: # All TF formats
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
if i in {11}: # Paddle
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
if i in {12}: # NCNN
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
if "cpu" in device.type:
assert cpu, "inference not supported on CPU"
if "cuda" in device.type:
@ -261,7 +267,8 @@ class ProfileModels:
"""
return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops)
def iterative_sigma_clipping(self, data, sigma=2, max_iters=3):
@staticmethod
def iterative_sigma_clipping(data, sigma=2, max_iters=3):
"""Applies an iterative sigma clipping algorithm to the given data times number of iterations."""
data = np.array(data)
for _ in range(max_iters):
@ -359,9 +366,13 @@ class ProfileModels:
def generate_table_row(self, model_name, t_onnx, t_engine, model_info):
"""Generates a formatted string for a table row that includes model performance and metric details."""
layers, params, gradients, flops = model_info
return f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |"
return (
f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± "
f"{t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |"
)
def generate_results_dict(self, model_name, t_onnx, t_engine, model_info):
@staticmethod
def generate_results_dict(model_name, t_onnx, t_engine, model_info):
"""Generates a dictionary of model details including name, parameters, GFLOPS and speed metrics."""
layers, params, gradients, flops = model_info
return {
@ -372,11 +383,18 @@ class ProfileModels:
"model/speed_TensorRT(ms)": round(t_engine[0], 3),
}
def print_table(self, table_rows):
@staticmethod
def print_table(table_rows):
"""Formats and prints a comparison table for different models with given statistics and performance data."""
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU"
header = f"| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |"
separator = "|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|"
header = (
f"| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | "
f"Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |"
)
separator = (
"|-------------|---------------------|--------------------|------------------------------|"
"-----------------------------------|------------------|-----------------|"
)
print(f"\n\n{header}")
print(separator)

@ -20,7 +20,8 @@ GITHUB_ASSETS_NAMES = (
[f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
+ [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
+ [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
+ [f"yolov8{k}-world.pt" for k in "sml"]
+ [f"yolov8{k}-world.pt" for k in "smlx"]
+ [f"yolov8{k}-worldv2.pt" for k in "smlx"]
+ [f"yolov9{k}.pt" for k in "ce"]
+ [f"yolo_nas_{k}.pt" for k in "sml"]
+ [f"sam_{k}.pt" for k in "bl"]

@ -60,27 +60,29 @@ def imshow(winname: str, mat: np.ndarray):
_torch_save = torch.save # copy to avoid recursion errors
def torch_save(*args, **kwargs):
def torch_save(*args, use_dill=True, **kwargs):
"""
Use dill (if exists) to serialize the lambda functions where pickle does not do this. Also adds 3 retries with
exponential standoff in case of save failure to improve robustness to transient issues.
Optionally use dill to serialize lambda functions where pickle does not, adding robustness with 3 retries and
exponential standoff in case of save failure.
Args:
*args (tuple): Positional arguments to pass to torch.save.
use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
**kwargs (dict): Keyword arguments to pass to torch.save.
"""
try:
import dill as pickle # noqa
except ImportError:
assert use_dill
import dill as pickle
except (AssertionError, ImportError):
import pickle
if "pickle_module" not in kwargs:
kwargs["pickle_module"] = pickle # noqa
kwargs["pickle_module"] = pickle
for i in range(4): # 3 retries
try:
return _torch_save(*args, **kwargs)
except RuntimeError: # unable to save, possibly waiting for device to flush or anti-virus to finish scanning
except RuntimeError as e: # unable to save, possibly waiting for device to flush or antivirus scan
if i == 3:
raise
time.sleep((2**i) / 2) # exponential standoff 0.5s, 1.0s, 2.0s
raise e
time.sleep((2**i) / 2) # exponential standoff: 0.5s, 1.0s, 2.0s

Loading…
Cancel
Save