`ultralytics 8.2.38` official YOLOv10 support (#13113)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
pull/13814/head^2 v8.2.38
Burhan 9 months ago committed by GitHub
parent 821e5fa477
commit ffb46fd7fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 3
      .github/workflows/ci.yaml
  2. 45
      docs/en/models/yolov10.md
  3. 24
      docs/en/reference/nn/modules/block.md
  4. 4
      docs/en/reference/nn/modules/head.md
  5. 4
      docs/en/reference/utils/loss.md
  6. 9
      tests/test_python.py
  7. 2
      ultralytics/__init__.py
  8. 42
      ultralytics/cfg/models/v10/yolov10b.yaml
  9. 42
      ultralytics/cfg/models/v10/yolov10l.yaml
  10. 42
      ultralytics/cfg/models/v10/yolov10m.yaml
  11. 42
      ultralytics/cfg/models/v10/yolov10n.yaml
  12. 42
      ultralytics/cfg/models/v10/yolov10s.yaml
  13. 42
      ultralytics/cfg/models/v10/yolov10x.yaml
  14. 1
      ultralytics/engine/exporter.py
  15. 14
      ultralytics/nn/modules/__init__.py
  16. 256
      ultralytics/nn/modules/block.py
  17. 117
      ultralytics/nn/modules/head.py
  18. 52
      ultralytics/nn/tasks.py
  19. 5
      ultralytics/utils/benchmarks.py
  20. 1
      ultralytics/utils/downloads.py
  21. 22
      ultralytics/utils/loss.py
  22. 3
      ultralytics/utils/metrics.py
  23. 3
      ultralytics/utils/ops.py

@ -179,6 +179,9 @@ jobs:
- name: Benchmark OBBModel
shell: bash
run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}-obb.pt' imgsz=160 verbose=0.472
- name: Benchmark YOLOv10Model
shell: bash
run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/yolov10n.pt' imgsz=160 verbose=0.178
- name: Merge Coverage Reports
run: |
coverage xml -o coverage-benchmarks.xml

