`ultralytics 8.1.23` add YOLOv9-C and E models (#8571)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/8638/head v8.1.23
Laughing 9 months ago committed by GitHub
parent e138d701a0
commit 2071776a36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 57
      docs/en/models/yolov9.md
  2. 32
      docs/en/reference/nn/modules/block.md
  3. 2
      ultralytics/__init__.py
  4. 36
      ultralytics/cfg/models/v9/yolov9c.yaml
  5. 60
      ultralytics/cfg/models/v9/yolov9e.yaml
  6. 1
      ultralytics/engine/exporter.py
  7. 12
      ultralytics/nn/modules/__init__.py
  8. 152
      ultralytics/nn/modules/block.py
  9. 17
      ultralytics/nn/tasks.py
  10. 1
      ultralytics/utils/downloads.py

@ -76,15 +76,62 @@ The YOLOv9-C model, in particular, highlights the effectiveness of the architect
These results showcase YOLOv9's strategic advancements in model design, emphasizing its enhanced efficiency without compromising on the precision essential for real-time object detection tasks. The model not only pushes the boundaries of performance metrics but also emphasizes the importance of computational efficiency, making it a pivotal development in the field of computer vision.
## Integration and Future Directions
YOLOv9 embodies the spirit of open-source collaboration that is central to the advancement of AI technology. With plans for future integration into the Ultralytics package, YOLOv9 is poised to become an accessible tool for researchers and practitioners alike, further enhancing its impact on the field of computer vision.
## Conclusion
YOLOv9 represents a pivotal development in real-time object detection, offering significant improvements in terms of efficiency, accuracy, and adaptability. By addressing critical challenges through innovative solutions like PGI and GELAN, YOLOv9 sets a new precedent for future research and application in the field. As the AI community continues to evolve, YOLOv9 stands as a testament to the power of collaboration and innovation in driving technological progress.
Stay tuned for updates on Ultralytics package integration and explore the possibilities that YOLOv9 brings to the realm of computer vision.
## Usage Examples
This example provides simple YOLOv9 training and inference examples. For full documentation on these and other [modes](../modes/index.md) see the [Predict](../modes/predict.md), [Train](../modes/train.md), [Val](../modes/val.md) and [Export](../modes/export.md) docs pages.
!!! Example
=== "Python"
PyTorch pretrained `*.pt` models as well as configuration `*.yaml` files can be passed to the `YOLO()` class to create a model instance in python:
```python
from ultralytics import YOLO
# Build a YOLOv9c model from scratch
model = YOLO('yolov9c.yaml')
# Build a YOLOv9c model from pretrained weight
model = YOLO('yolov9c.pt')
# Display model information (optional)
model.info()
# Train the model on the COCO8 example dataset for 100 epochs
results = model.train(data='coco8.yaml', epochs=100, imgsz=640)
# Run inference with the YOLOv9c model on the 'bus.jpg' image
results = model('path/to/bus.jpg')
```
=== "CLI"
CLI commands are available to directly run the models:
```bash
# Build a YOLOv9c model from scratch and train it on the COCO8 example dataset for 100 epochs
yolo train model=yolov9c.yaml data=coco8.yaml epochs=100 imgsz=640
# Build a YOLOv9c model from scratch and run inference on the 'bus.jpg' image
yolo predict model=yolov9c.yaml source=path/to/bus.jpg
```
## Supported Tasks and Modes
The YOLOv9 series offers a range of models, each optimized for high-performance [Object Detection](../tasks/detect.md). These models cater to varying computational needs and accuracy requirements, making them versatile for a wide array of applications.
| Model Type | Pre-trained Weights | Tasks Supported | Inference | Validation | Training | Export |
|------------|-----------------------------------------------------------------------------------------|----------------------------------------|-----------|------------|----------|--------|
| YOLOv9-C | [yolov9c.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov9c.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ✅ | ✅ |
| YOLOv9-E | [yolov9e.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov9e.pt) | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ✅ | ✅ |
This table provides a detailed overview of the YOLOv9 model variants, highlighting their capabilities in object detection tasks and their compatibility with various operational modes such as [Inference](../modes/predict.md), [Validation](../modes/val.md), [Training](../modes/train.md), and [Export](../modes/export.md). This comprehensive support ensures that users can fully leverage the capabilities of YOLOv9 models in a broad range of object detection scenarios.
## Citations and Acknowledgements

@ -106,3 +106,35 @@ keywords: YOLO, Ultralytics, neural network, nn.modules.block, Proto, HGBlock, S
## ::: ultralytics.nn.modules.block.BNContrastiveHead
<br><br>
## ::: ultralytics.nn.modules.block.RepBottleneck
<br><br>
## ::: ultralytics.nn.modules.block.RepCSP
<br><br>
## ::: ultralytics.nn.modules.block.RepNCSPELAN4
<br><br>
## ::: ultralytics.nn.modules.block.ADown
<br><br>
## ::: ultralytics.nn.modules.block.SPPELAN
<br><br>
## ::: ultralytics.nn.modules.block.Silence
<br><br>
## ::: ultralytics.nn.modules.block.CBLinear
<br><br>
## ::: ultralytics.nn.modules.block.CBFuse
<br><br>

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.22"
__version__ = "8.1.23"
from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

@ -0,0 +1,36 @@
# YOLOv9
# parameters
nc: 80 # number of classes
# gelan backbone
backbone:
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]] # 2
- [-1, 1, ADown, [256]] # 3-P3/8
- [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]] # 4
- [-1, 1, ADown, [512]] # 5-P4/16
- [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 6
- [-1, 1, ADown, [512]] # 7-P5/32
- [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 8
- [-1, 1, SPPELAN, [512, 256]] # 9
head:
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 12
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]] # 15 (P3/8-small)
- [-1, 1, ADown, [256]]
- [[-1, 12], 1, Concat, [1]] # cat head P4
- [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 18 (P4/16-medium)
- [-1, 1, ADown, [512]]
- [[-1, 9], 1, Concat, [1]] # cat head P5
- [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 21 (P5/32-large)
- [[15, 18, 21], 1, Detect, [nc]] # DDetect(P3, P4, P5)

@ -0,0 +1,60 @@
# YOLOv9
# parameters
nc: 80 # number of classes
# gelan backbone
backbone:
- [-1, 1, Silence, []]
- [-1, 1, Conv, [64, 3, 2]] # 1-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 2-P2/4
- [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]] # 3
- [-1, 1, ADown, [256]] # 4-P3/8
- [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]] # 5
- [-1, 1, ADown, [512]] # 6-P4/16
- [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 7
- [-1, 1, ADown, [1024]] # 8-P5/32
- [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 9
- [1, 1, CBLinear, [[64]]] # 10
- [3, 1, CBLinear, [[64, 128]]] # 11
- [5, 1, CBLinear, [[64, 128, 256]]] # 12
- [7, 1, CBLinear, [[64, 128, 256, 512]]] # 13
- [9, 1, CBLinear, [[64, 128, 256, 512, 1024]]] # 14
- [0, 1, Conv, [64, 3, 2]] # 15-P1/2
- [[10, 11, 12, 13, 14, -1], 1, CBFuse, [[0, 0, 0, 0, 0]]] # 16
- [-1, 1, Conv, [128, 3, 2]] # 17-P2/4
- [[11, 12, 13, 14, -1], 1, CBFuse, [[1, 1, 1, 1]]] # 18
- [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]] # 19
- [-1, 1, ADown, [256]] # 20-P3/8
- [[12, 13, 14, -1], 1, CBFuse, [[2, 2, 2]]] # 21
- [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]] # 22
- [-1, 1, ADown, [512]] # 23-P4/16
- [[13, 14, -1], 1, CBFuse, [[3, 3]]] # 24
- [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 25
- [-1, 1, ADown, [1024]] # 26-P5/32
- [[14, -1], 1, CBFuse, [[4]]] # 27
- [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 28
- [-1, 1, SPPELAN, [512, 256]] # 29
# gelan head
head:
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 25], 1, Concat, [1]] # cat backbone P4
- [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]] # 32
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 22], 1, Concat, [1]] # cat backbone P3
- [-1, 1, RepNCSPELAN4, [256, 256, 128, 2]] # 35 (P3/8-small)
- [-1, 1, ADown, [256]]
- [[-1, 32], 1, Concat, [1]] # cat head P4
- [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]] # 38 (P4/16-medium)
- [-1, 1, ADown, [512]]
- [[-1, 29], 1, Concat, [1]] # cat head P5
- [-1, 1, RepNCSPELAN4, [512, 1024, 512, 2]] # 41 (P5/32-large)
# detect
- [[35, 38, 41], 1, Detect, [nc]] # Detect(P3, P4, P5)

