Merge branch 'main' into yolo-serve

yolo-serve
Glenn Jocher 4 months ago committed by GitHub
commit 2c5a478a6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 1
      .github/workflows/docker.yaml
  2. 4
      docker/Dockerfile-runner
  3. 33
      docs/en/guides/distance-calculation.md
  4. 1
      docs/en/guides/heatmaps.md
  5. 192
      docs/en/guides/raspberry-pi.md
  6. 1
      docs/en/guides/workouts-monitoring.md
  7. 15
      docs/en/models/sam-2.md
  8. 8
      docs/en/models/sam.md
  9. 2
      mkdocs.yml
  10. 2
      ultralytics/__init__.py
  11. 2
      ultralytics/engine/exporter.py
  12. 114
      ultralytics/models/sam/predict.py
  13. 3
      ultralytics/models/yolo/classify/train.py
  14. 87
      ultralytics/solutions/distance_calculation.py
  15. 4
      ultralytics/solutions/object_counter.py
  16. 2
      ultralytics/utils/metrics.py
  17. 29
      ultralytics/utils/plotting.py

@ -184,6 +184,7 @@ jobs:
env:
GH_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }}
run: |
sleep 60
gh workflow run deploy_cloud_run.yml \
--repo ultralytics/assistant \
--ref main

@ -17,8 +17,8 @@ ENV PYTHONUNBUFFERED=1 \
WORKDIR /actions-runner
# Download and unpack the latest runner from https://github.com/actions/runner
RUN FILENAME=actions-runner-linux-x64-2.317.0.tar.gz && \
curl -o $FILENAME -L https://github.com/actions/runner/releases/download/v2.317.0/$FILENAME && \
RUN FILENAME=actions-runner-linux-x64-2.320.0.tar.gz && \
curl -o $FILENAME -L https://github.com/actions/runner/releases/download/v2.320.0/$FILENAME && \
tar xzf $FILENAME && \
rm $FILENAME

