|
|
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
from sahi import AutoDetectionModel
|
|
|
|
from sahi.predict import get_sliced_prediction
|
|
|
|
from sahi.utils.yolov8 import download_yolov8s_model
|
|
|
|
|
|
|
|
from ultralytics.utils.files import increment_path
|
|
|
|
from ultralytics.utils.plotting import Annotator, colors
|
|
|
|
|
|
|
|
|
|
|
|
class SAHIInference:
|
|
|
|
"""Runs YOLOv8 and SAHI for object detection on video with options to view, save, and track results."""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
"""Initializes the SAHIInference class for performing sliced inference using SAHI with YOLOv8 models."""
|
|
|
|
self.detection_model = None
|
|
|
|
|
|
|
|
def load_model(self, weights):
|
|
|
|
"""Loads a YOLOv8 model with specified weights for object detection using SAHI."""
|
|
|
|
yolov8_model_path = f"models/{weights}"
|
|
|
|
download_yolov8s_model(yolov8_model_path)
|
|
|
|
self.detection_model = AutoDetectionModel.from_pretrained(
|
|
|
|
model_type="yolov8", model_path=yolov8_model_path, confidence_threshold=0.3, device="cpu"
|
|
|
|
)
|
|
|
|
|
|
|
|
def inference(
|
|
|
|
self, weights="yolov8n.pt", source="test.mp4", view_img=False, save_img=False, exist_ok=False, track=False
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Run object detection on a video using YOLOv8 and SAHI.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
weights (str): Model weights path.
|
|
|
|
source (str): Video file path.
|
|
|
|
view_img (bool): Show results.
|
|
|
|
save_img (bool): Save results.
|
|
|
|
exist_ok (bool): Overwrite existing files.
|
|
|
|
track (bool): Enable object tracking with SAHI
|
|
|
|
"""
|
|
|
|
# Video setup
|
|
|
|
cap = cv2.VideoCapture(source)
|
|
|
|
assert cap.isOpened(), "Error reading video file"
|
|
|
|
frame_width, frame_height = int(cap.get(3)), int(cap.get(4))
|
|
|
|
|
|
|
|
# Output setup
|
|
|
|
save_dir = increment_path(Path("ultralytics_results_with_sahi") / "exp", exist_ok)
|
|
|
|
save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
video_writer = cv2.VideoWriter(
|
|
|
|
str(save_dir / f"{Path(source).stem}.mp4"),
|
|
|
|
cv2.VideoWriter_fourcc(*"mp4v"),
|
|
|
|
int(cap.get(5)),
|
|
|
|
(frame_width, frame_height),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Load model
|
|
|
|
self.load_model(weights)
|
|
|
|
while cap.isOpened():
|
|
|
|
success, frame = cap.read()
|
|
|
|
if not success:
|
|
|
|
break
|
|
|
|
annotator = Annotator(frame) # Initialize annotator for plotting detection and tracking results
|
|
|
|
results = get_sliced_prediction(
|
|
|
|
frame,
|
|
|
|
self.detection_model,
|
|
|
|
slice_height=512,
|
|
|
|
slice_width=512,
|
|
|
|
overlap_height_ratio=0.2,
|
|
|
|
overlap_width_ratio=0.2,
|
|
|
|
)
|
|
|
|
detection_data = [
|
|
|
|
(det.category.name, det.category.id, (det.bbox.minx, det.bbox.miny, det.bbox.maxx, det.bbox.maxy))
|
|
|
|
for det in results.object_prediction_list
|
|
|
|
]
|
|
|
|
|
|
|
|
for det in detection_data:
|
|
|
|
annotator.box_label(det[2], label=str(det[0]), color=colors(int(det[1]), True))
|
|
|
|
|
|
|
|
if view_img:
|
|
|
|
cv2.imshow(Path(source).stem, frame)
|
|
|
|
if save_img:
|
|
|
|
video_writer.write(frame)
|
|
|
|
|
|
|
|
if cv2.waitKey(1) & 0xFF == ord("q"):
|
|
|
|
break
|
|
|
|
video_writer.release()
|
|
|
|
cap.release()
|
|
|
|
cv2.destroyAllWindows()
|
|
|
|
|
|
|
|
def parse_opt(self):
|
|
|
|
"""Parse command line arguments."""
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path")
|
|
|
|
parser.add_argument("--source", type=str, required=True, help="video file path")
|
|
|
|
parser.add_argument("--view-img", action="store_true", help="show results")
|
|
|
|
parser.add_argument("--save-img", action="store_true", help="save results")
|
|
|
|
parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment")
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
inference = SAHIInference()
|
|
|
|
inference.inference(**vars(inference.parse_opt()))
|