import argparse from pathlib import Path import cv2 from sahi import AutoDetectionModel from sahi.predict import get_sliced_prediction from sahi.utils.yolov8 import download_yolov8s_model from ultralytics.utils.files import increment_path def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False, exist_ok=False): """ Run object detection on a video using YOLOv8 and SAHI. Args: weights (str): Model weights path. source (str): Video file path. view_img (bool): Show results. save_img (bool): Save results. exist_ok (bool): Overwrite existing files. """ # Check source path if not Path(source).exists(): raise FileNotFoundError(f"Source path '{source}' does not exist.") yolov8_model_path = f'models/{weights}' download_yolov8s_model(yolov8_model_path) detection_model = AutoDetectionModel.from_pretrained(model_type='yolov8', model_path=yolov8_model_path, confidence_threshold=0.3, device='cpu') # Video setup videocapture = cv2.VideoCapture(source) frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4)) fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*'mp4v') # Output setup save_dir = increment_path(Path('ultralytics_results_with_sahi') / 'exp', exist_ok) save_dir.mkdir(parents=True, exist_ok=True) video_writer = cv2.VideoWriter(str(save_dir / f'{Path(source).stem}.mp4'), fourcc, fps, (frame_width, frame_height)) while videocapture.isOpened(): success, frame = videocapture.read() if not success: break results = get_sliced_prediction(frame, detection_model, slice_height=512, slice_width=512, overlap_height_ratio=0.2, overlap_width_ratio=0.2) object_prediction_list = results.object_prediction_list boxes_list = [] clss_list = [] for ind, _ in enumerate(object_prediction_list): boxes = object_prediction_list[ind].bbox.minx, object_prediction_list[ind].bbox.miny, \ object_prediction_list[ind].bbox.maxx, object_prediction_list[ind].bbox.maxy clss = object_prediction_list[ind].category.name boxes_list.append(boxes) clss_list.append(clss) for box, cls in zip(boxes_list, clss_list): x1, y1, x2, y2 = box cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (56, 56, 255), 2) label = str(cls) t_size = cv2.getTextSize(label, 0, fontScale=0.6, thickness=1)[0] cv2.rectangle(frame, (int(x1), int(y1) - t_size[1] - 3), (int(x1) + t_size[0], int(y1) + 3), (56, 56, 255), -1) cv2.putText(frame, label, (int(x1), int(y1) - 2), 0, 0.6, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA) if view_img: cv2.imshow(Path(source).stem, frame) if save_img: video_writer.write(frame) if cv2.waitKey(1) & 0xFF == ord('q'): break video_writer.release() videocapture.release() cv2.destroyAllWindows() def parse_opt(): """Parse command line arguments.""" parser = argparse.ArgumentParser() parser.add_argument('--weights', type=str, default='yolov8n.pt', help='initial weights path') parser.add_argument('--source', type=str, required=True, help='video file path') parser.add_argument('--view-img', action='store_true', help='show results') parser.add_argument('--save-img', action='store_true', help='save results') parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') return parser.parse_args() def main(opt): """Main function.""" run(**vars(opt)) if __name__ == '__main__': opt = parse_opt() main(opt)