`ultralytics 8.1.39` add YOLO-World training (#9268)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/9447/head v8.1.39
Laughing 11 months ago committed by GitHub
parent 18036908d4
commit e9187c1296
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 1
      docs/en/datasets/detect/index.md
  2. 96
      docs/en/datasets/detect/lvis.md
  3. 1
      docs/en/datasets/index.md
  4. 2
      docs/en/models/fast-sam.md
  5. 86
      docs/en/models/yolo-world.md
  6. 4
      docs/en/reference/data/augment.md
  7. 4
      docs/en/reference/data/build.md
  8. 10
      docs/en/reference/data/dataset.md
  9. 8
      docs/en/reference/data/utils.md
  10. 15
      docs/en/reference/models/yolo/world/train.md
  11. 11
      docs/en/reference/models/yolo/world/train_world.md
  12. 1
      docs/mkdocs_github_authors.yaml
  13. 4
      mkdocs.yml
  14. 26
      tests/test_python.py
  15. 2
      ultralytics/__init__.py
  16. 1239
      ultralytics/cfg/datasets/lvis.yaml
  17. 20
      ultralytics/data/__init__.py
  18. 123
      ultralytics/data/augment.py
  19. 28
      ultralytics/data/build.py
  20. 26
      ultralytics/data/converter.py
  21. 161
      ultralytics/data/dataset.py
  22. 26
      ultralytics/data/utils.py
  23. 39
      ultralytics/engine/trainer.py
  24. 2
      ultralytics/models/fastsam/prompt.py
  25. 4
      ultralytics/models/yolo/__init__.py
  26. 53
      ultralytics/models/yolo/detect/val.py
  27. 1
      ultralytics/models/yolo/model.py
  28. 5
      ultralytics/models/yolo/world/__init__.py
  29. 91
      ultralytics/models/yolo/world/train.py
  30. 108
      ultralytics/models/yolo/world/train_world.py
  31. 6
      ultralytics/nn/modules/block.py
  32. 9
      ultralytics/nn/modules/head.py
  33. 42
      ultralytics/nn/tasks.py
  34. 2
      ultralytics/utils/loss.py

@ -74,6 +74,7 @@ Here is a list of the supported datasets and a brief description for each:
- [**Argoverse**](argoverse.md): A collection of sensor data collected from autonomous vehicles. It contains 3D tracking annotations for car objects.
- [**COCO**](coco.md): Common Objects in Context (COCO) is a large-scale object detection, segmentation, and captioning dataset with 80 object categories.
- [**LVIS**](lvis.md): LVIS is a large-scale object detection, segmentation, and captioning dataset with 1203 object categories.
- [**COCO8**](coco8.md): A smaller subset of the COCO dataset, COCO8 is more lightweight and faster to train.
- [**GlobalWheat2020**](globalwheat2020.md): A dataset containing images of wheat heads for the Global Wheat Challenge 2020.
- [**Objects365**](objects365.md): A large-scale object detection dataset with 365 object categories and 600k images, aimed at advancing object detection research.

@ -0,0 +1,96 @@
---
comments: true
description: Learn how LVIS, a leading dataset for object detection and segmentation, integrates with Ultralytics. Discover ways to use it for training YOLO models.
keywords: Ultralytics, LVIS dataset, object detection, YOLO, YOLO model training, image segmentation, computer vision, deep learning models
---
# LVIS Dataset
The [LVIS](https://www.lvisdataset.org/dataset) dataset is a large-scale, fine-grained vocabulary-level annotation dataset developed and released by Facebook AI Research (FAIR). It is primarily used as a research benchmark for object detection and instance segmentation with a large vocabulary of categories, aiming to drive further advancements in computer vision field.
## Key Features
- LVIS contains 160k images and 2M instance annotations for object detection, segmentation, and captioning tasks.
- The dataset comprises 1203 object categories, including common objects like cars, bicycles, and animals, as well as more specific categories such as umbrellas, handbags, and sports equipment.
- Annotations include object bounding boxes, segmentation masks, and captions for each image.
- LVIS provides standardized evaluation metrics like mean Average Precision (mAP) for object detection, and mean Average Recall (mAR) for segmentation tasks, making it suitable for comparing model performance.
- LVIS uses the exactly the same images as [COCO](./coco.md) dataset, but with different splits and different annotations.
## Dataset Structure
The LVIS dataset is split into three subsets:
1. **Train**: This subset contains 100k images for training object detection, segmentation, and captioning models.
2. **Val**: This subset has 20k images used for validation purposes during model training.
3. **Minival**: This subset is exactly the same as COCO val2017 set which has 5k images used for validation purposes during model training.
4. **Test**: This subset consists of 20k images used for testing and benchmarking the trained models. Ground truth annotations for this subset are not publicly available, and the results are submitted to the [LVIS evaluation server](https://eval.ai/web/challenges/challenge-page/675/overview) for performance evaluation.
## Applications
The LVIS dataset is widely used for training and evaluating deep learning models in object detection (such as YOLO, Faster R-CNN, and SSD), instance segmentation (such as Mask R-CNN). The dataset's diverse set of object categories, large number of annotated images, and standardized evaluation metrics make it an essential resource for computer vision researchers and practitioners.
## Dataset YAML
A YAML (Yet Another Markup Language) file is used to define the dataset configuration. It contains information about the dataset's paths, classes, and other relevant information. In the case of the LVIS dataset, the `lvis.yaml` file is maintained at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/lvis.yaml](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/lvis.yaml).
!!! Example "ultralytics/cfg/datasets/lvis.yaml"
```yaml
--8<-- "ultralytics/cfg/datasets/lvis.yaml"
```
## Usage
To train a YOLOv8n model on the LVIS dataset for 100 epochs with an image size of 640, you can use the following code snippets. For a comprehensive list of available arguments, refer to the model [Training](../../modes/train.md) page.
!!! Example "Train Example"
=== "Python"
```python
from ultralytics import YOLO
# Load a model
model = YOLO('yolov8n.pt') # load a pretrained model (recommended for training)
# Train the model
results = model.train(data='lvis.yaml', epochs=100, imgsz=640)
```
=== "CLI"
```bash
# Start training from a pretrained *.pt model
yolo detect train data=lvis.yaml model=yolov8n.pt epochs=100 imgsz=640
```
## Sample Images and Annotations
The LVIS dataset contains a diverse set of images with various object categories and complex scenes. Here are some examples of images from the dataset, along with their corresponding annotations:
![Dataset sample image](https://private-user-images.githubusercontent.com/61612323/316485965-a88c2e62-58d0-4f67-bc69-1418e42175e9.jpg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTEzNjcyNjYsIm5iZiI6MTcxMTM2Njk2NiwicGF0aCI6Ii82MTYxMjMyMy8zMTY0ODU5NjUtYTg4YzJlNjItNThkMC00ZjY3LWJjNjktMTQxOGU0MjE3NWU5LmpwZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDAzMjUlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwMzI1VDExNDI0NlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWZmMTVlNzE5MTBkOTZmNDQwNzJjNWQzYzM2NmEyMGMxODQ4ZDEyMjYwYmMyY2JjZDU5YzBmMDIyZGEwMGEwZDAmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.7thukPdnJKYuBmTk1ROUyqxxV3Ix5GeNLqyi4wSDYvA)
- **Mosaiced Image**: This image demonstrates a training batch composed of mosaiced dataset images. Mosaicing is a technique used during training that combines multiple images into a single image to increase the variety of objects and scenes within each training batch. This helps improve the model's ability to generalize to different object sizes, aspect ratios, and contexts.
The example showcases the variety and complexity of the images in the LVIS dataset and the benefits of using mosaicing during the training process.
## Citations and Acknowledgments
If you use the LVIS dataset in your research or development work, please cite the following paper:
!!! Quote ""
=== "BibTeX"
```bibtex
@inproceedings{gupta2019lvis,
title={{LVIS}: A Dataset for Large Vocabulary Instance Segmentation},
author={Gupta, Agrim and Dollar, Piotr and Girshick, Ross},
booktitle={Proceedings of the {IEEE} Conference on Computer Vision and Pattern Recognition},
year={2019}
}
```
We would like to acknowledge the LVIS Consortium for creating and maintaining this valuable resource for the computer vision community. For more information about the LVIS dataset and its creators, visit the [LVIS dataset website](https://www.lvisdataset.org/dataset).

@ -36,6 +36,7 @@ Bounding box object detection is a computer vision technique that involves detec
- [Argoverse](detect/argoverse.md): A dataset containing 3D tracking and motion forecasting data from urban environments with rich annotations.
- [COCO](detect/coco.md): A large-scale dataset designed for object detection, segmentation, and captioning with over 200K labeled images.
- [LVIS](lvis.md): A large-scale object detection, segmentation, and captioning dataset with 1203 object categories.
- [COCO8](detect/coco8.md): Contains the first 4 images from COCO train and COCO val, suitable for quick tests.
- [Global Wheat 2020](detect/globalwheat2020.md): A dataset of wheat head images collected from around the world for object detection and localization tasks.
- [Objects365](detect/objects365.md): A high-quality, large-scale dataset for object detection with 365 object categories and over 600K annotated images.

@ -147,7 +147,7 @@ FastSAM is also available directly from the [https://github.com/CASIA-IVA-Lab/Fa
4. Install the CLIP model:
```shell
pip install git+https://github.com/openai/CLIP.git
pip install git+https://github.com/ultralytics/CLIP.git
```
### Example Usage

@ -64,6 +64,39 @@ This section details the models available with their specific pre-trained weight
The YOLO-World models are easy to integrate into your Python applications. Ultralytics provides user-friendly Python API and CLI commands to streamline development.
### Train Usage
!!! Tip "Tip"
We strongly recommend to use `yolov8-worldv2` model for custom training, because it supports deterministic training and also easy to export other formats i.e onnx/tensorrt.
Object detection is straightforward with the `train` method, as illustrated below:
!!! Example
=== "Python"
PyTorch pretrained `*.pt` models as well as configuration `*.yaml` files can be passed to the `YOLOWorld()` class to create a model instance in python:
```python
from ultralytics import YOLOWorld
# Load a pretrained YOLOv8s-worldv2 model
model = YOLOWorld('yolov8s-worldv2.pt')
# Train the model on the COCO8 example dataset for 100 epochs
results = model.train(data='coco8.yaml', epochs=100, imgsz=640)
# Run inference with the YOLOv8n model on the 'bus.jpg' image
results = model('path/to/bus.jpg')
```
=== "CLI"
```bash
# Load a pretrained YOLOv8s-worldv2 model and train it on the COCO8 example dataset for 100 epochs
yolo train model=yolov8s-worldv2.yaml data=coco8.yaml epochs=100 imgsz=640
```
### Predict Usage
Object detection is straightforward with the `predict` method, as illustrated below:
@ -196,6 +229,59 @@ You can also save a model after setting custom classes. By doing this you create
This approach provides a powerful means of customizing state-of-the-art object detection models for specific tasks, making advanced AI more accessible and applicable to a broader range of practical applications.
## Reproduce official results from scratch(Experimental)
### Prepare datasets
- Train data
| Dataset | Type | Samples | Boxes | Annotation Files |
|-------------------------------------------------------------------|-----------|---------|-------|--------------------------------------------------------------------------------------------------------------------------------------------|
| [Objects365v1](https://opendatalab.com/OpenDataLab/Objects365_v1) | Detection | 609k | 9621k | [objects365_train.json](https://opendatalab.com/OpenDataLab/Objects365_v1) |
| [GQA](https://nlp.stanford.edu/data/gqa/images.zip) | Grounding | 621k | 3681k | [final_mixed_train_no_coco.json](https://huggingface.co/GLIPModel/GLIP/blob/main/mdetr_annotations/final_mixed_train_no_coco.json) |
| [Flickr30k](https://shannon.cs.illinois.edu/DenotationGraph/) | Grounding | 149k | 641k | [final_flickr_separateGT_train.json](https://huggingface.co/GLIPModel/GLIP/blob/main/mdetr_annotations/final_flickr_separateGT_train.json) |
- Val data
| Dataset | Type | Annotation Files |
|---------------------------------------------------------------------------------------------------------|-----------|--------------------------------------------------------------------------------------------------------|
| [LVIS minival](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/lvis.yaml) | Detection | [minival.txt](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/lvis.yaml) |
### Launch training from scratch
!!! Note
`WorldTrainerFromScratch` is highly customized to allow training yolo-world models on both detection datasets and grounding datasets simultaneously. More details please checkout [ultralytics.model.yolo.world.train_world.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train_world.py).
!!! Example
=== "Python"
```python
from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
from ultralytics import YOLOWorld
data = dict(
train=dict(
yolo_data=["Objects365.yaml"],
grounding_data=[
dict(
img_path="../datasets/flickr30k/images",
json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
),
dict(
img_path="../datasets/GQA/images",
json_file="../datasets/GQA/final_mixed_train_no_coco.json",
),
],
),
val=dict(yolo_data=["lvis.yaml"]),
)
model = YOLOWorld("yolov8s-worldv2.yaml")
model.train(data=data, batch=128, epochs=100, trainer=WorldTrainerFromScratch)
```
## Citations and Acknowledgements
We extend our gratitude to the [Tencent AILab Computer Vision Center](https://ai.tencent.com/) for their pioneering work in real-time open-vocabulary object detection with YOLO-World:

@ -59,6 +59,10 @@ keywords: Ultralytics, Data Augmentation, BaseTransform, MixUp, RandomHSV, Lette
<br><br>
## ::: ultralytics.data.augment.RandomLoadText
<br><br>
## ::: ultralytics.data.augment.ClassifyLetterBox
<br><br>

@ -27,6 +27,10 @@ keywords: Ultralytics, YOLO v3, Data build, DataLoader, InfiniteDataLoader, seed
<br><br>
## ::: ultralytics.data.build.build_grounding
<br><br>
## ::: ultralytics.data.build.build_dataloader
<br><br>

@ -19,14 +19,18 @@ keywords: Ultralytics, YOLO, YOLODataset, SemanticDataset, data handling, data m
<br><br>
## ::: ultralytics.data.dataset.SemanticDataset
## ::: ultralytics.data.dataset.YOLOMultiModalDataset
<br><br>
## ::: ultralytics.data.dataset.load_dataset_cache_file
## ::: ultralytics.data.dataset.GroundingDataset
<br><br>
## ::: ultralytics.data.dataset.save_dataset_cache_file
## ::: ultralytics.data.dataset.YOLOConcatDataset
<br><br>
## ::: ultralytics.data.dataset.SemanticDataset
<br><br>

@ -66,3 +66,11 @@ keywords: Ultralytics, data utils, YOLO, img2label_paths, exif_size, polygon2mas
## ::: ultralytics.data.utils.autosplit
<br><br>
## ::: ultralytics.data.utils.load_dataset_cache_file
<br><br>
## ::: ultralytics.data.utils.save_dataset_cache_file
<br><br>

@ -0,0 +1,15 @@
# Reference for `ultralytics/models/yolo/world/train.py`
!!! Note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/yolo/world/train.py) 🛠. Thank you 🙏!
<br><br>
## ::: ultralytics.models.yolo.world.train.WorldTrainer
<br><br>
## ::: ultralytics.models.yolo.world.train.on_pretrain_routine_end
<br><br>

@ -0,0 +1,11 @@
# Reference for `ultralytics/models/yolo/world/train_world.py`
!!! Note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train_world.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/yolo/world/train_world.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/yolo/world/train_world.py) 🛠. Thank you 🙏!
<br><br>
## ::: ultralytics.models.yolo.world.train_world.WorldTrainerFromScratch
<br><br>

@ -18,6 +18,7 @@ chr043416@gmail.com: RizwanMunawar
glenn.jocher@ultralytics.com: glenn-jocher
muhammadrizwanmunawar123@gmail.com: RizwanMunawar
not.committed.yet: null
plashchynski@gmail.com: plashchynski
priytosh.revolution@live.com: priytosh-tripathi
shuizhuyuanluo@126.com: null
xinwang614@gmail.com: GreatV

@ -240,6 +240,7 @@ nav:
- datasets/detect/index.md
- Argoverse: datasets/detect/argoverse.md
- COCO: datasets/detect/coco.md
- LVIS: datasets/detect/lvis.md
- COCO8: datasets/detect/coco8.md
- GlobalWheat2020: datasets/detect/globalwheat2020.md
- Objects365: datasets/detect/objects365.md
@ -492,6 +493,9 @@ nav:
- predict: reference/models/yolo/segment/predict.md
- train: reference/models/yolo/segment/train.md
- val: reference/models/yolo/segment/val.md
- world:
- train: reference/models/yolo/world/train.md
- train_world: reference/models/yolo/world/train_world.md
- nn:
- autobackend: reference/nn/autobackend.md
- modules:

@ -643,3 +643,29 @@ def test_yolo_world():
model = YOLO("yolov8s-world.pt") # no YOLOv8n-world model yet
model.set_classes(["tree", "window"])
model(ASSETS / "bus.jpg", conf=0.01)
# Training from yaml
model = YOLO("yolov8s-worldv2.yaml") # no YOLOv8n-world model yet
model.train(data="coco8.yaml", epochs=2, imgsz=32, cache="disk", batch=-1, close_mosaic=1, name="yolo-world")
model = YOLO("yolov8s-worldv2.pt") # no YOLOv8n-world model yet
# val
model.val(data="coco8.yaml", imgsz=32, save_txt=True, save_json=True)
# Training from pretrain
model.train(data="coco8.yaml", epochs=2, imgsz=32, cache="disk", batch=-1, close_mosaic=1, name="yolo-world")
# test WorWorldTrainerFromScratch
from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
model = YOLO("yolov8s-worldv2.yaml") # no YOLOv8n-world model yet
data = dict(train=dict(yolo_data=["coco8.yaml"]), val=dict(yolo_data=["coco8.yaml"]))
model.train(
data=data,
epochs=2,
imgsz=32,
cache="disk",
batch=-1,
close_mosaic=1,
name="yolo-world",
trainer=WorldTrainerFromScratch,
)

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.38"
__version__ = "8.1.39"
from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

File diff suppressed because it is too large Load Diff

@ -1,15 +1,31 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .base import BaseDataset
from .build import build_dataloader, build_yolo_dataset, load_inference_source
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
from .build import (
build_dataloader,
build_yolo_dataset,
build_grounding,
load_inference_source,
)
from .dataset import (
ClassificationDataset,
SemanticDataset,
YOLODataset,
YOLOMultiModalDataset,
GroundingDataset,
YOLOConcatDataset,
)
__all__ = (
"BaseDataset",
"ClassificationDataset",
"SemanticDataset",
"YOLODataset",
"YOLOMultiModalDataset",
"YOLOConcatDataset",
"GroundingDataset",
"build_yolo_dataset",
"build_grounding",
"build_dataloader",
"load_inference_source",
)

@ -3,6 +3,7 @@
import math
import random
from copy import deepcopy
from typing import Tuple, Union
import cv2
import numpy as np
@ -66,7 +67,7 @@ class Compose:
def __init__(self, transforms):
"""Initializes the Compose object with a list of transforms."""
self.transforms = transforms
self.transforms = transforms if isinstance(transforms, list) else [transforms]
def __call__(self, data):
"""Applies a series of transformations to input data."""
@ -78,6 +79,29 @@ class Compose:
"""Appends a new transform to the existing list of transforms."""
self.transforms.append(transform)
def insert(self, index, transform):
"""Inserts a new transform to the existing list of transforms."""
self.transforms.insert(index, transform)
def __getitem__(self, index: Union[list, int]) -> "Compose":
"""Retrieve a specific transform or a set of transforms using indexing."""
assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}"
index = [index] if isinstance(index, int) else index
return Compose([self.transforms[i] for i in index])
def __setitem__(self, index: Union[list, int], value: Union[list, int]) -> None:
"""Retrieve a specific transform or a set of transforms using indexing."""
assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}"
if isinstance(index, list):
assert isinstance(
value, list
), f"The indices should be the same type as values, but got {type(index)} and {type(value)}"
if isinstance(index, int):
index, value = [index], [value]
for i, v in zip(index, value):
assert i < len(self.transforms), f"list index {i} out of range {len(self.transforms)}."
self.transforms[i] = v
def tolist(self):
"""Converts the list of transforms to a standard Python list."""
return self.transforms
@ -118,6 +142,8 @@ class BaseMixTransform:
mix_labels[i] = self.pre_transform(data)
labels["mix_labels"] = mix_labels
# Update cls and texts
labels = self._update_label_text(labels)
# Mosaic or MixUp
labels = self._mix_transform(labels)
labels.pop("mix_labels", None)
@ -131,6 +157,22 @@ class BaseMixTransform:
"""Gets a list of shuffled indexes for mosaic augmentation."""
raise NotImplementedError
def _update_label_text(self, labels):
"""Update label text."""
if "texts" not in labels:
return labels
mix_texts = sum([labels["texts"]] + [x["texts"] for x in labels["mix_labels"]], [])
mix_texts = list({tuple(x) for x in mix_texts})
text2id = {text: i for i, text in enumerate(mix_texts)}
for label in [labels] + labels["mix_labels"]:
for i, l in enumerate(label["cls"].squeeze(-1).tolist()):
text = label["texts"][int(l)]
label["cls"][i] = text2id[tuple(text)]
label["texts"] = mix_texts
return labels
class Mosaic(BaseMixTransform):
"""
@ -320,6 +362,8 @@ class Mosaic(BaseMixTransform):
final_labels["instances"].clip(imgsz, imgsz)
good = final_labels["instances"].remove_zero_area_boxes()
final_labels["cls"] = final_labels["cls"][good]
if "texts" in mosaic_labels[0]:
final_labels["texts"] = mosaic_labels[0]["texts"]
return final_labels
@ -970,6 +1014,83 @@ class Format:
return masks, instances, cls
class RandomLoadText:
"""
Randomly sample positive texts and negative texts and update the class indices accordingly to the number of samples.
Attributes:
prompt_format (str): Format for prompt. Default is '{}'.
neg_samples (tuple[int]): A ranger to randomly sample negative texts, Default is (80, 80).
max_samples (int): The max number of different text samples in one image, Default is 80.
padding (bool): Whether to pad texts to max_samples. Default is False.
padding_value (str): The padding text. Default is "".
"""
def __init__(
self,
prompt_format: str = "{}",
neg_samples: Tuple[int, int] = (80, 80),
max_samples: int = 80,
padding: bool = False,
padding_value: str = "",
) -> None:
"""Initializes the RandomLoadText class with given parameters."""
self.prompt_format = prompt_format
self.neg_samples = neg_samples
self.max_samples = max_samples
self.padding = padding
self.padding_value = padding_value
def __call__(self, labels: dict) -> dict:
"""Return updated classes and texts."""
assert "texts" in labels, "No texts found in labels."
class_texts = labels["texts"]
num_classes = len(class_texts)
cls = np.asarray(labels.pop("cls"), dtype=int)
pos_labels = np.unique(cls).tolist()
if len(pos_labels) > self.max_samples:
pos_labels = set(random.sample(pos_labels, k=self.max_samples))
neg_samples = min(min(num_classes, self.max_samples) - len(pos_labels), random.randint(*self.neg_samples))
neg_labels = []
for i in range(num_classes):
if i not in pos_labels:
neg_labels.append(i)
neg_labels = random.sample(neg_labels, k=neg_samples)
sampled_labels = pos_labels + neg_labels
random.shuffle(sampled_labels)
label2ids = {label: i for i, label in enumerate(sampled_labels)}
valid_idx = np.zeros(len(labels["instances"]), dtype=bool)
new_cls = []
for i, label in enumerate(cls.squeeze(-1).tolist()):
if label not in label2ids:
continue
valid_idx[i] = True
new_cls.append([label2ids[label]])
labels["instances"] = labels["instances"][valid_idx]
labels["cls"] = np.array(new_cls)
# Randomly select one prompt when there's more than one prompts
texts = []
for label in sampled_labels:
prompts = class_texts[label]
assert len(prompts) > 0
prompt = self.prompt_format.format(prompts[random.randrange(len(prompts))])
texts.append(prompt)
if self.padding:
valid_labels = len(pos_labels) + len(neg_labels)
num_padding = self.max_samples - valid_labels
if num_padding > 0:
texts += [self.padding_value] * num_padding
labels["texts"] = texts
return labels
def v8_transforms(dataset, imgsz, hyp, stretch=False):
"""Convert images to a size suitable for YOLOv8 training."""
pre_transform = Compose(

@ -22,7 +22,7 @@ from ultralytics.data.loaders import (
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.utils import RANK, colorstr
from ultralytics.utils.checks import check_file
from .dataset import YOLODataset
from .dataset import YOLODataset, YOLOMultiModalDataset, GroundingDataset
from .utils import PIN_MEMORY
@ -82,9 +82,10 @@ def seed_worker(worker_id): # noqa
random.seed(worker_seed)
def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32):
def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False):
"""Build YOLO Dataset."""
return YOLODataset(
dataset = YOLOMultiModalDataset if multi_modal else YOLODataset
return dataset(
img_path=img_path,
imgsz=cfg.imgsz,
batch_size=batch,
@ -103,6 +104,27 @@ def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, str
)
def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32):
"""Build YOLO Dataset."""
return GroundingDataset(
img_path=img_path,
json_file=json_file,
imgsz=cfg.imgsz,
batch_size=batch,
augment=mode == "train", # augmentation
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
rect=cfg.rect or rect, # rectangular batches
cache=cfg.cache or None,
single_cls=cfg.single_cls or False,
stride=int(stride),
pad=0.0 if mode == "train" else 0.5,
prefix=colorstr(f"{mode}: "),
task=cfg.task,
classes=cfg.classes,
fraction=cfg.fraction if mode == "train" else 1.0,
)
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
"""Return an InfiniteDataLoader or DataLoader for training or validation set."""
batch = min(batch, len(dataset))

@ -219,6 +219,7 @@ def convert_coco(
use_segments=False,
use_keypoints=False,
cls91to80=True,
lvis=False,
):
"""
Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.
@ -229,12 +230,14 @@ def convert_coco(
use_segments (bool, optional): Whether to include segmentation masks in the output.
use_keypoints (bool, optional): Whether to include keypoint annotations in the output.
cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs.
lvis (bool, optional): Whether to convert data in lvis dataset way.
Example:
```python
from ultralytics.data.converter import convert_coco
convert_coco('../datasets/coco/annotations/', use_segments=True, use_keypoints=False, cls91to80=True)
convert_coco('../datasets/lvis/annotations/', use_segments=True, use_keypoints=False, cls91to80=False, lvis=True)
```
Output:
@ -251,8 +254,14 @@ def convert_coco(
# Import json
for json_file in sorted(Path(labels_dir).resolve().glob("*.json")):
fn = Path(save_dir) / "labels" / json_file.stem.replace("instances_", "") # folder name
lname = "" if lvis else json_file.stem.replace("instances_", "")
fn = Path(save_dir) / "labels" / lname # folder name
fn.mkdir(parents=True, exist_ok=True)
if lvis:
# NOTE: create folders for both train and val in advance,
# since LVIS val set contains images from COCO 2017 train in addition to the COCO 2017 val split.
(fn / "train2017").mkdir(parents=True, exist_ok=True)
(fn / "val2017").mkdir(parents=True, exist_ok=True)
with open(json_file) as f:
data = json.load(f)
@ -263,16 +272,20 @@ def convert_coco(
for ann in data["annotations"]:
imgToAnns[ann["image_id"]].append(ann)
image_txt = []
# Write labels file
for img_id, anns in TQDM(imgToAnns.items(), desc=f"Annotations {json_file}"):
img = images[f"{img_id:d}"]
h, w, f = img["height"], img["width"], img["file_name"]
h, w = img["height"], img["width"]
f = str(Path(img["coco_url"]).relative_to("http://images.cocodataset.org")) if lvis else img["file_name"]
if lvis:
image_txt.append(str(Path("./images") / f))
bboxes = []
segments = []
keypoints = []
for ann in anns:
if ann["iscrowd"]:
if ann.get("iscrowd", False):
continue
# The COCO box format is [top left x, top left y, width, height]
box = np.array(ann["bbox"], dtype=np.float64)
@ -314,7 +327,12 @@ def convert_coco(
) # cls, box or segments
file.write(("%g " * len(line)).rstrip() % line + "\n")
LOGGER.info(f"COCO data converted successfully.\nResults saved to {save_dir.resolve()}")
if lvis:
with open((Path(save_dir) / json_file.name.replace("lvis_v1_", "").replace(".json", ".txt")), "a") as f:
for l in image_txt:
f.write(f"{l}\n")
LOGGER.info(f"{'LVIS' if lvis else 'COCO'} data converted successfully.\nResults saved to {save_dir.resolve()}")
def convert_dota_to_yolo_obb(dota_root_path: str):

@ -1,20 +1,41 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
from itertools import repeat
from collections import defaultdict
from multiprocessing.pool import ThreadPool
from pathlib import Path
import cv2
import json
import numpy as np
import torch
import torchvision
from PIL import Image
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable
from torch.utils.data import ConcatDataset
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
from ultralytics.utils.ops import resample_segments
from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms
from .augment import (
Compose,
Format,
Instances,
LetterBox,
RandomLoadText,
classify_augmentations,
classify_transforms,
v8_transforms,
)
from .base import BaseDataset
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
from .utils import (
HELP_URL,
LOGGER,
get_hash,
img2label_paths,
verify_image,
verify_image_label,
load_dataset_cache_file,
save_dataset_cache_file,
)
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
DATASET_CACHE_VERSION = "1.0.3"
@ -105,7 +126,7 @@ class YOLODataset(BaseDataset):
x["hash"] = get_hash(self.label_files + self.im_files)
x["results"] = nf, nm, ne, nc, len(self.im_files)
x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x)
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
return x
def get_labels(self):
@ -339,31 +360,125 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
x["hash"] = get_hash([x[0] for x in self.samples])
x["results"] = nf, nc, len(samples), samples
x["msgs"] = msgs # warnings
save_dataset_cache_file(self.prefix, path, x)
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
return samples
def load_dataset_cache_file(path):
"""Load an Ultralytics *.cache dictionary from path."""
import gc
class YOLOMultiModalDataset(YOLODataset):
"""
Dataset class for loading object detection and/or segmentation labels in YOLO format.
Args:
data (dict, optional): A dataset YAML dictionary. Defaults to None.
task (str): An explicit arg to point current task, Defaults to 'detect'.
Returns:
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
"""
def __init__(self, *args, data=None, task="detect", **kwargs):
"""Initializes a dataset object for object detection tasks with optional specifications."""
super().__init__(*args, data=data, task=task, **kwargs)
def update_labels_info(self, label):
"""Add texts information for multi modal model training."""
labels = super().update_labels_info(label)
# NOTE: some categories are concatenated with its synonyms by `/`.
labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
return labels
def build_transforms(self, hyp=None):
"""Enhances data transformations with optional text augmentation for multi-modal training."""
transforms = super().build_transforms(hyp)
if self.augment:
# NOTE: hard-coded the args for now.
transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))
return transforms
class GroundingDataset(YOLODataset):
def __init__(self, *args, task="detect", json_file, **kwargs):
"""Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""
assert task == "detect", "`GroundingDataset` only support `detect` task for now!"
self.json_file = json_file
super().__init__(*args, task=task, data={}, **kwargs)
def get_img_files(self, img_path):
"""The image files would be read in `get_labels` function, return empty list here."""
return []
def get_labels(self):
"""Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""
labels = []
LOGGER.info("Loading annotation file...")
with open(self.json_file, "r") as f:
annotations = json.load(f)
images = {f'{x["id"]:d}': x for x in annotations["images"]}
imgToAnns = defaultdict(list)
for ann in annotations["annotations"]:
imgToAnns[ann["image_id"]].append(ann)
for img_id, anns in TQDM(imgToAnns.items(), desc=f"Reading annotations {self.json_file}"):
img = images[f"{img_id:d}"]
h, w, f = img["height"], img["width"], img["file_name"]
im_file = Path(self.img_path) / f
if not im_file.exists():
continue
self.im_files.append(str(im_file))
bboxes = []
cat2id = {}
texts = []
for ann in anns:
if ann["iscrowd"]:
continue
box = np.array(ann["bbox"], dtype=np.float32)
box[:2] += box[2:] / 2
box[[0, 2]] /= float(w)
box[[1, 3]] /= float(h)
if box[2] <= 0 or box[3] <= 0:
continue
cat_name = " ".join([img["caption"][t[0] : t[1]] for t in ann["tokens_positive"]])
if cat_name not in cat2id:
cat2id[cat_name] = len(cat2id)
texts.append([cat_name])
cls = cat2id[cat_name] # class
box = [cls] + box.tolist()
if box not in bboxes:
bboxes.append(box)
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
labels.append(
dict(
im_file=im_file,
shape=(h, w),
cls=lb[:, 0:1], # n, 1
bboxes=lb[:, 1:], # n, 4
normalized=True,
bbox_format="xywh",
texts=texts,
)
)
return labels
def build_transforms(self, hyp=None):
"""Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""
transforms = super().build_transforms(hyp)
if self.augment:
# NOTE: hard-coded the args for now.
transforms.insert(-1, RandomLoadText(max_samples=80, padding=True))
return transforms
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
cache = np.load(str(path), allow_pickle=True).item() # load dict
gc.enable()
return cache
class YOLOConcatDataset(ConcatDataset):
"""
Dataset as a concatenation of multiple datasets.
def save_dataset_cache_file(prefix, path, x):
"""Save an Ultralytics dataset *.cache dictionary x to path."""
x["version"] = DATASET_CACHE_VERSION # add cache version
if is_dir_writeable(path.parent):
if path.exists():
path.unlink() # remove *.cache file if exists
np.save(str(path), x) # save cache for next time
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
LOGGER.info(f"{prefix}New cache created: {path}")
else:
LOGGER.warning(f"{prefix}WARNING ⚠ Cache directory {path.parent} is not writeable, cache not saved.")
This class is useful to assemble different existing datasets.
"""
@staticmethod
def collate_fn(batch):
"""Collates data samples into batches."""
return YOLODataset.collate_fn(batch)
# TODO: support semantic segmentation

@ -29,6 +29,7 @@ from ultralytics.utils import (
emojis,
yaml_load,
yaml_save,
is_dir_writeable,
)
from ultralytics.utils.checks import check_file, check_font, is_ascii
from ultralytics.utils.downloads import download, safe_download, unzip_file
@ -303,7 +304,7 @@ def check_det_dataset(dataset, autodownload=True):
# Set paths
data["path"] = path # download scripts
for k in "train", "val", "test":
for k in "train", "val", "test", "minival":
if data.get(k): # prepend path
if isinstance(data[k], str):
x = (path / data[k]).resolve()
@ -649,3 +650,26 @@ def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annot
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
with open(path.parent / txt[i], "a") as f:
f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file
def load_dataset_cache_file(path):
"""Load an Ultralytics *.cache dictionary from path."""
import gc
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
cache = np.load(str(path), allow_pickle=True).item() # load dict
gc.enable()
return cache
def save_dataset_cache_file(prefix, path, x, version):
"""Save an Ultralytics dataset *.cache dictionary x to path."""
x["version"] = version # add cache version
if is_dir_writeable(path.parent):
if path.exists():
path.unlink() # remove *.cache file if exists
np.save(str(path), x) # save cache for next time
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
LOGGER.info(f"{prefix}New cache created: {path}")
else:
LOGGER.warning(f"{prefix}WARNING ⚠ Cache directory {path.parent} is not writeable, cache not saved.")

@ -126,22 +126,7 @@ class BaseTrainer:
# Model and Dataset
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
try:
if self.args.task == "classify":
self.data = check_cls_dataset(self.args.data)
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
"detect",
"segment",
"pose",
"obb",
):
self.data = check_det_dataset(self.args.data)
if "yaml_file" in self.data:
self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
except Exception as e:
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
self.trainset, self.testset = self.get_dataset(self.data)
self.trainset, self.testset = self.get_dataset()
self.ema = None
# Optimization utils init
@ -509,13 +494,27 @@ class BaseTrainer:
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
@staticmethod
def get_dataset(data):
def get_dataset(self):
"""
Get train, val path from data dict if it exists.
Returns None if data format is not recognized.
"""
try:
if self.args.task == "classify":
data = check_cls_dataset(self.args.data)
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
"detect",
"segment",
"pose",
"obb",
):
data = check_det_dataset(self.args.data)
if "yaml_file" in data:
self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
except Exception as e:
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
self.data = data
return data["train"], data.get("val") or data.get("test")
def setup_model(self):
@ -666,8 +665,8 @@ class BaseTrainer:
if ckpt is None:
return
best_fitness = 0.0
start_epoch = ckpt["epoch"] + 1
if ckpt["optimizer"] is not None:
start_epoch = ckpt.get("epoch", -1) + 1
if ckpt.get("optimizer", None) is not None:
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
best_fitness = ckpt["best_fitness"]
if self.ema and ckpt.get("ema"):

@ -35,7 +35,7 @@ class FastSAMPrompt:
except ImportError:
from ultralytics.utils.checks import check_requirements
check_requirements("git+https://github.com/openai/CLIP.git")
check_requirements("git+https://github.com/ultralytics/CLIP.git")
import clip
self.clip = clip

@ -1,7 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.models.yolo import classify, detect, obb, pose, segment
from ultralytics.models.yolo import classify, detect, obb, pose, segment, world
from .model import YOLO, YOLOWorld
__all__ = "classify", "segment", "detect", "pose", "obb", "YOLO", "YOLOWorld"
__all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld"

@ -33,6 +33,7 @@ class DetectionValidator(BaseValidator):
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.nt_per_class = None
self.is_coco = False
self.is_lvis = False
self.class_map = None
self.args.task = "detect"
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
@ -66,8 +67,9 @@ class DetectionValidator(BaseValidator):
"""Initialize evaluation metrics for YOLO."""
val = self.data.get(self.args.split, "") # validation path
self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000))
self.args.save_json |= self.is_coco and not self.training # run on final val if training COCO
self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(len(model.names)))
self.args.save_json |= (self.is_coco or self.is_lvis) and not self.training # run on final val if training COCO
self.names = model.names
self.nc = len(model.names)
self.metrics.names = self.names
@ -266,7 +268,8 @@ class DetectionValidator(BaseValidator):
self.jdict.append(
{
"image_id": image_id,
"category_id": self.class_map[int(p[5])],
"category_id": self.class_map[int(p[5])]
+ (1 if self.is_lvis else 0), # index starts from 1 if it's lvis
"bbox": [round(x, 3) for x in b],
"score": round(p[4], 5),
}
@ -274,26 +277,42 @@ class DetectionValidator(BaseValidator):
def eval_json(self, stats):
"""Evaluates YOLO output in JSON format and returns performance statistics."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
pred_json = self.save_dir / "predictions.json" # predictions
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
anno_json = (
self.data["path"]
/ "annotations"
/ ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
) # annotations
pkg = "pycocotools" if self.is_coco else "lvis"
LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...")
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements("pycocotools>=2.0.6")
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
for x in pred_json, anno_json:
assert x.is_file(), f"{x} file not found"
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
eval = COCOeval(anno, pred, "bbox")
check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3")
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
eval = COCOeval(anno, pred, "bbox")
else:
from lvis import LVIS, LVISEval
anno = LVIS(str(anno_json)) # init annotations api
pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path)
eval = LVISEval(anno, pred, "bbox")
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
eval.evaluate()
eval.accumulate()
eval.summarize()
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
if self.is_lvis:
eval.print_results() # explicitly call print_results
# update mAP50-95 and mAP50
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
eval.stats[:2] if self.is_coco else [eval.results["AP50"], eval.results["AP"]]
)
except Exception as e:
LOGGER.warning(f"pycocotools unable to run: {e}")
LOGGER.warning(f"{pkg} unable to run: {e}")
return stats

@ -83,6 +83,7 @@ class YOLOWorld(Model):
"model": WorldModel,
"validator": yolo.detect.DetectionValidator,
"predictor": yolo.detect.DetectionPredictor,
"trainer": yolo.world.WorldTrainer,
}
}

@ -0,0 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .train import WorldTrainer
__all__ = ["WorldTrainer"]

@ -0,0 +1,91 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.models import yolo
from ultralytics.nn.tasks import WorldModel
from ultralytics.utils import DEFAULT_CFG, RANK
from ultralytics.data import build_yolo_dataset
from ultralytics.utils.torch_utils import de_parallel
from ultralytics.utils.checks import check_requirements
import itertools
try:
import clip
except ImportError:
check_requirements("git+https://github.com/ultralytics/CLIP.git")
import clip
def on_pretrain_routine_end(trainer):
"""Callback."""
if RANK in (-1, 0):
# NOTE: for evaluation
names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
device = next(trainer.model.parameters()).device
text_model, _ = clip.load("ViT-B/32", device=device)
for p in text_model.parameters():
p.requires_grad_(False)
trainer.text_model = text_model
class WorldTrainer(yolo.detect.DetectionTrainer):
"""
A class to fine-tune a world model on a close-set dataset.
Example:
```python
from ultralytics.models.yolo.world import WorldModel
args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3)
trainer = WorldTrainer(overrides=args)
trainer.train()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a WorldTrainer object with given arguments."""
if overrides is None:
overrides = {}
super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return WorldModel initialized with specified config and weights."""
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
# NOTE: Following the official config, nc hard-coded to 80 for now.
model = WorldModel(
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
ch=3,
nc=min(self.data["nc"], 80),
verbose=verbose and RANK == -1,
)
if weights:
model.load(weights)
self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)
return model
def build_dataset(self, img_path, mode="train", batch=None):
"""
Build YOLO Dataset.
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
return build_yolo_dataset(
self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
)
def preprocess_batch(self, batch):
"""Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed."""
batch = super().preprocess_batch(batch)
# NOTE: add text features
texts = list(itertools.chain(*batch["texts"]))
text_token = clip.tokenize(texts).to(batch["img"].device)
txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
return batch

@ -0,0 +1,108 @@
from ultralytics.data import build_yolo_dataset, build_grounding, YOLOConcatDataset
from ultralytics.data.utils import check_det_dataset
from ultralytics.models.yolo.world import WorldTrainer
from ultralytics.utils.torch_utils import de_parallel
from ultralytics.utils import DEFAULT_CFG
class WorldTrainerFromScratch(WorldTrainer):
"""
A class extending the WorldTrainer class for training a world model from scratch on open-set dataset.
Example:
```python
from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
from ultralytics import YOLOWorld
data = dict(
train=dict(
yolo_data=["Objects365.yaml"],
grounding_data=[
dict(
img_path="../datasets/flickr30k/images",
json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
),
dict(
img_path="../datasets/GQA/images",
json_file="../datasets/GQA/final_mixed_train_no_coco.json",
),
],
),
val=dict(yolo_data=["lvis.yaml"]),
)
model = YOLOWorld("yolov8s-worldv2.yaml")
model.train(data=data, trainer=WorldTrainerFromScratch)
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a WorldTrainer object with given arguments."""
if overrides is None:
overrides = {}
super().__init__(cfg, overrides, _callbacks)
def build_dataset(self, img_path, mode="train", batch=None):
"""
Build YOLO Dataset.
Args:
img_path (List[str] | str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
if mode == "train":
dataset = [
build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True)
if isinstance(im_path, str)
else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
for im_path in img_path
]
return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0]
else:
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
def get_dataset(self):
"""
Get train, val path from data dict if it exists.
Returns None if data format is not recognized.
"""
final_data = dict()
data_yaml = self.args.data
assert data_yaml.get("train", False) # object365.yaml
assert data_yaml.get("val", False) # lvis.yaml
data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}."
val_split = "minival" if "lvis" in data["val"][0]["val"] else "val"
for d in data["val"]:
if d.get("minival") is None: # for lvis dataset
continue
d["minival"] = str(d["path"] / d["minival"])
for s in ["train", "val"]:
final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
# save grounding data if there's one
grounding_data = data_yaml[s].get("grounding_data")
if grounding_data is None:
continue
grounding_data = [grounding_data] if not isinstance(grounding_data, list) else grounding_data
for g in grounding_data:
assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
final_data[s] += grounding_data
# NOTE: to make training work properly, set `nc` and `names`
final_data["nc"] = data["val"][0]["nc"]
final_data["names"] = data["val"][0]["names"]
self.data = final_data
return final_data["train"], final_data["val"][0]
def plot_training_labels(self):
"""DO NOT plot labels."""
pass
def final_eval(self):
"""Performs final evaluation and validation for object detection YOLO-World model."""
val = self.args.data["val"]["yolo_data"][0]
self.validator.args.data = val
self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
return super().final_eval()

@ -519,7 +519,8 @@ class ContrastiveHead(nn.Module):
def __init__(self):
"""Initializes ContrastiveHead with specified region-text similarity parameters."""
super().__init__()
self.bias = nn.Parameter(torch.zeros([]))
# NOTE: use -10.0 to keep the init cls loss consistency with other losses
self.bias = nn.Parameter(torch.tensor([-10.0]))
self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())
def forward(self, x, w):
@ -542,7 +543,8 @@ class BNContrastiveHead(nn.Module):
"""Initialize ContrastiveHead with region-text similarity parameters."""
super().__init__()
self.norm = nn.BatchNorm2d(embed_dims)
self.bias = nn.Parameter(torch.zeros([]))
# NOTE: use -10.0 to keep the init cls loss consistency with other losses
self.bias = nn.Parameter(torch.tensor([-10.0]))
# use -1.0 is more stable
self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))

@ -250,6 +250,15 @@ class WorldDetect(Detect):
y = torch.cat((dbox, cls.sigmoid()), 1)
return y if self.export else (y, x)
def bias_init(self):
"""Initialize Detect() biases, WARNING: requires stride availability."""
m = self # self.model[-1] # Detect() module
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
a[-1].bias.data[:] = 1.0 # box
# b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
class RTDETRDecoder(nn.Module):
"""

@ -564,28 +564,28 @@ class WorldModel(DetectionModel):
self.clip_model = None # CLIP model placeholder
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def set_classes(self, text):
"""Perform a forward pass with optional profiling, visualization, and embedding extraction."""
def set_classes(self, text, batch=80, cache_clip_model=True):
"""Set classes in advance so that model could do offline-inference without clip model."""
try:
import clip
except ImportError:
check_requirements("git+https://github.com/openai/CLIP.git")
check_requirements("git+https://github.com/ultralytics/CLIP.git")
import clip
if not getattr(self, "clip_model", None): # for backwards compatibility of models lacking clip_model attribute
if (
not getattr(self, "clip_model", None) and cache_clip_model
): # for backwards compatibility of models lacking clip_model attribute
self.clip_model = clip.load("ViT-B/32")[0]
device = next(self.clip_model.parameters()).device
model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0]
device = next(model.parameters()).device
text_token = clip.tokenize(text).to(device)
txt_feats = self.clip_model.encode_text(text_token).to(dtype=torch.float32)
txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach()
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
self.model[-1].nc = len(text)
def init_criterion(self):
"""Initialize the loss criterion for the model."""
raise NotImplementedError
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
"""
Perform a forward pass through the model.
@ -593,13 +593,14 @@ class WorldModel(DetectionModel):
x (torch.Tensor): The input tensor.
profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None.
augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
embed (list, optional): A list of feature vectors/embeddings to return.
Returns:
(torch.Tensor): Model's output tensor.
"""
txt_feats = self.txt_feats.to(device=x.device, dtype=x.dtype)
txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
if len(txt_feats) != len(x):
txt_feats = txt_feats.repeat(len(x), 1, 1)
ori_txt_feats = txt_feats.clone()
@ -627,6 +628,21 @@ class WorldModel(DetectionModel):
return torch.unbind(torch.cat(embeddings, 1), dim=0)
return x
def loss(self, batch, preds=None):
"""
Compute loss.
Args:
batch (dict): Batch to compute loss on.
preds (torch.Tensor | List[torch.Tensor]): Predictions.
"""
if not hasattr(self, "criterion"):
self.criterion = self.init_criterion()
if preds is None:
preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
return self.criterion(preds, batch)
class Ensemble(nn.ModuleList):
"""Ensemble of models."""

@ -157,7 +157,7 @@ class v8DetectionLoss:
self.hyp = h
self.stride = m.stride # model strides
self.nc = m.nc # number of classes
self.no = m.no
self.no = m.nc + m.reg_max * 4
self.reg_max = m.reg_max
self.device = device

Loading…
Cancel
Save