@ -43,12 +43,9 @@ Measuring the gap between two objects is known as distance calculation within a
```python
import cv2
from ultralytics import YOLO, solutions
from ultralytics import solutions
model = YOLO("yolo11n.pt")
names = model.model.names
cap = cv2.VideoCapture("path/to/video/file.mp4")
cap = cv2.VideoCapture("Path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
@ -56,16 +53,14 @@ Measuring the gap between two objects is known as distance calculation within a
video_writer = cv2.VideoWriter("distance_calculation.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
# Init distance-calculation obj
dist_obj = solutions.DistanceCalculation(names=names, view_img=True)
distance = solutions.DistanceCalculation(model="yolo11n.pt", show=True)
while cap.isOpened():
success, im0 = cap.read()
if not success:
print("Video frame is empty or video processing has been successfully completed.")
break
tracks = model.track(im0, persist=True, show=False)
im0 = dist_obj.start_process(im0, tracks)
im0 = distance.calculate(im0)
video_writer.write(im0)
cap.release()
@ -84,13 +79,11 @@ Measuring the gap between two objects is known as distance calculation within a
### Arguments `DistanceCalculation()`
| `Name` | `Type` | `Default` | Description |
| ---------------- | ------- | --------------- | --------------------------------------------------------- |
| `names` | `dict` | `None` | Dictionary of classes names. |
| `view_img` | `bool` | `False` | Flag to indicate if the video stream should be displayed. |
| `line_thickness` | `int` | `2` | Thickness of the lines drawn on the image. |
| `line_color` | `tuple` | `(255, 255, 0)` | Color of the lines drawn on the image (BGR format). |
| `centroid_color` | `tuple` | `(255, 0, 255)` | Color of the centroids drawn (BGR format). |
| `Name` | `Type` | `Default` | Description |
| ------------ | ------ | --------- | ---------------------------------------------------- |
| `model` | `str` | `None` | Path to Ultralytics YOLO Model File |
| `line_width` | `int` | `2` | Line thickness for bounding boxes. |
| `show` | `bool` | `False` | Flag to control whether to display the video stream. |
### Arguments `model.track`
@ -122,10 +115,8 @@ To delete points drawn during distance calculation with Ultralytics YOLO11, you
The key arguments for initializing the `DistanceCalculation` class in Ultralytics YOLO11 include:
- `names`: Dictionary mapping class indices to class names.
- `view_img`: Flag to indicate if the video stream should be displayed.
- `line_thickness`: Thickness of the lines drawn on the image.
- `line_color`: Color of the lines drawn on the image (BGR format).
- `centroid_color`: Color of the centroids (BGR format).
- `model`: Model file path.
- `show`: Flag to indicate if the video stream should be displayed.
- `line_width`: Thickness of bounding box and the lines drawn on the image.
For an exhaustive list and default values, see the [arguments of DistanceCalculation](#arguments-distancecalculation).

@ -222,6 +222,7 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
| Name | Type | Default | Description |
| ------------ | ------ | ------------------ | ----------------------------------------------------------------- |
| `model` | `str` | `None` | Path to Ultralytics YOLO Model File |
| `colormap` | `int` | `cv2.COLORMAP_JET` | Colormap to use for the heatmap. |
| `show` | `bool` | `False` | Whether to display the image with the heatmap overlay. |
| `show_in` | `bool` | `True` | Whether to display the count of objects entering the region. |

@ -1,12 +1,12 @@
---
comments: true
description: Learn how to deploy Ultralytics YOLOv8 on Raspberry Pi with our comprehensive guide. Get performance benchmarks, setup instructions, and best practices.
keywords: Ultralytics, YOLOv8, Raspberry Pi, setup, guide, benchmarks, computer vision, object detection, NCNN, Docker, camera modules
description: Learn how to deploy Ultralytics YOLO11 on Raspberry Pi with our comprehensive guide. Get performance benchmarks, setup instructions, and best practices.
keywords: Ultralytics, YOLO11, Raspberry Pi, setup, guide, benchmarks, computer vision, object detection, NCNN, Docker, camera modules
---
# Quick Start Guide: Raspberry Pi with Ultralytics YOLOv8
# Quick Start Guide: Raspberry Pi with Ultralytics YOLO11
This comprehensive guide provides a detailed walkthrough for deploying Ultralytics YOLOv8 on [Raspberry Pi](https://www.raspberrypi.com/) devices. Additionally, it showcases performance benchmarks to demonstrate the capabilities of YOLOv8 on these small and powerful devices.
This comprehensive guide provides a detailed walkthrough for deploying Ultralytics YOLO11 on [Raspberry Pi](https://www.raspberrypi.com/) devices. Additionally, it showcases performance benchmarks to demonstrate the capabilities of YOLO11 on these small and powerful devices.
<p align="center">
<br>
@ -56,7 +56,7 @@ There are two ways of setting up Ultralytics package on Raspberry Pi to build yo
### Start with Docker
The fastest way to get started with Ultralytics YOLOv8 on Raspberry Pi is to run with pre-built docker image for Raspberry Pi.
The fastest way to get started with Ultralytics YOLO11 on Raspberry Pi is to run with pre-built docker image for Raspberry Pi.
Execute the below command to pull the Docker container and run on Raspberry Pi. This is based on [arm64v8/debian](https://hub.docker.com/r/arm64v8/debian) docker image which contains Debian 12 (Bookworm) in a Python3 environment.
@ -98,7 +98,7 @@ Out of all the model export formats supported by Ultralytics, [NCNN](https://doc
## Convert Model to NCNN and Run Inference
The YOLOv8n model in PyTorch format is converted to NCNN to run inference with the exported model.
The YOLO11n model in PyTorch format is converted to NCNN to run inference with the exported model.
!!! example
@ -107,14 +107,14 @@ The YOLOv8n model in PyTorch format is converted to NCNN to run inference with t
```python
from ultralytics import YOLO
# Load a YOLOv8n PyTorch model
model = YOLO("yolov8n.pt")
# Load a YOLO11n PyTorch model
model = YOLO("yolo11n.pt")
# Export the model to NCNN format
model.export(format="ncnn") # creates 'yolov8n_ncnn_model'
model.export(format="ncnn") # creates 'yolo11n_ncnn_model'
# Load the exported NCNN model
ncnn_model = YOLO("yolov8n_ncnn_model")
ncnn_model = YOLO("yolo11n_ncnn_model")
# Run inference
results = ncnn_model("https://ultralytics.com/images/bus.jpg")
@ -123,102 +123,62 @@ The YOLOv8n model in PyTorch format is converted to NCNN to run inference with t
=== "CLI"
```bash
# Export a YOLOv8n PyTorch model to NCNN format
yolo export model=yolov8n.pt format=ncnn # creates 'yolov8n_ncnn_model'
# Export a YOLO11n PyTorch model to NCNN format
yolo export model=yolo11n.pt format=ncnn # creates 'yolo11n_ncnn_model'
# Run inference with the exported model
yolo predict model='yolov8n_ncnn_model' source='https://ultralytics.com/images/bus.jpg'
yolo predict model='yolo11n_ncnn_model' source='https://ultralytics.com/images/bus.jpg'
```
!!! tip
For more details about supported export options, visit the [Ultralytics documentation page on deployment options](https://docs.ultralytics.com/guides/model-deployment-options/).
## Raspberry Pi 5 vs Raspberry Pi 4 YOLOv8 Benchmarks
## Raspberry Pi 5 YOLO11 Benchmarks
YOLOv8 benchmarks were run by the Ultralytics team on nine different model formats measuring speed and [accuracy](https://www.ultralytics.com/glossary/accuracy): PyTorch, TorchScript, ONNX, OpenVINO, TF SavedModel, TF GraphDef, TF Lite, PaddlePaddle, NCNN. Benchmarks were run on both Raspberry Pi 5 and Raspberry Pi 4 at FP32 [precision](https://www.ultralytics.com/glossary/precision) with default input image size of 640.
!!! note
We have only included benchmarks for YOLOv8n and YOLOv8s models because other models sizes are too big to run on the Raspberry Pis and does not offer decent performance.
YOLO11 benchmarks were run by the Ultralytics team on nine different model formats measuring speed and [accuracy](https://www.ultralytics.com/glossary/accuracy): PyTorch, TorchScript, ONNX, OpenVINO, TF SavedModel, TF GraphDef, TF Lite, PaddlePaddle, NCNN. Benchmarks were run on a Raspberry Pi 5 at FP32 [precision](https://www.ultralytics.com/glossary/precision) with default input image size of 640.
### Comparison Chart
!!! tip "Performance"
=== "YOLOv8n"
<div style="text-align: center;">
<img width="800" src="https://github.com/ultralytics/docs/releases/download/0/yolov8n-benchmark-comparison.avif" alt="NVIDIA Jetson Ecosystem">
</div>
=== "YOLOv8s"
We have only included benchmarks for YOLO11n and YOLO11s models because other models sizes are too big to run on the Raspberry Pis and does not offer decent performance.
<div style="text-align: center;">
<img width="800" src="https://github.com/ultralytics/docs/releases/download/0/yolov8s-performance-comparison.avif" alt="NVIDIA Jetson Ecosystem">
</div>
<div style="text-align: center;">
<img width="800" src="https://github.com/ultralytics/docs/releases/download/0/rpi-yolo11-benchmarks.avif" alt="YOLO11 benchmarks on RPi 5">
</div>
### Detailed Comparison Table
The below table represents the benchmark results for two different models (YOLOv8n, YOLOv8s) across nine different formats (PyTorch, TorchScript, ONNX, OpenVINO, TF SavedModel, TF GraphDef, TF Lite, PaddlePaddle, NCNN), running on both Raspberry Pi 4 and Raspberry Pi 5, giving us the status, size, mAP50-95(B) metric, and inference time for each combination.
The below table represents the benchmark results for two different models (YOLO11n, YOLO11s) across nine different formats (PyTorch, TorchScript, ONNX, OpenVINO, TF SavedModel, TF GraphDef, TF Lite, PaddlePaddle, NCNN), running on a Raspberry Pi 5, giving us the status, size, mAP50-95(B) metric, and inference time for each combination.
!!! tip "Performance"
=== "YOLOv8n on RPi5"
| Format | Status | Size on disk (MB) | mAP50-95(B) | Inference time (ms/im) |
|---------------|--------|-------------------|-------------|------------------------|
| PyTorch | ✅ | 6.2 | 0.6381 | 508.61 |
| TorchScript | ✅ | 12.4 | 0.6092 | 558.38 |
| ONNX | ✅ | 12.2 | 0.6092 | 198.69 |
| OpenVINO | ✅ | 12.3 | 0.6092 | 704.70 |
| TF SavedModel | ✅ | 30.6 | 0.6092 | 367.64 |
| TF GraphDef | ✅ | 12.3 | 0.6092 | 473.22 |
| TF Lite | ✅ | 12.3 | 0.6092 | 380.67 |
| PaddlePaddle | ✅ | 24.4 | 0.6092 | 703.51 |
| NCNN | ✅ | 12.2 | 0.6034 | 94.28 |
=== "YOLOv8s on RPi5"
| Format | Status | Size on disk (MB) | mAP50-95(B) | Inference time (ms/im) |
|---------------|--------|-------------------|-------------|------------------------|
| PyTorch | ✅ | 21.5 | 0.6967 | 969.49 |
| TorchScript | ✅ | 43.0 | 0.7136 | 1110.04 |
| ONNX | ✅ | 42.8 | 0.7136 | 451.37 |
| OpenVINO | ✅ | 42.9 | 0.7136 | 873.51 |
| TF SavedModel | ✅ | 107.0 | 0.7136 | 658.15 |
| TF GraphDef | ✅ | 42.8 | 0.7136 | 946.01 |
| TF Lite | ✅ | 42.8 | 0.7136 | 1013.27 |
| PaddlePaddle | ✅ | 85.5 | 0.7136 | 1560.23 |
| NCNN | ✅ | 42.7 | 0.7204 | 211.26 |
=== "YOLOv8n on RPi4"
=== "YOLO11n"
| Format | Status | Size on disk (MB) | mAP50-95(B) | Inference time (ms/im) |
|---------------|--------|-------------------|-------------|------------------------|
| PyTorch | ✅ | 6.2 | 0.6381 | 1068.42 |
| TorchScript | ✅ | 12.4 | 0.6092 | 1248.01 |
| ONNX | ✅ | 12.2 | 0.6092 | 560.04 |
| OpenVINO | ✅ | 12.3 | 0.6092 | 534.93 |
| TF SavedModel | ✅ | 30.6 | 0.6092 | 816.50 |
| TF GraphDef | ✅ | 12.3 | 0.6092 | 1007.57 |
| TF Lite | ✅ | 12.3 | 0.6092 | 950.29 |
| PaddlePaddle | ✅ | 24.4 | 0.6092 | 1507.75 |
| NCNN | ✅ | 12.2 | 0.6092 | 414.73 |
=== "YOLOv8s on RPi4"
| PyTorch | ✅ | 5.4 | 0.61 | 524.828 |
| TorchScript | ✅ | 10.5 | 0.6082 | 666.874 |
| ONNX | ✅ | 10.2 | 0.6082 | 181.818 |
| OpenVINO | ✅ | 10.4 | 0.6082 | 530.224 |
| TF SavedModel | ✅ | 25.8 | 0.6082 | 405.964 |
| TF GraphDef | ✅ | 10.3 | 0.6082 | 473.558 |
| TF Lite | ✅ | 10.3 | 0.6082 | 324.158 |
| PaddlePaddle | ✅ | 20.4 | 0.6082 | 644.312 |
| NCNN | ✅ | 10.2 | 0.6106 | 93.938 |
=== "YOLO11s"
| Format | Status | Size on disk (MB) | mAP50-95(B) | Inference time (ms/im) |
|---------------|--------|-------------------|-------------|------------------------|
| PyTorch | ✅ | 21.5 | 0.6967 | 2589.58 |
| TorchScript | ✅ | 43.0 | 0.7136 | 2901.33 |
| ONNX | ✅ | 42.8 | 0.7136 | 1436.33 |
| OpenVINO | ✅ | 42.9 | 0.7136 | 1225.19 |
| TF SavedModel | ✅ | 107.0 | 0.7136 | 1770.95 |
| TF GraphDef | ✅ | 42.8 | 0.7136 | 2146.66 |
| TF Lite | ✅ | 42.8 | 0.7136 | 2945.03 |
| PaddlePaddle | ✅ | 85.5 | 0.7136 | 3962.62 |
| NCNN | ✅ | 42.7 | 0.7136 | 1042.39 |
| PyTorch | ✅ | 18.4 | 0.7526 | 1226.426 |
| TorchScript | ✅ | 36.5 | 0.7416 | 1507.95 |
| ONNX | ✅ | 36.3 | 0.7416 | 415.24 |
| OpenVINO | ✅ | 36.4 | 0.7416 | 1167.102 |
| TF SavedModel | ✅ | 91.1 | 0.7416 | 776.14 |
| TF GraphDef | ✅ | 36.4 | 0.7416 | 1014.396 |
| TF Lite | ✅ | 36.4 | 0.7416 | 845.934 |
| PaddlePaddle | ✅ | 72.5 | 0.7416 | 1567.824 |
| NCNN | ✅ | 36.2 | 0.7419 | 197.358 |
## Reproduce Our Results
@ -231,25 +191,25 @@ To reproduce the above Ultralytics benchmarks on all [export formats](../modes/e
```python
from ultralytics import YOLO
# Load a YOLOv8n PyTorch model
model = YOLO("yolov8n.pt")
# Load a YOLO11n PyTorch model
model = YOLO("yolo11n.pt")
# Benchmark YOLOv8n speed and accuracy on the COCO8 dataset for all all export formats
# Benchmark YOLO11n speed and accuracy on the COCO8 dataset for all all export formats
results = model.benchmarks(data="coco8.yaml", imgsz=640)
```
=== "CLI"
```bash
# Benchmark YOLOv8n speed and accuracy on the COCO8 dataset for all all export formats
yolo benchmark model=yolov8n.pt data=coco8.yaml imgsz=640
# Benchmark YOLO11n speed and accuracy on the COCO8 dataset for all all export formats
yolo benchmark model=yolo11n.pt data=coco8.yaml imgsz=640
```
Note that benchmarking results might vary based on the exact hardware and software configuration of a system, as well as the current workload of the system at the time the benchmarks are run. For the most reliable results use a dataset with a large number of images, i.e. `data='coco8.yaml' (4 val images), or `data='coco.yaml'` (5000 val images).
## Use Raspberry Pi Camera
When using Raspberry Pi for Computer Vision projects, it can be essentially to grab real-time video feeds to perform inference. The onboard MIPI CSI connector on the Raspberry Pi allows you to connect official Raspberry PI camera modules. In this guide, we have used a [Raspberry Pi Camera Module 3](https://www.raspberrypi.com/products/camera-module-3/) to grab the video feeds and perform inference using YOLOv8 models.
When using Raspberry Pi for Computer Vision projects, it can be essentially to grab real-time video feeds to perform inference. The onboard MIPI CSI connector on the Raspberry Pi allows you to connect official Raspberry PI camera modules. In this guide, we have used a [Raspberry Pi Camera Module 3](https://www.raspberrypi.com/products/camera-module-3/) to grab the video feeds and perform inference using YOLO11 models.
!!! tip
@ -273,13 +233,13 @@ rpicam-hello
### Inference with Camera
There are 2 methods of using the Raspberry Pi Camera to inference YOLOv8 models.
There are 2 methods of using the Raspberry Pi Camera to inference YOLO11 models.
!!! usage
=== "Method 1"
We can use `picamera2`which comes pre-installed with Raspberry Pi OS to access the camera and inference YOLOv8 models.
We can use `picamera2`which comes pre-installed with Raspberry Pi OS to access the camera and inference YOLO11 models.
!!! example
@ -299,14 +259,14 @@ There are 2 methods of using the Raspberry Pi Camera to inference YOLOv8 models.
picam2.configure("preview")
picam2.start()
# Load the YOLOv8 model
model = YOLO("yolov8n.pt")
# Load the YOLO11 model
model = YOLO("yolo11n.pt")
while True:
# Capture frame-by-frame
frame = picam2.capture_array()
# Run YOLOv8 inference on the frame
# Run YOLO11 inference on the frame
results = model(frame)
# Visualize the results on the frame
@ -340,8 +300,8 @@ There are 2 methods of using the Raspberry Pi Camera to inference YOLOv8 models.
```python
from ultralytics import YOLO
# Load a YOLOv8n PyTorch model
model = YOLO("yolov8n.pt")
# Load a YOLO11n PyTorch model
model = YOLO("yolo11n.pt")
# Run inference
results = model("tcp://127.0.0.1:8888")
@ -350,7 +310,7 @@ There are 2 methods of using the Raspberry Pi Camera to inference YOLOv8 models.
=== "CLI"
```bash
yolo predict model=yolov8n.pt source="tcp://127.0.0.1:8888"
yolo predict model=yolo11n.pt source="tcp://127.0.0.1:8888"
```
!!! tip
@ -359,7 +319,7 @@ There are 2 methods of using the Raspberry Pi Camera to inference YOLOv8 models.
## Best Practices when using Raspberry Pi
There are a couple of best practices to follow in order to enable maximum performance on Raspberry Pis running YOLOv8.
There are a couple of best practices to follow in order to enable maximum performance on Raspberry Pis running YOLO11.
1. Use an SSD
@ -371,7 +331,7 @@ There are a couple of best practices to follow in order to enable maximum perfor
## Next Steps
Congratulations on successfully setting up YOLO on your Raspberry Pi! For further learning and support, visit [Ultralytics YOLOv8 Docs](../index.md) and [Kashmir World Foundation](https://www.kashmirworldfoundation.org/).
Congratulations on successfully setting up YOLO on your Raspberry Pi! For further learning and support, visit [Ultralytics YOLO11 Docs](../index.md) and [Kashmir World Foundation](https://www.kashmirworldfoundation.org/).
## Acknowledgements and Citations
@ -381,9 +341,9 @@ For more information about Kashmir World Foundation's activities, you can visit
## FAQ
### How do I set up Ultralytics YOLOv8 on a Raspberry Pi without using Docker?
### How do I set up Ultralytics YOLO11 on a Raspberry Pi without using Docker?
To set up Ultralytics YOLOv8 on a Raspberry Pi without Docker, follow these steps:
To set up Ultralytics YOLO11 on a Raspberry Pi without Docker, follow these steps:
1. Update the package list and install `pip`:
```bash
@ -402,13 +362,13 @@ To set up Ultralytics YOLOv8 on a Raspberry Pi without Docker, follow these step
For detailed instructions, refer to the [Start without Docker](#start-without-docker) section.
### Why should I use Ultralytics YOLOv8's NCNN format on Raspberry Pi for AI tasks?
### Why should I use Ultralytics YOLO11's NCNN format on Raspberry Pi for AI tasks?
Ultralytics YOLOv8's NCNN format is highly optimized for mobile and embedded platforms, making it ideal for running AI tasks on Raspberry Pi devices. NCNN maximizes inference performance by leveraging ARM architecture, providing faster and more efficient processing compared to other formats. For more details on supported export options, visit the [Ultralytics documentation page on deployment options](../modes/export.md).
Ultralytics YOLO11's NCNN format is highly optimized for mobile and embedded platforms, making it ideal for running AI tasks on Raspberry Pi devices. NCNN maximizes inference performance by leveraging ARM architecture, providing faster and more efficient processing compared to other formats. For more details on supported export options, visit the [Ultralytics documentation page on deployment options](../modes/export.md).
### How can I convert a YOLOv8 model to NCNN format for use on Raspberry Pi?
### How can I convert a YOLO11 model to NCNN format for use on Raspberry Pi?
You can convert a PyTorch YOLOv8 model to NCNN format using either Python or CLI commands:
You can convert a PyTorch YOLO11 model to NCNN format using either Python or CLI commands:
!!! example
@ -417,14 +377,14 @@ You can convert a PyTorch YOLOv8 model to NCNN format using either Python or CLI
```python
from ultralytics import YOLO
# Load a YOLOv8n PyTorch model
model = YOLO("yolov8n.pt")
# Load a YOLO11n PyTorch model
model = YOLO("yolo11n.pt")
# Export the model to NCNN format
model.export(format="ncnn") # creates 'yolov8n_ncnn_model'
model.export(format="ncnn") # creates 'yolo11n_ncnn_model'
# Load the exported NCNN model
ncnn_model = YOLO("yolov8n_ncnn_model")
ncnn_model = YOLO("yolo11n_ncnn_model")
# Run inference
results = ncnn_model("https://ultralytics.com/images/bus.jpg")
@ -433,16 +393,16 @@ You can convert a PyTorch YOLOv8 model to NCNN format using either Python or CLI
=== "CLI"
```bash
# Export a YOLOv8n PyTorch model to NCNN format
yolo export model=yolov8n.pt format=ncnn # creates 'yolov8n_ncnn_model'
# Export a YOLO11n PyTorch model to NCNN format
yolo export model=yolo11n.pt format=ncnn # creates 'yolo11n_ncnn_model'
# Run inference with the exported model
yolo predict model='yolov8n_ncnn_model' source='https://ultralytics.com/images/bus.jpg'
yolo predict model='yolo11n_ncnn_model' source='https://ultralytics.com/images/bus.jpg'
```
For more details, see the [Use NCNN on Raspberry Pi](#use-ncnn-on-raspberry-pi) section.
### What are the hardware differences between Raspberry Pi 4 and Raspberry Pi 5 relevant to running YOLOv8?
### What are the hardware differences between Raspberry Pi 4 and Raspberry Pi 5 relevant to running YOLO11?
Key differences include:
@ -450,11 +410,11 @@ Key differences include:
- **Max CPU Frequency**: Raspberry Pi 4 has a max frequency of 1.8GHz, whereas Raspberry Pi 5 reaches 2.4GHz.
- **Memory**: Raspberry Pi 4 offers up to 8GB of LPDDR4-3200 SDRAM, while Raspberry Pi 5 features LPDDR4X-4267 SDRAM, available in 4GB and 8GB variants.
These enhancements contribute to better performance benchmarks for YOLOv8 models on Raspberry Pi 5 compared to Raspberry Pi 4. Refer to the [Raspberry Pi Series Comparison](#raspberry-pi-series-comparison) table for more details.
These enhancements contribute to better performance benchmarks for YOLO11 models on Raspberry Pi 5 compared to Raspberry Pi 4. Refer to the [Raspberry Pi Series Comparison](#raspberry-pi-series-comparison) table for more details.
### How can I set up a Raspberry Pi Camera Module to work with Ultralytics YOLOv8?
### How can I set up a Raspberry Pi Camera Module to work with Ultralytics YOLO11?
There are two methods to set up a Raspberry Pi Camera for YOLOv8 inference:
There are two methods to set up a Raspberry Pi Camera for YOLO11 inference:
1. **Using `picamera2`**:
@ -471,7 +431,7 @@ There are two methods to set up a Raspberry Pi Camera for YOLOv8 inference:
picam2.configure("preview")
picam2.start()
model = YOLO("yolov8n.pt")
model = YOLO("yolo11n.pt")
while True:
frame = picam2.capture_array()
@ -494,7 +454,7 @@ There are two methods to set up a Raspberry Pi Camera for YOLOv8 inference:
```python
from ultralytics import YOLO
model = YOLO("yolov8n.pt")
model = YOLO("yolo11n.pt")
results = model("tcp://127.0.0.1:8888")
```

@ -106,6 +106,7 @@ Monitoring workouts through pose estimation with [Ultralytics YOLO11](https://gi
| `show` | `bool` | `False` | Flag to display the image. |
| `up_angle` | `float` | `145.0` | Angle threshold for the 'up' pose. |
| `down_angle` | `float` | `90.0` | Angle threshold for the 'down' pose. |
| `model` | `str` | `None` | Path to Ultralytics YOLO Pose Model File |
### Arguments `model.predict`

@ -142,11 +142,20 @@ SAM 2 can be utilized across a broad spectrum of tasks, including real-time vide
# Display model information (optional)
model.info()
# Segment with bounding box prompt
# Run inference with bboxes prompt
results = model("path/to/image.jpg", bboxes=[100, 100, 200, 200])
# Segment with point prompt
results = model("path/to/image.jpg", points=[150, 150], labels=[1])
# Run inference with single point
results = model(points=[900, 370], labels=[1])
# Run inference with multiple points
results = model(points=[[400, 370], [900, 370]], labels=[1, 1])
# Run inference with multiple points prompt per object
results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
# Run inference with negative points prompt
results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
```
#### Segment Everything

@ -59,16 +59,16 @@ The Segment Anything Model can be employed for a multitude of downstream tasks t
results = model("ultralytics/assets/zidane.jpg", bboxes=[439, 437, 524, 709])
# Run inference with single point
results = predictor(points=[900, 370], labels=[1])
results = model(points=[900, 370], labels=[1])
# Run inference with multiple points
results = predictor(points=[[400, 370], [900, 370]], labels=[1, 1])
results = model(points=[[400, 370], [900, 370]], labels=[1, 1])
# Run inference with multiple points prompt per object
results = predictor(points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
# Run inference with negative points prompt
results = predictor(points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
```
!!! example "Segment everything"

@ -162,7 +162,7 @@ nav:
- solutions/index.md
- Guides:
- guides/index.md
- Live Inference 🚀 NEW: guides/streamlit-live-inference.md # for promotion of new pages
- YOLO11 🚀 NEW: models/yolo11.md # for promotion of new pages
- Languages:
- 🇬🇧&nbsp English: https://ultralytics.com/docs/
- 🇨🇳&nbsp 简体中文: https://docs.ultralytics.com/zh/

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.12"
__version__ = "8.3.13"
import os

@ -965,7 +965,7 @@ class Exporter:
f'--out_dir "{Path(f).parent}" '
"--show_operations "
"--search_delegate "
"--delegate_search_step 3 "
"--delegate_search_step 30 "
"--timeout_sec 180 "
f'"{tflite_model}"'
)

@ -235,7 +235,42 @@ class Predictor(BasePredictor):
"""
features = self.get_im_features(im) if self.features is None else self.features
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
points = (points, labels) if points is not None else None
# Embed prompts
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)
# Predict masks
pred_masks, pred_scores = self.model.mask_decoder(
image_embeddings=features,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
# `d` could be 1 or 3 depends on `multimask_output`.
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
"""
Prepares and transforms the input prompts for processing based on the destination shape.
Args:
dst_shape (tuple): The target shape (height, width) for the prompts.
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array.
Raises:
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
Returns:
(tuple): A tuple containing transformed bounding boxes, points, labels, and masks.
"""
src_shape = self.batch[1][0].shape[:2]
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
# Transform input prompts
if points is not None:
@ -258,23 +293,7 @@ class Predictor(BasePredictor):
bboxes *= r
if masks is not None:
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
points = (points, labels) if points is not None else None
# Embed prompts
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)
# Predict masks
pred_masks, pred_scores = self.model.mask_decoder(
image_embeddings=features,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
# `d` could be 1 or 3 depends on `multimask_output`.
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
return bboxes, points, labels, masks
def generate(
self,
@ -693,34 +712,7 @@ class SAM2Predictor(Predictor):
"""
features = self.get_im_features(im) if self.features is None else self.features
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
# Transform input prompts
if points is not None:
points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
points = points[None] if points.ndim == 1 else points
# Assuming labels are all positive if users don't pass labels.
if labels is None:
labels = torch.ones(points.shape[0])
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
points *= r
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
points, labels = points[:, None], labels[:, None]
if bboxes is not None:
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
bboxes = bboxes.view(-1, 2, 2) * r
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
# NOTE: merge "boxes" and "points" into a single "points" input
# (where boxes are added at the beginning) to model.sam_prompt_encoder
if points is not None:
points = torch.cat([bboxes, points], dim=1)
labels = torch.cat([bbox_labels, labels], dim=1)
else:
points, labels = bboxes, bbox_labels
if masks is not None:
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
points = (points, labels) if points is not None else None
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
@ -744,6 +736,36 @@ class SAM2Predictor(Predictor):
# `d` could be 1 or 3 depends on `multimask_output`.
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
"""
Prepares and transforms the input prompts for processing based on the destination shape.
Args:
dst_shape (tuple): The target shape (height, width) for the prompts.
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array.
Raises:
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
Returns:
(tuple): A tuple containing transformed bounding boxes, points, labels, and masks.
"""
bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks)
if bboxes is not None:
bboxes = bboxes.view(-1, 2, 2)
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
# NOTE: merge "boxes" and "points" into a single "points" input
# (where boxes are added at the beginning) to model.sam_prompt_encoder
if points is not None:
points = torch.cat([bboxes, points], dim=1)
labels = torch.cat([bbox_labels, labels], dim=1)
else:
points, labels = bboxes, bbox_labels
return bboxes, points, labels, masks
def set_image(self, image):
"""
Preprocesses and sets a single image for inference using the SAM2 model.

@ -8,7 +8,7 @@ from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import ClassificationModel
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
from ultralytics.utils.plotting import plot_images, plot_results
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
@ -141,7 +141,6 @@ class ClassificationTrainer(BaseTrainer):
self.metrics = self.validator(model=f)
self.metrics.pop("fitness", None)
self.run_callbacks("on_fit_epoch_end")
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
def plot_training_samples(self, batch, ni):
"""Plots training samples with their annotations."""

@ -4,55 +4,21 @@ import math
import cv2
from ultralytics.utils.checks import check_imshow
from ultralytics.solutions.solutions import BaseSolution # Import a parent class
from ultralytics.utils.plotting import Annotator, colors
class DistanceCalculation:
class DistanceCalculation(BaseSolution):
"""A class to calculate distance between two objects in a real-time video stream based on their tracks."""
def __init__(
self,
names,
view_img=False,
line_thickness=2,
line_color=(255, 0, 255),
centroid_color=(104, 31, 17),
):
"""
Initializes the DistanceCalculation class with the given parameters.
Args:
names (dict): Dictionary of classes names.
view_img (bool, optional): Flag to indicate if the video stream should be displayed. Defaults to False.
line_thickness (int, optional): Thickness of the lines drawn on the image. Defaults to 2.
line_color (tuple, optional): Color of the lines drawn on the image (BGR format). Defaults to (255, 255, 0).
centroid_color (tuple, optional): Color of the centroids drawn (BGR format). Defaults to (255, 0, 255).
"""
# Visual & image information
self.im0 = None
self.annotator = None
self.view_img = view_img
self.line_color = line_color
self.centroid_color = centroid_color
# Prediction & tracking information
self.names = names
self.boxes = None
self.line_thickness = line_thickness
self.trk_ids = None
# Distance calculation information
self.centroids = []
def __init__(self, **kwargs):
"""Initializes the DistanceCalculation class with the given parameters."""
super().__init__(**kwargs)
# Mouse event information
self.left_mouse_count = 0
self.selected_boxes = {}
# Check if environment supports imshow
self.env_check = check_imshow(warn=True)
self.window_name = "Ultralytics Solutions"
def mouse_event_for_distance(self, event, x, y, flags, param):
"""
Handles mouse events to select regions in a real-time video stream.
@ -67,7 +33,7 @@ class DistanceCalculation:
if event == cv2.EVENT_LBUTTONDOWN:
self.left_mouse_count += 1
if self.left_mouse_count <= 2:
for box, track_id in zip(self.boxes, self.trk_ids):
for box, track_id in zip(self.boxes, self.track_ids):
if box[0] < x < box[2] and box[1] < y < box[3] and track_id not in self.selected_boxes:
self.selected_boxes[track_id] = box
@ -75,30 +41,21 @@ class DistanceCalculation:
self.selected_boxes = {}
self.left_mouse_count = 0
def start_process(self, im0, tracks):
def calculate(self, im0):
"""
Processes the video frame and calculates the distance between two bounding boxes.
Args:
im0 (ndarray): The image frame.
tracks (list): List of tracks obtained from the object tracking process.
Returns:
(ndarray): The processed image frame.
"""
self.im0 = im0
if tracks[0].boxes.id is None:
if self.view_img:
self.display_frames()
return im0
self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator
self.extract_tracks(im0) # Extract tracks
self.boxes = tracks[0].boxes.xyxy.cpu()
clss = tracks[0].boxes.cls.cpu().tolist()
self.trk_ids = tracks[0].boxes.id.int().cpu().tolist()
self.annotator = Annotator(self.im0, line_width=self.line_thickness)
for box, cls, track_id in zip(self.boxes, clss, self.trk_ids):
# Iterate over bounding boxes, track ids and classes index
for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss):
self.annotator.box_label(box, color=colors(int(cls), True), label=self.names[int(cls)])
if len(self.selected_boxes) == 2:
@ -115,25 +72,11 @@ class DistanceCalculation:
pixels_distance = math.sqrt(
(self.centroids[0][0] - self.centroids[1][0]) ** 2 + (self.centroids[0][1] - self.centroids[1][1]) ** 2
)
self.annotator.plot_distance_and_line(pixels_distance, self.centroids, self.line_color, self.centroid_color)
self.annotator.plot_distance_and_line(pixels_distance, self.centroids)
self.centroids = []
if self.view_img and self.env_check:
self.display_frames()
return im0
def display_frames(self):
"""Displays the current frame with annotations."""
cv2.namedWindow(self.window_name)
cv2.setMouseCallback(self.window_name, self.mouse_event_for_distance)
cv2.imshow(self.window_name, self.im0)
if cv2.waitKey(1) & 0xFF == ord("q"):
return
self.display_output(im0) # display output with base class function
cv2.setMouseCallback("Ultralytics Solutions", self.mouse_event_for_distance)
if __name__ == "__main__":
names = {0: "person", 1: "car"} # example class names
distance_calculation = DistanceCalculation(names)
return im0 # return output image for more usage

@ -112,13 +112,13 @@ class ObjectCounter(BaseSolution):
# Iterate over bounding boxes, track ids and classes index
for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss):
# Draw bounding box and counting region
self.annotator.box_label(box, label=self.names[cls], color=colors(track_id, True))
self.annotator.box_label(box, label=self.names[cls], color=colors(cls, True))
self.store_tracking_history(track_id, box) # Store track history
self.store_classwise_counts(cls) # store classwise counts in dict
# Draw tracks of objects
self.annotator.draw_centroid_and_tracks(
self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width
self.track_line, color=colors(int(cls), True), track_thickness=self.line_width
)
# store previous position of track for object counting

@ -598,7 +598,7 @@ def ap_per_class(
# AP from recall-precision curve
for j in range(tp.shape[1]):
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
if plot and j == 0:
if j == 0:
prec_values.append(np.interp(x, mrec, mpre)) # precision at mAP@0.5
prec_values = np.array(prec_values) # (nc, 1000)

@ -804,31 +804,30 @@ class Annotator:
self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1])), 0, self.sf, txt_color, self.tf
)
def plot_distance_and_line(self, pixels_distance, centroids, line_color, centroid_color):
def plot_distance_and_line(
self, pixels_distance, centroids, line_color=(104, 31, 17), centroid_color=(255, 0, 255)
):
"""
Plot the distance and line on frame.
Args:
pixels_distance (float): Pixels distance between two bbox centroids.
centroids (list): Bounding box centroids data.
line_color (tuple): RGB distance line color.
centroid_color (tuple): RGB bounding box centroid color.
line_color (tuple, optional): Distance line color.
centroid_color (tuple, optional): Bounding box centroid color.
"""
# Get the text size
(text_width_m, text_height_m), _ = cv2.getTextSize(
f"Pixels Distance: {pixels_distance:.2f}", 0, self.sf, self.tf
)
text = f"Pixels Distance: {pixels_distance:.2f}"
(text_width_m, text_height_m), _ = cv2.getTextSize(text, 0, self.sf, self.tf)
# Define corners with 10-pixel margin and draw rectangle
top_left = (15, 25)
bottom_right = (15 + text_width_m + 20, 25 + text_height_m + 20)
cv2.rectangle(self.im, top_left, bottom_right, centroid_color, -1)
cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 20, 25 + text_height_m + 20), line_color, -1)
# Calculate the position for the text with a 10-pixel margin and draw text
text_position = (top_left[0] + 10, top_left[1] + text_height_m + 10)
text_position = (25, 25 + text_height_m + 10)
cv2.putText(
self.im,
f"Pixels Distance: {pixels_distance:.2f}",
text,
text_position,
0,
self.sf,
@ -1156,16 +1155,16 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
save_dir = Path(file).parent if file else Path(dir)
if classify:
fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)
index = [1, 4, 2, 3]
index = [2, 5, 3, 4]
elif segment:
fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
index = [1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]
index = [2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 8, 9, 12, 13]
elif pose:
fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)
index = [1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 18, 8, 9, 12, 13]
index = [2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 17, 18, 19, 9, 10, 13, 14]
else:
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
index = [1, 2, 3, 4, 5, 8, 9, 10, 6, 7]
index = [2, 3, 4, 5, 6, 9, 10, 11, 7, 8]
ax = ax.ravel()
files = list(save_dir.glob("results*.csv"))
assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."

Loading…
Cancel
Save