Add SAHI with YOLOv8 Video Inference Example (#4847)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/4857/head
Muhammad Rizwan Munawar 1 year ago committed by GitHub
parent 3c88bebc95
commit fdf08d823e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 21
      examples/README.md
  2. 62
      examples/YOLOv8-SAHI-Inference-Video/readme.md
  3. 107
      examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py

@ -4,16 +4,17 @@ This repository features a collection of real-world applications and walkthrough
### Ultralytics YOLO Example Applications ### Ultralytics YOLO Example Applications
| Title | Format | Contributor | | Title | Format | Contributor |
| -------------------------------------------------------------------------------------------------------------- | ------------------ | ----------------------------------------------------------------------------------------- | | ----------------------------------------------------------------------------------------------------------------------------------------- | ------------------ | ----------------------------------------------------------------------------------------- |
| [YOLO ONNX Detection Inference with C++](./YOLOv8-CPP-Inference) | C++/ONNX | [Justas Bartnykas](https://github.com/JustasBart) | | [YOLO ONNX Detection Inference with C++](./YOLOv8-CPP-Inference) | C++/ONNX | [Justas Bartnykas](https://github.com/JustasBart) |
| [YOLO OpenCV ONNX Detection Python](./YOLOv8-OpenCV-ONNX-Python) | OpenCV/Python/ONNX | [Farid Inawan](https://github.com/frdteknikelektro) | | [YOLO OpenCV ONNX Detection Python](./YOLOv8-OpenCV-ONNX-Python) | OpenCV/Python/ONNX | [Farid Inawan](https://github.com/frdteknikelektro) |
| [YOLOv8 .NET ONNX ImageSharp](https://github.com/dme-compunet/YOLOv8) | C#/ONNX/ImageSharp | [Compunet](https://github.com/dme-compunet) | | [YOLOv8 .NET ONNX ImageSharp](https://github.com/dme-compunet/YOLOv8) | C#/ONNX/ImageSharp | [Compunet](https://github.com/dme-compunet) |
| [YOLO .Net ONNX Detection C#](https://www.nuget.org/packages/Yolov8.Net) | C# .Net | [Samuel Stainback](https://github.com/sstainba) | | [YOLO .Net ONNX Detection C#](https://www.nuget.org/packages/Yolov8.Net) | C# .Net | [Samuel Stainback](https://github.com/sstainba) |
| [YOLOv8 on NVIDIA Jetson(TensorRT and DeepStream)](https://wiki.seeedstudio.com/YOLOv8-DeepStream-TRT-Jetson/) | Python | [Lakshantha](https://github.com/lakshanthad) | | [YOLOv8 on NVIDIA Jetson(TensorRT and DeepStream)](https://wiki.seeedstudio.com/YOLOv8-DeepStream-TRT-Jetson/) | Python | [Lakshantha](https://github.com/lakshanthad) |
| [YOLOv8 ONNXRuntime Python](./YOLOv8-ONNXRuntime) | Python/ONNXRuntime | [Semih Demirel](https://github.com/semihhdemirel) | | [YOLOv8 ONNXRuntime Python](./YOLOv8-ONNXRuntime) | Python/ONNXRuntime | [Semih Demirel](https://github.com/semihhdemirel) |
| [YOLOv8-ONNXRuntime-CPP](./YOLOv8-ONNXRuntime-CPP) | C++/ONNXRuntime | [DennisJcy](https://github.com/DennisJcy), [Onuralp Sezer](https://github.com/onuralpszr) | | [YOLOv8 ONNXRuntime CPP](./YOLOv8-ONNXRuntime-CPP) | C++/ONNXRuntime | [DennisJcy](https://github.com/DennisJcy), [Onuralp Sezer](https://github.com/onuralpszr) |
| [RTDETR ONNXRuntime C#](https://github.com/Kayzwer/yolo-cs/blob/master/RTDETR.cs) | C#/ONNX | [Kayzwer](https://github.com/Kayzwer) | | [RTDETR ONNXRuntime C#](https://github.com/Kayzwer/yolo-cs/blob/master/RTDETR.cs) | C#/ONNX | [Kayzwer](https://github.com/Kayzwer) |
| [YOLOv8 SAHI Video Inference](https://github.com/RizwanMunawar/ultralytics/blob/main/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py) | Python | [Muhammad Rizwan Munawar](https://github.com/RizwanMunawar) |
### How to Contribute ### How to Contribute

@ -0,0 +1,62 @@
# YOLOv8 with SAHI (Inference on Video)
[SAHI](https://docs.ultralytics.com/guides/sahi-tiled-inference/) is designed to optimize object detection algorithms for large-scale and high-resolution imagery. It partitions images into manageable slices, performs object detection on each slice, and then stitches the results back together. This tutorial will guide you through the process of running YOLOv8 inference on video files with the aid of SAHI.
## Table of Contents
- [Step 1: Install the Required Libraries](#step-1-install-the-required-libraries)
- [Step 2: Run the Inference with SAHI using Ultralytics YOLOv8](#step-2-run-the-inference-with-sahi-using-ultralytics-yolov8)
- [Usage Options](#usage-options)
- [FAQ](#faq)
## Step 1: Install the Required Libraries
Clone the repository and install the dependencies:
```bash
pip install sahi ultralytics
```
## Step 2: Run the Inference with SAHI using Ultralytics YOLOv8
Here are the basic commands for running the inference:
```bash
#if you want to save results
python yolov8_sahi.py --source "path/to/video.mp4" --save-img
#if you want to change model file
python yolov8_sahi.py --source "path/to/video.mp4" --save-img --weights "yolov8n.pt"
```
## Usage Options
- `--source`: Specifies the path to the video file you want to run inference on.
- `--save-img`: Flag to save the detection results as images.
- `--weights`: Specifies a different YOLOv8 model file (e.g., `yolov8n.pt`, `yolov8s.pt`, `yolov8m.pt`, `yolov8l.pt`, `yolov8x.pt`).
## FAQ
**1. What is SAHI?**
SAHI stands for Slicing, Analysis, and Healing of Images. It is a library designed to optimize object detection algorithms for large-scale and high-resolution images. The library source code is available on [GitHub](https://github.com/obss/sahi).
**2. Why use SAHI with YOLOv8?**
SAHI can handle large-scale images by slicing them into smaller, more manageable sizes without compromising the detection quality. This makes it a great companion to YOLOv8, especially when working with high-resolution videos.
**3. How do I debug issues?**
You can add the `--debug` flag to your command to print out more information during inference:
```bash
python yolov8_sahi.py --source "path to video file" --debug
```
**4. Can I use other YOLO versions?**
Yes, you can specify different YOLO model weights using the `--weights` option.
**5. Where can I find more information?**
For a full guide to YOLOv8 with SAHI see [https://docs.ultralytics.com/guides/sahi-tiled-inference](https://docs.ultralytics.com/guides/sahi-tiled-inference/).

@ -0,0 +1,107 @@
import argparse
from pathlib import Path
import cv2
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from sahi.utils.yolov8 import download_yolov8s_model
from ultralytics.utils.files import increment_path
def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False, exist_ok=False):
"""
Run object detection on a video using YOLOv8 and SAHI.
Args:
weights (str): Model weights path.
source (str): Video file path.
view_img (bool): Show results.
save_img (bool): Save results.
exist_ok (bool): Overwrite existing files.
"""
yolov8_model_path = f'models/{weights}'
download_yolov8s_model(yolov8_model_path)
detection_model = AutoDetectionModel.from_pretrained(model_type='yolov8',
model_path=yolov8_model_path,
confidence_threshold=0.3,
device='cpu')
# Video setup
videocapture = cv2.VideoCapture(source)
frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4))
fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*'mp4v')
# Output setup
save_dir = increment_path(Path('ultralytics_results_with_sahi') / 'exp', exist_ok)
save_dir.mkdir(parents=True, exist_ok=True)
video_writer = cv2.VideoWriter(str(save_dir / f'{Path(source).stem}.mp4'), fourcc, fps, (frame_width, frame_height))
while videocapture.isOpened():
success, frame = videocapture.read()
if not success:
break
results = get_sliced_prediction(frame,
detection_model,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.2,
overlap_width_ratio=0.2)
object_prediction_list = results.object_prediction_list
boxes_list = []
clss_list = []
for ind, _ in enumerate(object_prediction_list):
boxes = object_prediction_list[ind].bbox.minx, object_prediction_list[ind].bbox.miny, \
object_prediction_list[ind].bbox.maxx, object_prediction_list[ind].bbox.maxy
clss = object_prediction_list[ind].category.name
boxes_list.append(boxes)
clss_list.append(clss)
for box, cls in zip(boxes_list, clss_list):
x1, y1, x2, y2 = box
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (56, 56, 255), 2)
label = str(cls)
t_size = cv2.getTextSize(label, 0, fontScale=0.6, thickness=1)[0]
cv2.rectangle(frame, (int(x1), int(y1) - t_size[1] - 3), (int(x1) + t_size[0], int(y1) + 3), (56, 56, 255),
-1)
cv2.putText(frame,
label, (int(x1), int(y1) - 2),
0,
0.6, [255, 255, 255],
thickness=1,
lineType=cv2.LINE_AA)
if view_img:
cv2.imshow(Path(source).stem, frame)
if save_img:
video_writer.write(frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
video_writer.release()
videocapture.release()
cv2.destroyAllWindows()
def parse_opt():
"""Parse command line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='yolov8n.pt', help='initial weights path')
parser.add_argument('--source', type=str, required=True, help='video file path')
parser.add_argument('--view-img', action='store_true', help='show results')
parser.add_argument('--save-img', action='store_true', help='save results')
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
return parser.parse_args()
def main(opt):
"""Main function."""
run(**vars(opt))
if __name__ == '__main__':
opt = parse_opt()
main(opt)
Loading…
Cancel
Save