@ -201,7 +201,6 @@ class Exporter:
assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
if edgetpu and not LINUX:
raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/")
print(type(model))
if isinstance(model, WorldModel):
LOGGER.warning(
"WARNING ⚠ YOLOWorld (original version) export is not supported to any format.\n"

@ -40,6 +40,12 @@ from .block import (
ResNetLayer,
ContrastiveHead,
BNContrastiveHead,
RepNCSPELAN4,
ADown,
SPPELAN,
CBFuse,
CBLinear,
Silence,
)
from .conv import (
CBAM,
@ -123,4 +129,10 @@ __all__ = (
"ImagePoolingAttn",
"ContrastiveHead",
"BNContrastiveHead",
"RepNCSPELAN4",
"ADown",
"SPPELAN",
"CBFuse",
"CBLinear",
"Silence",
)

@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .conv import Conv, DWConv, GhostConv, LightConv, RepConv
from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad
from .transformer import TransformerBlock
__all__ = (
@ -31,6 +31,12 @@ __all__ = (
"Proto",
"RepC3",
"ResNetLayer",
"RepNCSPELAN4",
"ADown",
"SPPELAN",
"CBFuse",
"CBLinear",
"Silence",
)
@ -531,7 +537,6 @@ class BNContrastiveHead(nn.Module):
Args:
embed_dims (int): Embed dimensions of text and image features.
norm_cfg (dict): Normalization parameters.
"""
def __init__(self, embed_dims: int):
@ -548,3 +553,146 @@ class BNContrastiveHead(nn.Module):
w = F.normalize(w, dim=-1, p=2)
x = torch.einsum("bchw,bkc->bkhw", x, w)
return x * self.logit_scale.exp() + self.bias
class RepBottleneck(nn.Module):
"""Rep bottleneck."""
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
"""Initializes a RepBottleneck module with customizable in/out channels, shortcut option, groups and expansion
ratio.
"""
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = RepConv(c1, c_, k[0], 1)
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
self.add = shortcut and c1 == c2
def forward(self, x):
"""Forward pass through RepBottleneck layer."""
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
class RepCSP(nn.Module):
"""Rep CSP Bottleneck with 3 convolutions."""
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
"""Initializes RepCSP layer with given channels, repetitions, shortcut, groups and expansion ratio."""
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
def forward(self, x):
"""Forward pass through RepCSP layer."""
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
class RepNCSPELAN4(nn.Module):
"""CSP-ELAN."""
def __init__(self, c1, c2, c3, c4, n=1):
"""Initializes CSP-ELAN layer with specified channel sizes, repetitions, and convolutions."""
super().__init__()
self.c = c3 // 2
self.cv1 = Conv(c1, c3, 1, 1)
self.cv2 = nn.Sequential(RepCSP(c3 // 2, c4, n), Conv(c4, c4, 3, 1))
self.cv3 = nn.Sequential(RepCSP(c4, c4, n), Conv(c4, c4, 3, 1))
self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1)
def forward(self, x):
"""Forward pass through RepNCSPELAN4 layer."""
y = list(self.cv1(x).chunk(2, 1))
y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
return self.cv4(torch.cat(y, 1))
def forward_split(self, x):
"""Forward pass using split() instead of chunk()."""
y = list(self.cv1(x).split((self.c, self.c), 1))
y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
return self.cv4(torch.cat(y, 1))
class ADown(nn.Module):
"""ADown."""
def __init__(self, c1, c2):
"""Initializes ADown module with convolution layers to downsample input from channels c1 to c2."""
super().__init__()
self.c = c2 // 2
self.cv1 = Conv(c1 // 2, self.c, 3, 2, 1)
self.cv2 = Conv(c1 // 2, self.c, 1, 1, 0)
def forward(self, x):
"""Forward pass through ADown layer."""
x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
x1, x2 = x.chunk(2, 1)
x1 = self.cv1(x1)
x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)
x2 = self.cv2(x2)
return torch.cat((x1, x2), 1)
class SPPELAN(nn.Module):
"""SPP-ELAN."""
def __init__(self, c1, c2, c3, k=5):
"""Initializes SPP-ELAN block with convolution and max pooling layers for spatial pyramid pooling."""
super().__init__()
self.c = c3
self.cv1 = Conv(c1, c3, 1, 1)
self.cv2 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
self.cv3 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
self.cv4 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
self.cv5 = Conv(4 * c3, c2, 1, 1)
def forward(self, x):
"""Forward pass through SPPELAN layer."""
y = [self.cv1(x)]
y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4])
return self.cv5(torch.cat(y, 1))
class Silence(nn.Module):
"""Silence."""
def __init__(self):
"""Initializes the Silence module."""
super(Silence, self).__init__()
def forward(self, x):
"""Forward pass through Silence layer."""
return x
class CBLinear(nn.Module):
"""CBLinear."""
def __init__(self, c1, c2s, k=1, s=1, p=None, g=1):
"""Initializes the CBLinear module, passing inputs unchanged."""
super(CBLinear, self).__init__()
self.c2s = c2s
self.conv = nn.Conv2d(c1, sum(c2s), k, s, autopad(k, p), groups=g, bias=True)
def forward(self, x):
"""Forward pass through CBLinear layer."""
outs = self.conv(x).split(self.c2s, dim=1)
return outs
class CBFuse(nn.Module):
"""CBFuse."""
def __init__(self, idx):
"""Initializes CBFuse module with layer index for selective feature fusion."""
super(CBFuse, self).__init__()
self.idx = idx
def forward(self, xs):
"""Forward pass through CBFuse layer."""
target_size = xs[-1].shape[2:]
res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])]
out = torch.sum(torch.stack(res + xs[-1:]), dim=0)
return out

@ -43,6 +43,12 @@ from ultralytics.nn.modules import (
RTDETRDecoder,
Segment,
WorldDetect,
RepNCSPELAN4,
ADown,
SPPELAN,
CBFuse,
CBLinear,
Silence,
)
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
@ -570,7 +576,7 @@ class WorldModel(DetectionModel):
text_token = clip.tokenize(text).to(device)
txt_feats = model.encode_text(text_token).to(dtype=torch.float32)
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach()
self.model[-1].nc = len(text)
def init_criterion(self):
@ -850,6 +856,9 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
C1,
C2,
C2f,
RepNCSPELAN4,
ADown,
SPPELAN,
C2fAttn,
C3,
C3TR,
@ -892,6 +901,12 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
args.insert(1, [ch[x] for x in f])
elif m is CBLinear:
c2 = args[0]
c1 = ch[f]
args = [c1, c2, *args[1:]]
elif m is CBFuse:
c2 = ch[f[-1]]
else:
c2 = ch[f]

@ -22,6 +22,7 @@ GITHUB_ASSETS_NAMES = (
+ [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
+ [f"yolov8{k}-world.pt" for k in "smlx"]
+ [f"yolov8{k}-worldv2.pt" for k in "smlx"]
+ [f"yolov9{k}.pt" for k in "ce"]
+ [f"yolo_nas_{k}.pt" for k in "sml"]
+ [f"sam_{k}.pt" for k in "bl"]
+ [f"FastSAM-{k}.pt" for k in "sx"]

Loading…
Cancel
Save