From 3e4a581c35c4ecec538a2d8eb818d59e2234f361 Mon Sep 17 00:00:00 2001 From: Muhammad Rizwan Munawar Date: Sun, 11 Aug 2024 09:46:19 +0500 Subject: [PATCH] Optimized SAHI video inference (#15183) Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher --- .../yolov8_sahi.py | 175 +++++++++--------- 1 file changed, 83 insertions(+), 92 deletions(-) diff --git a/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py b/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py index f2b8274c9b..ae140bc57a 100644 --- a/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py +++ b/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py @@ -9,103 +9,94 @@ from sahi.predict import get_sliced_prediction from sahi.utils.yolov8 import download_yolov8s_model from ultralytics.utils.files import increment_path +from ultralytics.utils.plotting import Annotator, colors -def run(weights="yolov8n.pt", source="test.mp4", view_img=False, save_img=False, exist_ok=False): - """ - Run object detection on a video using YOLOv8 and SAHI. - - Args: - weights (str): Model weights path. - source (str): Video file path. - view_img (bool): Show results. - save_img (bool): Save results. - exist_ok (bool): Overwrite existing files. - """ - - # Check source path - if not Path(source).exists(): - raise FileNotFoundError(f"Source path '{source}' does not exist.") - - yolov8_model_path = f"models/{weights}" - download_yolov8s_model(yolov8_model_path) - detection_model = AutoDetectionModel.from_pretrained( - model_type="yolov8", model_path=yolov8_model_path, confidence_threshold=0.3, device="cpu" - ) - - # Video setup - videocapture = cv2.VideoCapture(source) - frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4)) - fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*"mp4v") - - # Output setup - save_dir = increment_path(Path("ultralytics_results_with_sahi") / "exp", exist_ok) - save_dir.mkdir(parents=True, exist_ok=True) - video_writer = cv2.VideoWriter(str(save_dir / f"{Path(source).stem}.mp4"), fourcc, fps, (frame_width, frame_height)) - - while videocapture.isOpened(): - success, frame = videocapture.read() - if not success: - break - - results = get_sliced_prediction( - frame, detection_model, slice_height=512, slice_width=512, overlap_height_ratio=0.2, overlap_width_ratio=0.2 - ) - object_prediction_list = results.object_prediction_list - - boxes_list = [] - clss_list = [] - for ind, _ in enumerate(object_prediction_list): - boxes = ( - object_prediction_list[ind].bbox.minx, - object_prediction_list[ind].bbox.miny, - object_prediction_list[ind].bbox.maxx, - object_prediction_list[ind].bbox.maxy, - ) - clss = object_prediction_list[ind].category.name - boxes_list.append(boxes) - clss_list.append(clss) - - for box, cls in zip(boxes_list, clss_list): - x1, y1, x2, y2 = box - cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (56, 56, 255), 2) - label = str(cls) - t_size = cv2.getTextSize(label, 0, fontScale=0.6, thickness=1)[0] - cv2.rectangle( - frame, (int(x1), int(y1) - t_size[1] - 3), (int(x1) + t_size[0], int(y1) + 3), (56, 56, 255), -1 - ) - cv2.putText( - frame, label, (int(x1), int(y1) - 2), 0, 0.6, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA - ) - - if view_img: - cv2.imshow(Path(source).stem, frame) - if save_img: - video_writer.write(frame) +class SahiInference: + def __init__(self): + self.detection_model = None - if cv2.waitKey(1) & 0xFF == ord("q"): - break - video_writer.release() - videocapture.release() - cv2.destroyAllWindows() - - -def parse_opt(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser() - parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path") - parser.add_argument("--source", type=str, required=True, help="video file path") - parser.add_argument("--view-img", action="store_true", help="show results") - parser.add_argument("--save-img", action="store_true", help="save results") - parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment") - return parser.parse_args() + def load_model(self, weights): + yolov8_model_path = f"models/{weights}" + download_yolov8s_model(yolov8_model_path) + self.detection_model = AutoDetectionModel.from_pretrained( + model_type="yolov8", model_path=yolov8_model_path, confidence_threshold=0.3, device="cpu" + ) + def inference( + self, weights="yolov8n.pt", source="test.mp4", view_img=False, save_img=False, exist_ok=False, track=False + ): + """ + Run object detection on a video using YOLOv8 and SAHI. + + Args: + weights (str): Model weights path. + source (str): Video file path. + view_img (bool): Show results. + save_img (bool): Save results. + exist_ok (bool): Overwrite existing files. + track (bool): Enable object tracking with SAHI + """ + # Video setup + cap = cv2.VideoCapture(source) + assert cap.isOpened(), "Error reading video file" + frame_width, frame_height = int(cap.get(3)), int(cap.get(4)) + + # Output setup + save_dir = increment_path(Path("ultralytics_results_with_sahi") / "exp", exist_ok) + save_dir.mkdir(parents=True, exist_ok=True) + video_writer = cv2.VideoWriter( + str(save_dir / f"{Path(source).stem}.mp4"), + cv2.VideoWriter_fourcc(*"mp4v"), + int(cap.get(5)), + (frame_width, frame_height), + ) -def main(opt): - """Main function.""" - run(**vars(opt)) + # Load model + self.load_model(weights) + while cap.isOpened(): + success, frame = cap.read() + if not success: + break + annotator = Annotator(frame) # Initialize annotator for plotting detection and tracking results + results = get_sliced_prediction( + frame, + self.detection_model, + slice_height=512, + slice_width=512, + overlap_height_ratio=0.2, + overlap_width_ratio=0.2, + ) + detection_data = [ + (det.category.name, det.category.id, (det.bbox.minx, det.bbox.miny, det.bbox.maxx, det.bbox.maxy)) + for det in results.object_prediction_list + ] + + for det in detection_data: + annotator.box_label(det[2], label=str(det[0]), color=colors(int(det[1]), True)) + + if view_img: + cv2.imshow(Path(source).stem, frame) + if save_img: + video_writer.write(frame) + + if cv2.waitKey(1) & 0xFF == ord("q"): + break + video_writer.release() + cap.release() + cv2.destroyAllWindows() + + def parse_opt(self): + """Parse command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path") + parser.add_argument("--source", type=str, required=True, help="video file path") + parser.add_argument("--view-img", action="store_true", help="show results") + parser.add_argument("--save-img", action="store_true", help="save results") + parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment") + return parser.parse_args() if __name__ == "__main__": - opt = parse_opt() - main(opt) + inference = SahiInference() + inference.inference(**vars(inference.parse_opt()))