@ -19,7 +19,7 @@ Real-time object detection aims to accurately predict object categories and posi
The architecture of YOLOv10 builds upon the strengths of previous YOLO models while introducing several key innovations. The model architecture consists of the following components:
1. **Backbone**: Responsible for feature extraction, the backbone in YOLOv10 uses an enhanced version of CSPNet (Cross Stage Partial Network) to improve gradient flow and reduce computational redundancy.
2. **Neck**: The neck is designed to aggregate features from different scales and passes them to the head. It includes PAN (Path Aggregation Network) layers for effective multiscale feature fusion.
2. **Neck**: The neck is designed to aggregate features from different scales and passes them to the head. It includes PAN (Path Aggregation Network) layers for effective multi-scale feature fusion.
3. **One-to-Many Head**: Generates multiple predictions per object during training to provide rich supervisory signals and improve learning accuracy.
4. **One-to-One Head**: Generates a single best prediction per object during inference to eliminate the need for NMS, thereby reducing latency and improving efficiency.
@ -113,23 +113,19 @@ Here is a detailed comparison of YOLOv10 variants with other state-of-the-art mo
| YOLOv8-L | 43.7 | 165.2 | 52.9 | 12.39 | 8.06 |
| RT-DETR-R50 | 42.0 | 136.0 | 53.1 | 9.20 | 9.07 |
| **YOLOv10-L** | **24.4** | **120.3** | **53.4** | **7.28** | **7.21** |
| | | | | |
| | | | | | |
| YOLOv8-X | 68.2 | 257.8 | 53.9 | 16.86 | 12.83 |
| RT-DETR-R101 | 76.0 | 259.0 | 54.3 | 13.71 | 13.58 |
| **YOLOv10-X** | **29.5** | **160.4** | **54.4** | **10.70** | **10.60** |
## Usage Examples
!!! tip "Coming Soon"
The Ultralytics team is actively working on officially integrating the YOLOv10 models into the `ultralytics` package. Once the integration is complete, the usage examples shown below will be fully functional. Please stay tuned by following our social media and [GitHub repository](https://github.com/ultralytics/ultralytics) for the latest updates on YOLOv10 integration. We appreciate your patience and excitement! 🚀
For predicting new images with YOLOv10:
```python
from ultralytics import YOLO
# Load a pretrained YOLOv10n model
# Load a pre-trained YOLOv10n model
model = YOLO("yolov10n.pt")
# Perform object detection on an image
@ -151,6 +147,34 @@ model = YOLO("yolov10n.yaml")
model.train(data="coco8.yaml", epochs=100, imgsz=640)
```
## Supported Tasks and Modes
The YOLOv10 models series offers a range of models, each optimized for high-performance [Object Detection](../tasks/detect.md). These models cater to varying computational needs and accuracy requirements, making them versatile for a wide array of applications.
| Model | Filenames | Tasks | Inference | Validation | Training | Export |
|---------|------------------------------------------------------------------------|----------------------------------------------|-----------|------------|----------|--------|
| YOLOv10 | `yolov10n.pt` `yolov10s.pt` `yolov10m.pt` `yolov10l.pt` `yolov10x.pt` | [Object Detection](../tasks/detect.md) | ✅ | ✅ | ✅ | ✅ |
## Exporting YOLOv10
Due to the new operations introduced with YOLOv10, not all export formats provided by Ultralytics are currently supported. The following table outlines which formats have been successfully converted using Ultralytics for YOLOv10. Feel free to open a pull request if you're able to [provide a contribution change](../help/contributing.md) for adding export support of additional formats for YOLOv10.
| Export Format | Supported |
| ------------------------------------------------- | --------- |
| [TorchScript](../integrations/torchscript.md) | ✅ |
| [ONNX](../integrations/onnx.md) | ✅ |
| [OpenVINO](../integrations/openvino.md) | ✅ |
| [TensorRT](../integrations/tensorrt.md) | ✅ |
| [CoreML](../integrations/coreml.md) | ❌ |
| [TF SavedModel](../integrations/tf-savedmodel.md) | ❌ |
| [TF GraphDef](../integrations/tf-graphdef.md) | ❌ |
| [TF Lite](../integrations/tflite.md) | ❌ |
| [TF Edge TPU](../integrations/edge-tpu.md) | ❌ |
| [TF.js](../integrations/tfjs.md) | ❌ |
| [PaddlePaddle](../integrations/paddlepaddle.md) | ❌ |
| [NCNN](../integrations/ncnn.md) | ❌ |
## Conclusion
YOLOv10 sets a new standard in real-time object detection by addressing the shortcomings of previous YOLO versions and incorporating innovative design strategies. Its ability to deliver high accuracy with low computational cost makes it an ideal choice for a wide range of real-world applications.
@ -175,3 +199,10 @@ We would like to acknowledge the YOLOv10 authors from [Tsinghua University](http
```
For detailed implementation, architectural innovations, and experimental results, please refer to the YOLOv10 [research paper](https://arxiv.org/pdf/2405.14458) and [GitHub repository](https://github.com/THU-MIG/yolov10) by the Tsinghua University team.
[1]: https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov10n.pt
[2]: https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov10s.pt
[3]: https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov10m.pt
[4]: https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov10b.pt
[5]: https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov10l.pt
[6]: https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov10x.pt

@ -142,3 +142,27 @@ keywords: Ultralytics, YOLO, neural networks, block modules, DFL, Proto, HGStem,
## ::: ultralytics.nn.modules.block.CBFuse
<br><br>
## ::: ultralytics.nn.modules.block.RepVGGDW
<br><br>
## ::: ultralytics.nn.modules.block.CIB
<br><br>
## ::: ultralytics.nn.modules.block.C2fCIB
<br><br>
## ::: ultralytics.nn.modules.block.Attention
<br><br>
## ::: ultralytics.nn.modules.block.PSA
<br><br>
## ::: ultralytics.nn.modules.block.SCDown
<br><br>

@ -38,3 +38,7 @@ keywords: Ultralytics, YOLO, Detection, Pose, RTDETRDecoder, nn modules, guides
## ::: ultralytics.nn.modules.head.RTDETRDecoder
<br><br>
## ::: ultralytics.nn.modules.head.v10Detect
<br><br>

@ -50,3 +50,7 @@ keywords: Ultralytics, loss functions, Varifocal Loss, Focal Loss, Bbox Loss, Ro
## ::: ultralytics.utils.loss.v8OBBLoss
<br><br>
## ::: ultralytics.utils.loss.E2EDetectLoss
<br><br>

@ -577,3 +577,12 @@ def test_yolo_world():
close_mosaic=1,
trainer=WorldTrainerFromScratch,
)
def test_yolov10():
"""A simple test for yolov10 for now."""
model = YOLO("yolov10n.yaml")
# train/val/predict
model.train(data="coco8.yaml", epochs=1, imgsz=32, close_mosaic=1, cache="disk")
model.val(data="coco8.yaml", imgsz=32)
model(SOURCE)

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.37"
__version__ = "8.2.38"
import os

@ -0,0 +1,42 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv10 object detection model. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
b: [0.67, 1.00, 512]
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2fCIB, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 1, PSA, [1024]] # 10
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2fCIB, [512, True]] # 13
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 16 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 13], 1, Concat, [1]] # cat head P4
- [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium)
- [-1, 1, SCDown, [512, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large)
- [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)

@ -0,0 +1,42 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv10 object detection model. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
l: [1.00, 1.00, 512]
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2fCIB, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 1, PSA, [1024]] # 10
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2fCIB, [512, True]] # 13
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 16 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 13], 1, Concat, [1]] # cat head P4
- [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium)
- [-1, 1, SCDown, [512, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large)
- [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)

@ -0,0 +1,42 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv10 object detection model. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
m: [0.67, 0.75, 768]
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2fCIB, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 1, PSA, [1024]] # 10
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 13
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 16 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 13], 1, Concat, [1]] # cat head P4
- [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium)
- [-1, 1, SCDown, [512, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large)
- [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)

@ -0,0 +1,42 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv10 object detection model. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024]
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 1, PSA, [1024]] # 10
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 13
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 16 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 13], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 19 (P4/16-medium)
- [-1, 1, SCDown, [512, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 3, C2fCIB, [1024, True, True]] # 22 (P5/32-large)
- [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)

@ -0,0 +1,42 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv10 object detection model. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
s: [0.33, 0.50, 1024]
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2fCIB, [1024, True, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 1, PSA, [1024]] # 10
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 13
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 16 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 13], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 19 (P4/16-medium)
- [-1, 1, SCDown, [512, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 3, C2fCIB, [1024, True, True]] # 22 (P5/32-large)
- [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)

@ -0,0 +1,42 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv10 object detection model. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
x: [1.00, 1.25, 512]
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2fCIB, [512, True]]
- [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2fCIB, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 1, PSA, [1024]] # 10
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2fCIB, [512, True]] # 13
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 16 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 13], 1, Concat, [1]] # cat head P4
- [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium)
- [-1, 1, SCDown, [512, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large)
- [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)

@ -920,6 +920,7 @@ class Exporter:
@try_export
def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr("TensorFlow Lite:")):
"""YOLOv8 TensorFlow Lite export."""
# BUG https://github.com/ultralytics/ultralytics/issues/13436
import tensorflow as tf # noqa
LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")

@ -22,18 +22,22 @@ from .block import (
C2,
C3,
C3TR,
CIB,
DFL,
ELAN1,
PSA,
SPP,
SPPELAN,
SPPF,
AConv,
ADown,
Attention,
BNContrastiveHead,
Bottleneck,
BottleneckCSP,
C2f,
C2fAttn,
C2fCIB,
C3Ghost,
C3x,
CBFuse,
@ -46,7 +50,9 @@ from .block import (
Proto,
RepC3,
RepNCSPELAN4,
RepVGGDW,
ResNetLayer,
SCDown,
)
from .conv import (
CBAM,
@ -63,7 +69,7 @@ from .conv import (
RepConv,
SpatialAttention,
)
from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment, WorldDetect
from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment, WorldDetect, v10Detect
from .transformer import (
AIFI,
MLP,
@ -137,4 +143,10 @@ __all__ = (
"CBLinear",
"AConv",
"ELAN1",
"RepVGGDW",
"CIB",
"C2fCIB",
"Attention",
"PSA",
"SCDown",
)

@ -5,6 +5,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics.utils.torch_utils import fuse_conv_and_bn
from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad
from .transformer import TransformerBlock
@ -39,6 +41,12 @@ __all__ = (
"CBFuse",
"CBLinear",
"Silence",
"RepVGGDW",
"CIB",
"C2fCIB",
"Attention",
"PSA",
"SCDown",
)
@ -699,3 +707,251 @@ class CBFuse(nn.Module):
target_size = xs[-1].shape[2:]
res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])]
return torch.sum(torch.stack(res + xs[-1:]), dim=0)
class RepVGGDW(torch.nn.Module):
"""RepVGGDW is a class that represents a depth wise separable convolutional block in RepVGG architecture."""
def __init__(self, ed) -> None:
super().__init__()
self.conv = Conv(ed, ed, 7, 1, 3, g=ed, act=False)
self.conv1 = Conv(ed, ed, 3, 1, 1, g=ed, act=False)
self.dim = ed
self.act = nn.SiLU()
def forward(self, x):
"""
Performs a forward pass of the RepVGGDW block.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor after applying the depth wise separable convolution.
"""
return self.act(self.conv(x) + self.conv1(x))
def forward_fuse(self, x):
"""
Performs a forward pass of the RepVGGDW block without fusing the convolutions.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor after applying the depth wise separable convolution.
"""
return self.act(self.conv(x))
@torch.no_grad()
def fuse(self):
"""
Fuses the convolutional layers in the RepVGGDW block.
This method fuses the convolutional layers and updates the weights and biases accordingly.
"""
conv = fuse_conv_and_bn(self.conv.conv, self.conv.bn)
conv1 = fuse_conv_and_bn(self.conv1.conv, self.conv1.bn)
conv_w = conv.weight
conv_b = conv.bias
conv1_w = conv1.weight
conv1_b = conv1.bias
conv1_w = torch.nn.functional.pad(conv1_w, [2, 2, 2, 2])
final_conv_w = conv_w + conv1_w
final_conv_b = conv_b + conv1_b
conv.weight.data.copy_(final_conv_w)
conv.bias.data.copy_(final_conv_b)
self.conv = conv
del self.conv1
class CIB(nn.Module):
"""
Conditional Identity Block (CIB) module.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
shortcut (bool, optional): Whether to add a shortcut connection. Defaults to True.
e (float, optional): Scaling factor for the hidden channels. Defaults to 0.5.
lk (bool, optional): Whether to use RepVGGDW for the third convolutional layer. Defaults to False.
"""
def __init__(self, c1, c2, shortcut=True, e=0.5, lk=False):
"""Initializes the custom model with optional shortcut, scaling factor, and RepVGGDW layer."""
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = nn.Sequential(
Conv(c1, c1, 3, g=c1),
Conv(c1, 2 * c_, 1),
Conv(2 * c_, 2 * c_, 3, g=2 * c_) if not lk else RepVGGDW(2 * c_),
Conv(2 * c_, c2, 1),
Conv(c2, c2, 3, g=c2),
)
self.add = shortcut and c1 == c2
def forward(self, x):
"""
Forward pass of the CIB module.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor.
"""
return x + self.cv1(x) if self.add else self.cv1(x)
class C2fCIB(C2f):
"""
C2fCIB class represents a convolutional block with C2f and CIB modules.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
n (int, optional): Number of CIB modules to stack. Defaults to 1.
shortcut (bool, optional): Whether to use shortcut connection. Defaults to False.
lk (bool, optional): Whether to use local key connection. Defaults to False.
g (int, optional): Number of groups for grouped convolution. Defaults to 1.
e (float, optional): Expansion ratio for CIB modules. Defaults to 0.5.
"""
def __init__(self, c1, c2, n=1, shortcut=False, lk=False, g=1, e=0.5):
"""Initializes the module with specified parameters for channel, shortcut, local key, groups, and expansion."""
super().__init__(c1, c2, n, shortcut, g, e)
self.m = nn.ModuleList(CIB(self.c, self.c, shortcut, e=1.0, lk=lk) for _ in range(n))
class Attention(nn.Module):
"""
Attention module that performs self-attention on the input tensor.
Args:
dim (int): The input tensor dimension.
num_heads (int): The number of attention heads.
attn_ratio (float): The ratio of the attention key dimension to the head dimension.
Attributes:
num_heads (int): The number of attention heads.
head_dim (int): The dimension of each attention head.
key_dim (int): The dimension of the attention key.
scale (float): The scaling factor for the attention scores.
qkv (Conv): Convolutional layer for computing the query, key, and value.
proj (Conv): Convolutional layer for projecting the attended values.
pe (Conv): Convolutional layer for positional encoding.
"""
def __init__(self, dim, num_heads=8, attn_ratio=0.5):
"""Initializes multi-head attention module with query, key, and value convolutions and positional encoding."""
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.key_dim = int(self.head_dim * attn_ratio)
self.scale = self.key_dim**-0.5
nh_kd = nh_kd = self.key_dim * num_heads
h = dim + nh_kd * 2
self.qkv = Conv(dim, h, 1, act=False)
self.proj = Conv(dim, dim, 1, act=False)
self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
def forward(self, x):
"""
Forward pass of the Attention module.
Args:
x (torch.Tensor): The input tensor.
Returns:
(torch.Tensor): The output tensor after self-attention.
"""
B, C, H, W = x.shape
N = H * W
qkv = self.qkv(x)
q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split(
[self.key_dim, self.key_dim, self.head_dim], dim=2
)
attn = (q.transpose(-2, -1) @ k) * self.scale
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))
x = self.proj(x)
return x
class PSA(nn.Module):
"""
Position-wise Spatial Attention module.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
e (float): Expansion factor for the intermediate channels. Default is 0.5.
Attributes:
c (int): Number of intermediate channels.
cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
attn (Attention): Attention module for spatial attention.
ffn (nn.Sequential): Feed-forward network module.
"""
def __init__(self, c1, c2, e=0.5):
"""Initializes convolution layers, attention module, and feed-forward network with channel reduction."""
super().__init__()
assert c1 == c2
self.c = int(c1 * e)
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
self.cv2 = Conv(2 * self.c, c1, 1)
self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 64)
self.ffn = nn.Sequential(Conv(self.c, self.c * 2, 1), Conv(self.c * 2, self.c, 1, act=False))
def forward(self, x):
"""
Forward pass of the PSA module.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor.
"""
a, b = self.cv1(x).split((self.c, self.c), dim=1)
b = b + self.attn(b)
b = b + self.ffn(b)
return self.cv2(torch.cat((a, b), 1))
class SCDown(nn.Module):
def __init__(self, c1, c2, k, s):
"""
Spatial Channel Downsample (SCDown) module.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
k (int): Kernel size for the convolutional layer.
s (int): Stride for the convolutional layer.
"""
super().__init__()
self.cv1 = Conv(c1, c2, 1, 1)
self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False)
def forward(self, x):
"""
Forward pass of the SCDown module.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor after applying the SCDown module.
"""
return self.cv2(self.cv1(x))

@ -1,6 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""Model head modules."""
import copy
import math
import torch
@ -14,7 +15,7 @@ from .conv import Conv
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
from .utils import bias_init_with_prob, linear_init
__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"
__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect"
class Detect(nn.Module):
@ -22,6 +23,8 @@ class Detect(nn.Module):
dynamic = False # force grid reconstruction
export = False # export mode
end2end = False # end2end
max_det = 300 # max_det
shape = None
anchors = torch.empty(0) # init
strides = torch.empty(0) # init
@ -41,13 +44,48 @@ class Detect(nn.Module):
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
if self.end2end:
self.one2one_cv2 = copy.deepcopy(self.cv2)
self.one2one_cv3 = copy.deepcopy(self.cv3)
def forward(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
if self.end2end:
return self.forward_end2end(x)
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
if self.training: # Training path
return x
y = self._inference(x)
return y if self.export else (y, x)
def forward_end2end(self, x):
"""
Performs forward pass of the v10Detect module.
Args:
x (tensor): Input tensor.
Returns:
(dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.
If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.
"""
x_detach = [xi.detach() for xi in x]
one2one = [
torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
]
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
if self.training: # Training path
return {"one2many": x, "one2one": one2one}
y = self._inference(one2one)
y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
return y if self.export else (y, {"one2many": x, "one2one": one2one})
def _inference(self, x):
"""Decode predicted bounding boxes and class probabilities based on multiple-level feature maps."""
# Inference path
shape = x[0].shape # BCHW
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
@ -73,7 +111,7 @@ class Detect(nn.Module):
dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
y = torch.cat((dbox, cls.sigmoid()), 1)
return y if self.export else (y, x)
return y
def bias_init(self):
"""Initialize Detect() biases, WARNING: requires stride availability."""
@ -83,10 +121,47 @@ class Detect(nn.Module):
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
a[-1].bias.data[:] = 1.0 # box
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
if self.end2end:
for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from
a[-1].bias.data[:] = 1.0 # box
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
def decode_bboxes(self, bboxes, anchors):
"""Decode bounding boxes."""
return dist2bbox(bboxes, anchors, xywh=True, dim=1)
return dist2bbox(bboxes, anchors, xywh=not self.end2end, dim=1)
@staticmethod
def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
"""
Post-processes the predictions obtained from a YOLOv10 model.
Args:
preds (torch.Tensor): The predictions obtained from the model. It should have a shape of (batch_size, num_boxes, 4 + num_classes).
max_det (int): The maximum number of detections to keep.
nc (int, optional): The number of classes. Defaults to 80.
Returns:
(torch.Tensor): The post-processed predictions with shape (batch_size, max_det, 6),
including bounding boxes, scores and cls.
"""
assert 4 + nc == preds.shape[-1]
boxes, scores = preds.split([4, nc], dim=-1)
max_scores = scores.amax(dim=-1)
max_scores, index = torch.topk(max_scores, min(max_det, max_scores.shape[1]), axis=-1)
index = index.unsqueeze(-1)
boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))
# NOTE: simplify but result slightly lower mAP
# scores, labels = scores.max(dim=-1)
# return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
scores, index = torch.topk(scores.flatten(1), max_det, axis=-1)
labels = index % nc
index = index // nc
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1)
class Segment(Detect):
@ -487,3 +562,39 @@ class RTDETRDecoder(nn.Module):
xavier_uniform_(self.query_pos_head.layers[1].weight)
for layer in self.input_proj:
xavier_uniform_(layer[0].weight)
class v10Detect(Detect):
"""
v10 Detection head from https://arxiv.org/pdf/2405.14458
Args:
nc (int): Number of classes.
ch (tuple): Tuple of channel sizes.
Attributes:
max_det (int): Maximum number of detections.
Methods:
__init__(self, nc=80, ch=()): Initializes the v10Detect object.
forward(self, x): Performs forward pass of the v10Detect module.
bias_init(self): Initializes biases of the Detect module.
"""
end2end = True
def __init__(self, nc=80, ch=()):
"""Initializes the v10Detect object with the specified number of classes and input channels."""
super().__init__(nc, ch)
c3 = max(ch[0], min(self.nc, 100)) # channels
# Light cls head
self.cv3 = nn.ModuleList(
nn.Sequential(
nn.Sequential(Conv(x, x, 3, g=x), Conv(x, c3, 1)),
nn.Sequential(Conv(c3, c3, 3, g=c3), Conv(c3, c3, 1)),
nn.Conv2d(c3, self.nc, 1),
)
for x in ch
)
self.one2one_cv3 = copy.deepcopy(self.cv3)

@ -15,6 +15,7 @@ from ultralytics.nn.modules import (
C3TR,
ELAN1,
OBB,
PSA,
SPP,
SPPELAN,
SPPF,
@ -24,6 +25,7 @@ from ultralytics.nn.modules import (
BottleneckCSP,
C2f,
C2fAttn,
C2fCIB,
C3Ghost,
C3x,
CBFuse,
@ -46,14 +48,24 @@ from ultralytics.nn.modules import (
RepC3,
RepConv,
RepNCSPELAN4,
RepVGGDW,
ResNetLayer,
RTDETRDecoder,
SCDown,
Segment,
WorldDetect,
v10Detect,
)
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss
from ultralytics.utils.loss import (
E2EDetectLoss,
v8ClassificationLoss,
v8DetectionLoss,
v8OBBLoss,
v8PoseLoss,
v8SegmentationLoss,
)
from ultralytics.utils.plotting import feature_visualization
from ultralytics.utils.torch_utils import (
fuse_conv_and_bn,
@ -192,6 +204,9 @@ class BaseModel(nn.Module):
if isinstance(m, RepConv):
m.fuse_convs()
m.forward = m.forward_fuse # update forward
if isinstance(m, RepVGGDW):
m.fuse()
m.forward = m.forward_fuse
self.info(verbose=verbose)
return self
@ -294,6 +309,7 @@ class DetectionModel(BaseModel):
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
self.inplace = self.yaml.get("inplace", True)
self.end2end = getattr(self.model[-1], "end2end", False)
# Build strides
m = self.model[-1] # Detect()
@ -303,6 +319,8 @@ class DetectionModel(BaseModel):
def _forward(x):
"""Performs a forward pass through the model, handling different Detect subclass types accordingly."""
if self.end2end:
return self.forward(x)["one2many"]
return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
@ -355,7 +373,7 @@ class DetectionModel(BaseModel):
def init_criterion(self):
"""Initialize the loss criterion for the DetectionModel."""
return v8DetectionLoss(self)
return E2EDetectLoss(self) if self.end2end else v8DetectionLoss(self)
class OBBModel(DetectionModel):
@ -689,8 +707,8 @@ def temporary_modules(modules={}, attributes={}):
Example:
```python
with temporary_modules({'old.module.path': 'new.module.path'}, {'old.module.attribute': 'new.module.attribute'}):
import old.module.path # this will now import new.module.path
with temporary_modules({'old.module': 'new.module'}, {'old.module.attribute': 'new.module.attribute'}):
import old.module # this will now import new.module
from old.module import attribute # this will now import new.module.attribute
```
@ -700,23 +718,19 @@ def temporary_modules(modules={}, attributes={}):
applications or libraries. Use this function with caution.
"""
import importlib
import sys
from importlib import import_module
try:
# Set attributes in sys.modules under their old name
for old, new in attributes.items():
old_module, old_attr = old.rsplit(".", 1)
new_module, new_attr = new.rsplit(".", 1)
setattr(
importlib.import_module(old_module),
old_attr,
getattr(importlib.import_module(new_module), new_attr),
)
setattr(import_module(old_module), old_attr, getattr(import_module(new_module), new_attr))
# Set modules in sys.modules under their old name
for old, new in modules.items():
sys.modules[old] = importlib.import_module(new)
sys.modules[old] = import_module(new)
yield
finally:
@ -750,9 +764,10 @@ def torch_safe_load(weight):
"ultralytics.yolo.data": "ultralytics.data",
},
attributes={
"ultralytics.nn.modules.block.Silence": "torch.nn.Identity",
"ultralytics.nn.modules.block.Silence": "torch.nn.Identity", # YOLOv9e
"ultralytics.nn.tasks.YOLOv10DetectionModel": "ultralytics.nn.tasks.DetectionModel", # YOLOv10
},
): # for legacy 8.0 Classify and Pose models
):
ckpt = torch.load(file, map_location="cpu")
except ModuleNotFoundError as e: # e.name is missing module name
@ -911,6 +926,9 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
DWConvTranspose2d,
C3x,
RepC3,
PSA,
SCDown,
C2fCIB,
}:
c1, c2 = ch[f], args[0]
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
@ -922,7 +940,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
) # num heads
args = [c1, c2, *args[1:]]
if m in {BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3}:
if m in {BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3, C2fCIB}:
args.insert(2, n) # number of repeats
n = 1
elif m is AIFI:
@ -939,7 +957,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[x] for x in f)
elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn}:
elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}:
args.append([ch[x] for x in f])
if m is Segment:
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
@ -1024,7 +1042,7 @@ def guess_model_task(model):
m = cfg["head"][-1][-2].lower() # output module name
if m in {"classify", "classifier", "cls", "fc"}:
return "classify"
if m == "detect":
if "detect" in m:
return "detect"
if m == "segment":
return "segment"
@ -1056,7 +1074,7 @@ def guess_model_task(model):
return "pose"
elif isinstance(m, OBB):
return "obb"
elif isinstance(m, (Detect, WorldDetect)):
elif isinstance(m, (Detect, WorldDetect, v10Detect)):
return "detect"
# Guess from model filename

@ -81,6 +81,7 @@ def benchmark(
device = select_device(device, verbose=False)
if isinstance(model, (str, Path)):
model = YOLO(model)
is_end2end = getattr(model.model.model[-1], "end2end", False)
y = []
t0 = time.time()
@ -96,14 +97,18 @@ def benchmark(
assert MACOS or LINUX, "CoreML and TF.js export only supported on macOS and Linux"
assert not IS_RASPBERRYPI, "CoreML and TF.js export not supported on Raspberry Pi"
assert not IS_JETSON, "CoreML and TF.js export not supported on NVIDIA Jetson"
assert not is_end2end, "End-to-end models not supported by CoreML and TF.js yet"
if i in {3, 5}: # CoreML and OpenVINO
assert not IS_PYTHON_3_12, "CoreML and OpenVINO not supported on Python 3.12"
if i in {6, 7, 8, 9, 10}: # All TF formats
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
assert not is_end2end, "End-to-end models not supported by onnx2tf yet"
if i in {11}: # Paddle
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet"
if i in {12}: # NCNN
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
assert not is_end2end, "End-to-end models not supported by NCNN yet"
if "cpu" in device.type:
assert cpu, "inference not supported on CPU"
if "cuda" in device.type:

@ -23,6 +23,7 @@ GITHUB_ASSETS_NAMES = (
+ [f"yolov8{k}-world.pt" for k in "smlx"]
+ [f"yolov8{k}-worldv2.pt" for k in "smlx"]
+ [f"yolov9{k}.pt" for k in "ce"]
+ [f"yolov10{k}.pt" for k in "nsmblx"]
+ [f"yolo_nas_{k}.pt" for k in "sml"]
+ [f"sam_{k}.pt" for k in "bl"]
+ [f"FastSAM-{k}.pt" for k in "sx"]

@ -148,7 +148,7 @@ class KeypointLoss(nn.Module):
class v8DetectionLoss:
"""Criterion class for computing training losses."""
def __init__(self, model): # model must be de-paralleled
def __init__(self, model, tal_topk=10): # model must be de-paralleled
"""Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function."""
device = next(model.parameters()).device # get model device
h = model.args # hyperparameters
@ -164,7 +164,7 @@ class v8DetectionLoss:
self.use_dfl = m.reg_max > 1
self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
@ -714,3 +714,21 @@ class v8OBBLoss(v8DetectionLoss):
b, a, c = pred_dist.shape # batch, anchors, channels
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
class E2EDetectLoss:
"""Criterion class for computing training losses."""
def __init__(self, model):
"""Initialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model."""
self.one2many = v8DetectionLoss(model, tal_topk=10)
self.one2one = v8DetectionLoss(model, tal_topk=1)
def __call__(self, preds, batch):
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
preds = preds[1] if isinstance(preds, tuple) else preds
one2many = preds["one2many"]
loss_one2many = self.one2many(one2many, batch)
one2one = preds["one2one"]
loss_one2one = self.one2one(one2one, batch)
return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]

@ -64,8 +64,9 @@ def box_iou(box1, box2, eps=1e-7):
(torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
"""
# NOTE: Need .float() to get accurate iou values
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
(a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
(a1, a2), (b1, b2) = box1.float().unsqueeze(1).chunk(2, 2), box2.float().unsqueeze(0).chunk(2, 2)
inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2)
# IoU = inter / (area1 + area2 - inter)

@ -213,6 +213,9 @@ def non_max_suppression(
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
if prediction.shape[-1] == 6: # end-to-end model
return [pred[pred[:, 4] > conf_thres] for pred in prediction]
bs = prediction.shape[0] # batch size
nc = nc or (prediction.shape[1] - 4) # number of classes
nm = prediction.shape[1] - nc - 4 # number of masks

Loading…
Cancel
Save