From 946e18f79cfa7eb3e6b5a9aa6d34258434c8540c Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Mon, 4 Mar 2024 05:56:57 +0800 Subject: [PATCH] `ultralytics 8.1.21` Add YOLOv8-World-v2 models (#8580) Signed-off-by: Glenn Jocher Co-authored-by: Glenn Jocher Co-authored-by: UltralyticsAssistant --- .github/workflows/ci.yaml | 6 ++-- docs/en/models/yolo-world.md | 32 ++++++++++------- ultralytics/__init__.py | 2 +- ultralytics/cfg/models/v8/yolov8-world.yaml | 2 +- ...ov8-world-t2i.yaml => yolov8-worldv2.yaml} | 10 +++--- ultralytics/engine/exporter.py | 13 +++++-- ultralytics/engine/model.py | 14 ++++++-- ultralytics/engine/trainer.py | 2 ++ ultralytics/models/yolo/model.py | 6 ++-- ultralytics/utils/benchmarks.py | 36 ++++++++++++++----- ultralytics/utils/downloads.py | 3 +- ultralytics/utils/patches.py | 20 ++++++----- 12 files changed, 98 insertions(+), 48 deletions(-) rename ultralytics/cfg/models/v8/{yolov8-world-t2i.yaml => yolov8-worldv2.yaml} (83%) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 78148249b..9e1b2b086 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -118,9 +118,9 @@ jobs: run: | yolo checks pip list - # - name: Benchmark DetectionModel - # shell: bash - # run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}.pt' imgsz=160 verbose=0.318 + - name: Benchmark World DetectionModel + shell: bash + run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/yolov8s-worldv2.pt' imgsz=160 verbose=0.318 - name: Benchmark SegmentationModel shell: bash run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}-seg.pt' imgsz=160 verbose=0.281 diff --git a/docs/en/models/yolo-world.md b/docs/en/models/yolo-world.md index 954b5dd18..116d62dfb 100644 --- a/docs/en/models/yolo-world.md +++ b/docs/en/models/yolo-world.md @@ -36,21 +36,29 @@ This section details the models available with their specific pre-trained weight All the YOLOv8-World weights have been directly migrated from the official [YOLO-World](https://github.com/AILab-CVC/YOLO-World) repository, highlighting their excellent contributions. -| Model Type | Pre-trained Weights | Tasks Supported | Inference | Validation | Training | Export | -|---------------|-----------------------------------------------------------------------------------------------------|----------------------------------------|-----------|------------|----------|--------| -| YOLOv8s-world | [yolov8s-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8s-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ | -| YOLOv8m-world | [yolov8m-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8m-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ | -| YOLOv8l-world | [yolov8l-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8l-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ | -| YOLOv8x-world | [yolov8x-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8x-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ | +| Model Type | Pre-trained Weights | Tasks Supported | Inference | Validation | Training | Export | +|-----------------|-------------------------------------------------------------------------------------------------------|----------------------------------------|-----------|------------|----------|--------| +| YOLOv8s-world | [yolov8s-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8s-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ | +| YOLOv8s-worldv2 | [yolov8s-worldv2.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8s-worldv2.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ✅ | +| YOLOv8m-world | [yolov8m-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8m-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ | +| YOLOv8m-worldv2 | [yolov8m-worldv2.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8m-worldv2.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ✅ | +| YOLOv8l-world | [yolov8l-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8l-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ | +| YOLOv8l-worldv2 | [yolov8l-worldv2.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8l-worldv2.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ✅ | +| YOLOv8x-world | [yolov8x-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8x-world.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ❌ | +| YOLOv8x-worldv2 | [yolov8x-worldv2.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8x-worldv2.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ❌ | ✅ | ## Zero-shot Transfer on COCO Dataset -| Model Type | mAP | mAP50 | mAP75 | -|---------------|------|-------|-------| -| yolov8s-world | 37.4 | 52.0 | 40.6 | -| yolov8m-world | 42.0 | 57.0 | 45.6 | -| yolov8l-world | 45.7 | 61.3 | 49.8 | -| yolov8x-world | 47.0 | 63.0 | 51.2 | +| Model Type | mAP | mAP50 | mAP75 | +|-----------------|------|-------|-------| +| yolov8s-world | 37.4 | 52.0 | 40.6 | +| yolov8s-worldv2 | 37.7 | 52.2 | 41.0 | +| yolov8m-world | 42.0 | 57.0 | 45.6 | +| yolov8m-worldv2 | 43.0 | 58.4 | 46.8 | +| yolov8l-world | 45.7 | 61.3 | 49.8 | +| yolov8l-worldv2 | 45.8 | 61.3 | 49.8 | +| yolov8x-world | 47.0 | 63.0 | 51.2 | +| yolov8x-worldv2 | 47.1 | 62.8 | 51.4 | ## Usage Examples diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 666aea71f..81a263efd 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.1.20" +__version__ = "8.1.21" from ultralytics.data.explorer.explorer import Explorer from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld diff --git a/ultralytics/cfg/models/v8/yolov8-world.yaml b/ultralytics/cfg/models/v8/yolov8-world.yaml index 611ea1a9f..c21a7f002 100644 --- a/ultralytics/cfg/models/v8/yolov8-world.yaml +++ b/ultralytics/cfg/models/v8/yolov8-world.yaml @@ -1,5 +1,5 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect +# YOLOv8-World object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/tasks/detect # Parameters nc: 80 # number of classes diff --git a/ultralytics/cfg/models/v8/yolov8-world-t2i.yaml b/ultralytics/cfg/models/v8/yolov8-worldv2.yaml similarity index 83% rename from ultralytics/cfg/models/v8/yolov8-world-t2i.yaml rename to ultralytics/cfg/models/v8/yolov8-worldv2.yaml index 6b654adbc..322b97d4b 100644 --- a/ultralytics/cfg/models/v8/yolov8-world-t2i.yaml +++ b/ultralytics/cfg/models/v8/yolov8-worldv2.yaml @@ -1,5 +1,5 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect +# YOLOv8-World-v2 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/tasks/detect # Parameters nc: 80 # number of classes @@ -29,18 +29,18 @@ backbone: head: - [-1, 1, nn.Upsample, [None, 2, "nearest"]] - [[-1, 6], 1, Concat, [1]] # cat backbone P4 - - [-1, 2, C2fAttn, [512, 256, 8]] # 12 + - [-1, 3, C2fAttn, [512, 256, 8]] # 12 - [-1, 1, nn.Upsample, [None, 2, "nearest"]] - [[-1, 4], 1, Concat, [1]] # cat backbone P3 - - [-1, 2, C2fAttn, [256, 128, 4]] # 15 (P3/8-small) + - [-1, 3, C2fAttn, [256, 128, 4]] # 15 (P3/8-small) - [15, 1, Conv, [256, 3, 2]] - [[-1, 12], 1, Concat, [1]] # cat head P4 - - [-1, 2, C2fAttn, [512, 256, 8]] # 18 (P4/16-medium) + - [-1, 3, C2fAttn, [512, 256, 8]] # 18 (P4/16-medium) - [-1, 1, Conv, [512, 3, 2]] - [[-1, 9], 1, Concat, [1]] # cat head P5 - - [-1, 2, C2fAttn, [1024, 512, 16]] # 21 (P5/32-large) + - [-1, 3, C2fAttn, [1024, 512, 16]] # 21 (P5/32-large) - [[15, 18, 21], 1, WorldDetect, [nc, 512, True]] # Detect(P3, P4, P5) diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index 9dae52789..173983e21 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -68,7 +68,7 @@ from ultralytics.data.dataset import YOLODataset from ultralytics.data.utils import check_det_dataset from ultralytics.nn.autobackend import check_class_names, default_class_names from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder -from ultralytics.nn.tasks import DetectionModel, SegmentationModel +from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel from ultralytics.utils import ( ARM64, DEFAULT_CFG, @@ -201,6 +201,14 @@ class Exporter: assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'" if edgetpu and not LINUX: raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/") + print(type(model)) + if isinstance(model, WorldModel): + LOGGER.warning( + "WARNING ⚠️ YOLOWorld (original version) export is not supported to any format.\n" + "WARNING ⚠️ YOLOWorldv2 models (i.e. 'yolov8s-worldv2.pt') only support export to " + "(torchscript, onnx, openvino, engine, coreml) formats. " + "See https://docs.ultralytics.com/models/yolo-world for details." + ) # Input im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device) @@ -252,9 +260,10 @@ class Exporter: self.metadata = { "description": description, "author": "Ultralytics", - "license": "AGPL-3.0 https://ultralytics.com/license", "date": datetime.now().isoformat(), "version": __version__, + "license": "AGPL-3.0 License (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", "stride": int(max(model.stride)), "task": model.task, "batch": self.args.batch, diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index 061dfe5f1..dd000cf12 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -295,7 +295,7 @@ class Model(nn.Module): self.model.load(weights) return self - def save(self, filename: Union[str, Path] = "saved_model.pt") -> None: + def save(self, filename: Union[str, Path] = "saved_model.pt", use_dill=True) -> None: """ Saves the current model state to a file. @@ -303,12 +303,22 @@ class Model(nn.Module): Args: filename (str | Path): The name of the file to save the model to. Defaults to 'saved_model.pt'. + use_dill (bool): Whether to try using dill for serialization if available. Defaults to True. Raises: AssertionError: If the model is not a PyTorch model. """ self._check_is_pytorch_model() - torch.save(self.ckpt, filename) + from ultralytics import __version__ + from datetime import datetime + + updates = { + "date": datetime.now().isoformat(), + "version": __version__, + "license": "AGPL-3.0 License (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + } + torch.save({**self.ckpt, **updates}, filename, use_dill=use_dill) def info(self, detailed: bool = False, verbose: bool = True): """ diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 33821171e..f005f3418 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -488,6 +488,8 @@ class BaseTrainer: "train_results": results, "date": datetime.now().isoformat(), "version": __version__, + "license": "AGPL-3.0 (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", } # Save last and best diff --git a/ultralytics/models/yolo/model.py b/ultralytics/models/yolo/model.py index 44b0d9e8b..5a2dc24f1 100644 --- a/ultralytics/models/yolo/model.py +++ b/ultralytics/models/yolo/model.py @@ -13,8 +13,8 @@ class YOLO(Model): def __init__(self, model="yolov8n.pt", task=None, verbose=False): """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'.""" - stem = Path(model).stem # filename stem without suffix, i.e. "yolov8n" - if "-world" in stem: + model = Path(model) + if "-world" in model.stem and model.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model new_instance = YOLOWorld(model) self.__class__ = type(new_instance) self.__dict__ = new_instance.__dict__ @@ -67,7 +67,7 @@ class YOLOWorld(Model): Initializes the YOLOv8-World model with the given pre-trained model file. Supports *.pt and *.yaml formats. Args: - model (str): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'. + model (str | Path): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'. """ super().__init__(model=model, task="detect") diff --git a/ultralytics/utils/benchmarks.py b/ultralytics/utils/benchmarks.py index b98448da5..3bc63510e 100644 --- a/ultralytics/utils/benchmarks.py +++ b/ultralytics/utils/benchmarks.py @@ -32,7 +32,7 @@ from pathlib import Path import numpy as np import torch.cuda -from ultralytics import YOLO +from ultralytics import YOLO, YOLOWorld from ultralytics.cfg import TASK2DATA, TASK2METRIC from ultralytics.engine.exporter import export_formats from ultralytics.utils import ASSETS, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR @@ -84,14 +84,20 @@ def benchmark( emoji, filename = "❌", None # export defaults try: # Checks - if i == 9: + if i == 9: # Edge TPU assert LINUX, "Edge TPU export only supported on Linux" - elif i == 7: + elif i == 7: # TF GraphDef assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task" elif i in {5, 10}: # CoreML and TF.js assert MACOS or LINUX, "export only supported on macOS and Linux" if i in {3, 5}: # CoreML and OpenVINO assert not IS_PYTHON_3_12, "CoreML and OpenVINO not supported on Python 3.12" + if i in {6, 7, 8, 9, 10}: # All TF formats + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet" + if i in {11}: # Paddle + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet" + if i in {12}: # NCNN + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet" if "cpu" in device.type: assert cpu, "inference not supported on CPU" if "cuda" in device.type: @@ -261,7 +267,8 @@ class ProfileModels: """ return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops) - def iterative_sigma_clipping(self, data, sigma=2, max_iters=3): + @staticmethod + def iterative_sigma_clipping(data, sigma=2, max_iters=3): """Applies an iterative sigma clipping algorithm to the given data times number of iterations.""" data = np.array(data) for _ in range(max_iters): @@ -359,9 +366,13 @@ class ProfileModels: def generate_table_row(self, model_name, t_onnx, t_engine, model_info): """Generates a formatted string for a table row that includes model performance and metric details.""" layers, params, gradients, flops = model_info - return f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |" + return ( + f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± " + f"{t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |" + ) - def generate_results_dict(self, model_name, t_onnx, t_engine, model_info): + @staticmethod + def generate_results_dict(model_name, t_onnx, t_engine, model_info): """Generates a dictionary of model details including name, parameters, GFLOPS and speed metrics.""" layers, params, gradients, flops = model_info return { @@ -372,11 +383,18 @@ class ProfileModels: "model/speed_TensorRT(ms)": round(t_engine[0], 3), } - def print_table(self, table_rows): + @staticmethod + def print_table(table_rows): """Formats and prints a comparison table for different models with given statistics and performance data.""" gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU" - header = f"| Model | size
(pixels) | mAPval
50-95 | Speed
CPU ONNX
(ms) | Speed
{gpu} TensorRT
(ms) | params
(M) | FLOPs
(B) |" - separator = "|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|" + header = ( + f"| Model | size
(pixels) | mAPval
50-95 | Speed
CPU ONNX
(ms) | " + f"Speed
{gpu} TensorRT
(ms) | params
(M) | FLOPs
(B) |" + ) + separator = ( + "|-------------|---------------------|--------------------|------------------------------|" + "-----------------------------------|------------------|-----------------|" + ) print(f"\n\n{header}") print(separator) diff --git a/ultralytics/utils/downloads.py b/ultralytics/utils/downloads.py index 213145624..470fa83dc 100644 --- a/ultralytics/utils/downloads.py +++ b/ultralytics/utils/downloads.py @@ -20,7 +20,8 @@ GITHUB_ASSETS_NAMES = ( [f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")] + [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")] + [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")] - + [f"yolov8{k}-world.pt" for k in "sml"] + + [f"yolov8{k}-world.pt" for k in "smlx"] + + [f"yolov8{k}-worldv2.pt" for k in "smlx"] + [f"yolo_nas_{k}.pt" for k in "sml"] + [f"sam_{k}.pt" for k in "bl"] + [f"FastSAM-{k}.pt" for k in "sx"] diff --git a/ultralytics/utils/patches.py b/ultralytics/utils/patches.py index 703ec19d7..acbf5a99f 100644 --- a/ultralytics/utils/patches.py +++ b/ultralytics/utils/patches.py @@ -60,27 +60,29 @@ def imshow(winname: str, mat: np.ndarray): _torch_save = torch.save # copy to avoid recursion errors -def torch_save(*args, **kwargs): +def torch_save(*args, use_dill=True, **kwargs): """ - Use dill (if exists) to serialize the lambda functions where pickle does not do this. Also adds 3 retries with - exponential standoff in case of save failure to improve robustness to transient issues. + Optionally use dill to serialize lambda functions where pickle does not, adding robustness with 3 retries and + exponential standoff in case of save failure. Args: *args (tuple): Positional arguments to pass to torch.save. + use_dill (bool): Whether to try using dill for serialization if available. Defaults to True. **kwargs (dict): Keyword arguments to pass to torch.save. """ try: - import dill as pickle # noqa - except ImportError: + assert use_dill + import dill as pickle + except (AssertionError, ImportError): import pickle if "pickle_module" not in kwargs: - kwargs["pickle_module"] = pickle # noqa + kwargs["pickle_module"] = pickle for i in range(4): # 3 retries try: return _torch_save(*args, **kwargs) - except RuntimeError: # unable to save, possibly waiting for device to flush or anti-virus to finish scanning + except RuntimeError as e: # unable to save, possibly waiting for device to flush or antivirus scan if i == 3: - raise - time.sleep((2**i) / 2) # exponential standoff 0.5s, 1.0s, 2.0s + raise e + time.sleep((2**i) / 2) # exponential standoff: 0.5s, 1.0s, 2.0s