Add SAHI with YOLOv8 Video Inference Example (#4847)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/4857/head
parent
3c88bebc95
commit
fdf08d823e
3 changed files with 180 additions and 10 deletions
@ -0,0 +1,62 @@ |
||||
# YOLOv8 with SAHI (Inference on Video) |
||||
|
||||
[SAHI](https://docs.ultralytics.com/guides/sahi-tiled-inference/) is designed to optimize object detection algorithms for large-scale and high-resolution imagery. It partitions images into manageable slices, performs object detection on each slice, and then stitches the results back together. This tutorial will guide you through the process of running YOLOv8 inference on video files with the aid of SAHI. |
||||
|
||||
## Table of Contents |
||||
|
||||
- [Step 1: Install the Required Libraries](#step-1-install-the-required-libraries) |
||||
- [Step 2: Run the Inference with SAHI using Ultralytics YOLOv8](#step-2-run-the-inference-with-sahi-using-ultralytics-yolov8) |
||||
- [Usage Options](#usage-options) |
||||
- [FAQ](#faq) |
||||
|
||||
## Step 1: Install the Required Libraries |
||||
|
||||
Clone the repository and install the dependencies: |
||||
|
||||
```bash |
||||
pip install sahi ultralytics |
||||
``` |
||||
|
||||
## Step 2: Run the Inference with SAHI using Ultralytics YOLOv8 |
||||
|
||||
Here are the basic commands for running the inference: |
||||
|
||||
```bash |
||||
#if you want to save results |
||||
python yolov8_sahi.py --source "path/to/video.mp4" --save-img |
||||
|
||||
#if you want to change model file |
||||
python yolov8_sahi.py --source "path/to/video.mp4" --save-img --weights "yolov8n.pt" |
||||
``` |
||||
|
||||
## Usage Options |
||||
|
||||
- `--source`: Specifies the path to the video file you want to run inference on. |
||||
- `--save-img`: Flag to save the detection results as images. |
||||
- `--weights`: Specifies a different YOLOv8 model file (e.g., `yolov8n.pt`, `yolov8s.pt`, `yolov8m.pt`, `yolov8l.pt`, `yolov8x.pt`). |
||||
|
||||
## FAQ |
||||
|
||||
**1. What is SAHI?** |
||||
|
||||
SAHI stands for Slicing, Analysis, and Healing of Images. It is a library designed to optimize object detection algorithms for large-scale and high-resolution images. The library source code is available on [GitHub](https://github.com/obss/sahi). |
||||
|
||||
**2. Why use SAHI with YOLOv8?** |
||||
|
||||
SAHI can handle large-scale images by slicing them into smaller, more manageable sizes without compromising the detection quality. This makes it a great companion to YOLOv8, especially when working with high-resolution videos. |
||||
|
||||
**3. How do I debug issues?** |
||||
|
||||
You can add the `--debug` flag to your command to print out more information during inference: |
||||
|
||||
```bash |
||||
python yolov8_sahi.py --source "path to video file" --debug |
||||
``` |
||||
|
||||
**4. Can I use other YOLO versions?** |
||||
|
||||
Yes, you can specify different YOLO model weights using the `--weights` option. |
||||
|
||||
**5. Where can I find more information?** |
||||
|
||||
For a full guide to YOLOv8 with SAHI see [https://docs.ultralytics.com/guides/sahi-tiled-inference](https://docs.ultralytics.com/guides/sahi-tiled-inference/). |
@ -0,0 +1,107 @@ |
||||
import argparse |
||||
from pathlib import Path |
||||
|
||||
import cv2 |
||||
from sahi import AutoDetectionModel |
||||
from sahi.predict import get_sliced_prediction |
||||
from sahi.utils.yolov8 import download_yolov8s_model |
||||
|
||||
from ultralytics.utils.files import increment_path |
||||
|
||||
|
||||
def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False, exist_ok=False): |
||||
""" |
||||
Run object detection on a video using YOLOv8 and SAHI. |
||||
|
||||
Args: |
||||
weights (str): Model weights path. |
||||
source (str): Video file path. |
||||
view_img (bool): Show results. |
||||
save_img (bool): Save results. |
||||
exist_ok (bool): Overwrite existing files. |
||||
""" |
||||
|
||||
yolov8_model_path = f'models/{weights}' |
||||
download_yolov8s_model(yolov8_model_path) |
||||
detection_model = AutoDetectionModel.from_pretrained(model_type='yolov8', |
||||
model_path=yolov8_model_path, |
||||
confidence_threshold=0.3, |
||||
device='cpu') |
||||
|
||||
# Video setup |
||||
videocapture = cv2.VideoCapture(source) |
||||
frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4)) |
||||
fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*'mp4v') |
||||
|
||||
# Output setup |
||||
save_dir = increment_path(Path('ultralytics_results_with_sahi') / 'exp', exist_ok) |
||||
save_dir.mkdir(parents=True, exist_ok=True) |
||||
video_writer = cv2.VideoWriter(str(save_dir / f'{Path(source).stem}.mp4'), fourcc, fps, (frame_width, frame_height)) |
||||
|
||||
while videocapture.isOpened(): |
||||
success, frame = videocapture.read() |
||||
if not success: |
||||
break |
||||
|
||||
results = get_sliced_prediction(frame, |
||||
detection_model, |
||||
slice_height=512, |
||||
slice_width=512, |
||||
overlap_height_ratio=0.2, |
||||
overlap_width_ratio=0.2) |
||||
object_prediction_list = results.object_prediction_list |
||||
|
||||
boxes_list = [] |
||||
clss_list = [] |
||||
for ind, _ in enumerate(object_prediction_list): |
||||
boxes = object_prediction_list[ind].bbox.minx, object_prediction_list[ind].bbox.miny, \ |
||||
object_prediction_list[ind].bbox.maxx, object_prediction_list[ind].bbox.maxy |
||||
clss = object_prediction_list[ind].category.name |
||||
boxes_list.append(boxes) |
||||
clss_list.append(clss) |
||||
|
||||
for box, cls in zip(boxes_list, clss_list): |
||||
x1, y1, x2, y2 = box |
||||
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (56, 56, 255), 2) |
||||
label = str(cls) |
||||
t_size = cv2.getTextSize(label, 0, fontScale=0.6, thickness=1)[0] |
||||
cv2.rectangle(frame, (int(x1), int(y1) - t_size[1] - 3), (int(x1) + t_size[0], int(y1) + 3), (56, 56, 255), |
||||
-1) |
||||
cv2.putText(frame, |
||||
label, (int(x1), int(y1) - 2), |
||||
0, |
||||
0.6, [255, 255, 255], |
||||
thickness=1, |
||||
lineType=cv2.LINE_AA) |
||||
|
||||
if view_img: |
||||
cv2.imshow(Path(source).stem, frame) |
||||
if save_img: |
||||
video_writer.write(frame) |
||||
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): |
||||
break |
||||
video_writer.release() |
||||
videocapture.release() |
||||
cv2.destroyAllWindows() |
||||
|
||||
|
||||
def parse_opt(): |
||||
"""Parse command line arguments.""" |
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument('--weights', type=str, default='yolov8n.pt', help='initial weights path') |
||||
parser.add_argument('--source', type=str, required=True, help='video file path') |
||||
parser.add_argument('--view-img', action='store_true', help='show results') |
||||
parser.add_argument('--save-img', action='store_true', help='save results') |
||||
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') |
||||
return parser.parse_args() |
||||
|
||||
|
||||
def main(opt): |
||||
"""Main function.""" |
||||
run(**vars(opt)) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
opt = parse_opt() |
||||
main(opt) |
Loading…
Reference in new issue