`ultralytics 8.1.39` add YOLO-World training (#9268)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/9447/head v8.1.39
parent
18036908d4
commit
e9187c1296
34 changed files with 2161 additions and 95 deletions
@ -0,0 +1,96 @@ |
||||
--- |
||||
comments: true |
||||
description: Learn how LVIS, a leading dataset for object detection and segmentation, integrates with Ultralytics. Discover ways to use it for training YOLO models. |
||||
keywords: Ultralytics, LVIS dataset, object detection, YOLO, YOLO model training, image segmentation, computer vision, deep learning models |
||||
--- |
||||
|
||||
# LVIS Dataset |
||||
|
||||
The [LVIS](https://www.lvisdataset.org/dataset) dataset is a large-scale, fine-grained vocabulary-level annotation dataset developed and released by Facebook AI Research (FAIR). It is primarily used as a research benchmark for object detection and instance segmentation with a large vocabulary of categories, aiming to drive further advancements in computer vision field. |
||||
|
||||
## Key Features |
||||
|
||||
- LVIS contains 160k images and 2M instance annotations for object detection, segmentation, and captioning tasks. |
||||
- The dataset comprises 1203 object categories, including common objects like cars, bicycles, and animals, as well as more specific categories such as umbrellas, handbags, and sports equipment. |
||||
- Annotations include object bounding boxes, segmentation masks, and captions for each image. |
||||
- LVIS provides standardized evaluation metrics like mean Average Precision (mAP) for object detection, and mean Average Recall (mAR) for segmentation tasks, making it suitable for comparing model performance. |
||||
- LVIS uses the exactly the same images as [COCO](./coco.md) dataset, but with different splits and different annotations. |
||||
|
||||
## Dataset Structure |
||||
|
||||
The LVIS dataset is split into three subsets: |
||||
|
||||
1. **Train**: This subset contains 100k images for training object detection, segmentation, and captioning models. |
||||
2. **Val**: This subset has 20k images used for validation purposes during model training. |
||||
3. **Minival**: This subset is exactly the same as COCO val2017 set which has 5k images used for validation purposes during model training. |
||||
4. **Test**: This subset consists of 20k images used for testing and benchmarking the trained models. Ground truth annotations for this subset are not publicly available, and the results are submitted to the [LVIS evaluation server](https://eval.ai/web/challenges/challenge-page/675/overview) for performance evaluation. |
||||
|
||||
|
||||
## Applications |
||||
|
||||
The LVIS dataset is widely used for training and evaluating deep learning models in object detection (such as YOLO, Faster R-CNN, and SSD), instance segmentation (such as Mask R-CNN). The dataset's diverse set of object categories, large number of annotated images, and standardized evaluation metrics make it an essential resource for computer vision researchers and practitioners. |
||||
|
||||
## Dataset YAML |
||||
|
||||
A YAML (Yet Another Markup Language) file is used to define the dataset configuration. It contains information about the dataset's paths, classes, and other relevant information. In the case of the LVIS dataset, the `lvis.yaml` file is maintained at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/lvis.yaml](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/lvis.yaml). |
||||
|
||||
!!! Example "ultralytics/cfg/datasets/lvis.yaml" |
||||
|
||||
```yaml |
||||
--8<-- "ultralytics/cfg/datasets/lvis.yaml" |
||||
``` |
||||
|
||||
## Usage |
||||
|
||||
To train a YOLOv8n model on the LVIS dataset for 100 epochs with an image size of 640, you can use the following code snippets. For a comprehensive list of available arguments, refer to the model [Training](../../modes/train.md) page. |
||||
|
||||
!!! Example "Train Example" |
||||
|
||||
=== "Python" |
||||
|
||||
```python |
||||
from ultralytics import YOLO |
||||
|
||||
# Load a model |
||||
model = YOLO('yolov8n.pt') # load a pretrained model (recommended for training) |
||||
|
||||
# Train the model |
||||
results = model.train(data='lvis.yaml', epochs=100, imgsz=640) |
||||
``` |
||||
|
||||
=== "CLI" |
||||
|
||||
```bash |
||||
# Start training from a pretrained *.pt model |
||||
yolo detect train data=lvis.yaml model=yolov8n.pt epochs=100 imgsz=640 |
||||
``` |
||||
|
||||
## Sample Images and Annotations |
||||
|
||||
The LVIS dataset contains a diverse set of images with various object categories and complex scenes. Here are some examples of images from the dataset, along with their corresponding annotations: |
||||
|
||||
data:image/s3,"s3://crabby-images/1b44c/1b44c4be1552376de0deef5b86ff9d3c7377ddd7" alt="Dataset sample image" |
||||
|
||||
|
||||
- **Mosaiced Image**: This image demonstrates a training batch composed of mosaiced dataset images. Mosaicing is a technique used during training that combines multiple images into a single image to increase the variety of objects and scenes within each training batch. This helps improve the model's ability to generalize to different object sizes, aspect ratios, and contexts. |
||||
|
||||
The example showcases the variety and complexity of the images in the LVIS dataset and the benefits of using mosaicing during the training process. |
||||
|
||||
## Citations and Acknowledgments |
||||
|
||||
If you use the LVIS dataset in your research or development work, please cite the following paper: |
||||
|
||||
!!! Quote "" |
||||
|
||||
=== "BibTeX" |
||||
|
||||
```bibtex |
||||
@inproceedings{gupta2019lvis, |
||||
title={{LVIS}: A Dataset for Large Vocabulary Instance Segmentation}, |
||||
author={Gupta, Agrim and Dollar, Piotr and Girshick, Ross}, |
||||
booktitle={Proceedings of the {IEEE} Conference on Computer Vision and Pattern Recognition}, |
||||
year={2019} |
||||
} |
||||
``` |
||||
|
||||
We would like to acknowledge the LVIS Consortium for creating and maintaining this valuable resource for the computer vision community. For more information about the LVIS dataset and its creators, visit the [LVIS dataset website](https://www.lvisdataset.org/dataset). |
File diff suppressed because it is too large
Load Diff
@ -1,15 +1,31 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
from .base import BaseDataset |
||||
from .build import build_dataloader, build_yolo_dataset, load_inference_source |
||||
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset |
||||
from .build import ( |
||||
build_dataloader, |
||||
build_yolo_dataset, |
||||
build_grounding, |
||||
load_inference_source, |
||||
) |
||||
from .dataset import ( |
||||
ClassificationDataset, |
||||
SemanticDataset, |
||||
YOLODataset, |
||||
YOLOMultiModalDataset, |
||||
GroundingDataset, |
||||
YOLOConcatDataset, |
||||
) |
||||
|
||||
__all__ = ( |
||||
"BaseDataset", |
||||
"ClassificationDataset", |
||||
"SemanticDataset", |
||||
"YOLODataset", |
||||
"YOLOMultiModalDataset", |
||||
"YOLOConcatDataset", |
||||
"GroundingDataset", |
||||
"build_yolo_dataset", |
||||
"build_grounding", |
||||
"build_dataloader", |
||||
"load_inference_source", |
||||
) |
||||
|
@ -1,7 +1,7 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
from ultralytics.models.yolo import classify, detect, obb, pose, segment |
||||
from ultralytics.models.yolo import classify, detect, obb, pose, segment, world |
||||
|
||||
from .model import YOLO, YOLOWorld |
||||
|
||||
__all__ = "classify", "segment", "detect", "pose", "obb", "YOLO", "YOLOWorld" |
||||
__all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld" |
||||
|
@ -0,0 +1,5 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
from .train import WorldTrainer |
||||
|
||||
__all__ = ["WorldTrainer"] |
@ -0,0 +1,91 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
from ultralytics.models import yolo |
||||
from ultralytics.nn.tasks import WorldModel |
||||
from ultralytics.utils import DEFAULT_CFG, RANK |
||||
from ultralytics.data import build_yolo_dataset |
||||
from ultralytics.utils.torch_utils import de_parallel |
||||
from ultralytics.utils.checks import check_requirements |
||||
import itertools |
||||
|
||||
try: |
||||
import clip |
||||
except ImportError: |
||||
check_requirements("git+https://github.com/ultralytics/CLIP.git") |
||||
import clip |
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer): |
||||
"""Callback.""" |
||||
if RANK in (-1, 0): |
||||
# NOTE: for evaluation |
||||
names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())] |
||||
de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False) |
||||
device = next(trainer.model.parameters()).device |
||||
text_model, _ = clip.load("ViT-B/32", device=device) |
||||
for p in text_model.parameters(): |
||||
p.requires_grad_(False) |
||||
trainer.text_model = text_model |
||||
|
||||
|
||||
class WorldTrainer(yolo.detect.DetectionTrainer): |
||||
""" |
||||
A class to fine-tune a world model on a close-set dataset. |
||||
|
||||
Example: |
||||
```python |
||||
from ultralytics.models.yolo.world import WorldModel |
||||
|
||||
args = dict(model='yolov8s-world.pt', data='coco8.yaml', epochs=3) |
||||
trainer = WorldTrainer(overrides=args) |
||||
trainer.train() |
||||
``` |
||||
""" |
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): |
||||
"""Initialize a WorldTrainer object with given arguments.""" |
||||
if overrides is None: |
||||
overrides = {} |
||||
super().__init__(cfg, overrides, _callbacks) |
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True): |
||||
"""Return WorldModel initialized with specified config and weights.""" |
||||
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`. |
||||
# NOTE: Following the official config, nc hard-coded to 80 for now. |
||||
model = WorldModel( |
||||
cfg["yaml_file"] if isinstance(cfg, dict) else cfg, |
||||
ch=3, |
||||
nc=min(self.data["nc"], 80), |
||||
verbose=verbose and RANK == -1, |
||||
) |
||||
if weights: |
||||
model.load(weights) |
||||
self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end) |
||||
|
||||
return model |
||||
|
||||
def build_dataset(self, img_path, mode="train", batch=None): |
||||
""" |
||||
Build YOLO Dataset. |
||||
|
||||
Args: |
||||
img_path (str): Path to the folder containing images. |
||||
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. |
||||
batch (int, optional): Size of batches, this is for `rect`. Defaults to None. |
||||
""" |
||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) |
||||
return build_yolo_dataset( |
||||
self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train" |
||||
) |
||||
|
||||
def preprocess_batch(self, batch): |
||||
"""Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed.""" |
||||
batch = super().preprocess_batch(batch) |
||||
|
||||
# NOTE: add text features |
||||
texts = list(itertools.chain(*batch["texts"])) |
||||
text_token = clip.tokenize(texts).to(batch["img"].device) |
||||
txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32 |
||||
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True) |
||||
batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1]) |
||||
return batch |
@ -0,0 +1,108 @@ |
||||
from ultralytics.data import build_yolo_dataset, build_grounding, YOLOConcatDataset |
||||
from ultralytics.data.utils import check_det_dataset |
||||
from ultralytics.models.yolo.world import WorldTrainer |
||||
from ultralytics.utils.torch_utils import de_parallel |
||||
from ultralytics.utils import DEFAULT_CFG |
||||
|
||||
|
||||
class WorldTrainerFromScratch(WorldTrainer): |
||||
""" |
||||
A class extending the WorldTrainer class for training a world model from scratch on open-set dataset. |
||||
|
||||
Example: |
||||
```python |
||||
from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch |
||||
from ultralytics import YOLOWorld |
||||
|
||||
data = dict( |
||||
train=dict( |
||||
yolo_data=["Objects365.yaml"], |
||||
grounding_data=[ |
||||
dict( |
||||
img_path="../datasets/flickr30k/images", |
||||
json_file="../datasets/flickr30k/final_flickr_separateGT_train.json", |
||||
), |
||||
dict( |
||||
img_path="../datasets/GQA/images", |
||||
json_file="../datasets/GQA/final_mixed_train_no_coco.json", |
||||
), |
||||
], |
||||
), |
||||
val=dict(yolo_data=["lvis.yaml"]), |
||||
) |
||||
|
||||
model = YOLOWorld("yolov8s-worldv2.yaml") |
||||
model.train(data=data, trainer=WorldTrainerFromScratch) |
||||
``` |
||||
""" |
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): |
||||
"""Initialize a WorldTrainer object with given arguments.""" |
||||
if overrides is None: |
||||
overrides = {} |
||||
super().__init__(cfg, overrides, _callbacks) |
||||
|
||||
def build_dataset(self, img_path, mode="train", batch=None): |
||||
""" |
||||
Build YOLO Dataset. |
||||
|
||||
Args: |
||||
img_path (List[str] | str): Path to the folder containing images. |
||||
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. |
||||
batch (int, optional): Size of batches, this is for `rect`. Defaults to None. |
||||
""" |
||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) |
||||
if mode == "train": |
||||
dataset = [ |
||||
build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True) |
||||
if isinstance(im_path, str) |
||||
else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs) |
||||
for im_path in img_path |
||||
] |
||||
return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0] |
||||
else: |
||||
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs) |
||||
|
||||
def get_dataset(self): |
||||
""" |
||||
Get train, val path from data dict if it exists. |
||||
|
||||
Returns None if data format is not recognized. |
||||
""" |
||||
final_data = dict() |
||||
data_yaml = self.args.data |
||||
assert data_yaml.get("train", False) # object365.yaml |
||||
assert data_yaml.get("val", False) # lvis.yaml |
||||
data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()} |
||||
assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}." |
||||
val_split = "minival" if "lvis" in data["val"][0]["val"] else "val" |
||||
for d in data["val"]: |
||||
if d.get("minival") is None: # for lvis dataset |
||||
continue |
||||
d["minival"] = str(d["path"] / d["minival"]) |
||||
for s in ["train", "val"]: |
||||
final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]] |
||||
# save grounding data if there's one |
||||
grounding_data = data_yaml[s].get("grounding_data") |
||||
if grounding_data is None: |
||||
continue |
||||
grounding_data = [grounding_data] if not isinstance(grounding_data, list) else grounding_data |
||||
for g in grounding_data: |
||||
assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}" |
||||
final_data[s] += grounding_data |
||||
# NOTE: to make training work properly, set `nc` and `names` |
||||
final_data["nc"] = data["val"][0]["nc"] |
||||
final_data["names"] = data["val"][0]["names"] |
||||
self.data = final_data |
||||
return final_data["train"], final_data["val"][0] |
||||
|
||||
def plot_training_labels(self): |
||||
"""DO NOT plot labels.""" |
||||
pass |
||||
|
||||
def final_eval(self): |
||||
"""Performs final evaluation and validation for object detection YOLO-World model.""" |
||||
val = self.args.data["val"]["yolo_data"][0] |
||||
self.validator.args.data = val |
||||
self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val" |
||||
return super().final_eval() |
Loading…
Reference in new issue