|
|
@ -9,103 +9,94 @@ from sahi.predict import get_sliced_prediction |
|
|
|
from sahi.utils.yolov8 import download_yolov8s_model |
|
|
|
from sahi.utils.yolov8 import download_yolov8s_model |
|
|
|
|
|
|
|
|
|
|
|
from ultralytics.utils.files import increment_path |
|
|
|
from ultralytics.utils.files import increment_path |
|
|
|
|
|
|
|
from ultralytics.utils.plotting import Annotator, colors |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run(weights="yolov8n.pt", source="test.mp4", view_img=False, save_img=False, exist_ok=False): |
|
|
|
class SahiInference: |
|
|
|
""" |
|
|
|
def __init__(self): |
|
|
|
Run object detection on a video using YOLOv8 and SAHI. |
|
|
|
self.detection_model = None |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
|
|
weights (str): Model weights path. |
|
|
|
|
|
|
|
source (str): Video file path. |
|
|
|
|
|
|
|
view_img (bool): Show results. |
|
|
|
|
|
|
|
save_img (bool): Save results. |
|
|
|
|
|
|
|
exist_ok (bool): Overwrite existing files. |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Check source path |
|
|
|
|
|
|
|
if not Path(source).exists(): |
|
|
|
|
|
|
|
raise FileNotFoundError(f"Source path '{source}' does not exist.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
yolov8_model_path = f"models/{weights}" |
|
|
|
|
|
|
|
download_yolov8s_model(yolov8_model_path) |
|
|
|
|
|
|
|
detection_model = AutoDetectionModel.from_pretrained( |
|
|
|
|
|
|
|
model_type="yolov8", model_path=yolov8_model_path, confidence_threshold=0.3, device="cpu" |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Video setup |
|
|
|
|
|
|
|
videocapture = cv2.VideoCapture(source) |
|
|
|
|
|
|
|
frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4)) |
|
|
|
|
|
|
|
fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*"mp4v") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Output setup |
|
|
|
|
|
|
|
save_dir = increment_path(Path("ultralytics_results_with_sahi") / "exp", exist_ok) |
|
|
|
|
|
|
|
save_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
video_writer = cv2.VideoWriter(str(save_dir / f"{Path(source).stem}.mp4"), fourcc, fps, (frame_width, frame_height)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
while videocapture.isOpened(): |
|
|
|
|
|
|
|
success, frame = videocapture.read() |
|
|
|
|
|
|
|
if not success: |
|
|
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
results = get_sliced_prediction( |
|
|
|
|
|
|
|
frame, detection_model, slice_height=512, slice_width=512, overlap_height_ratio=0.2, overlap_width_ratio=0.2 |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
object_prediction_list = results.object_prediction_list |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
boxes_list = [] |
|
|
|
|
|
|
|
clss_list = [] |
|
|
|
|
|
|
|
for ind, _ in enumerate(object_prediction_list): |
|
|
|
|
|
|
|
boxes = ( |
|
|
|
|
|
|
|
object_prediction_list[ind].bbox.minx, |
|
|
|
|
|
|
|
object_prediction_list[ind].bbox.miny, |
|
|
|
|
|
|
|
object_prediction_list[ind].bbox.maxx, |
|
|
|
|
|
|
|
object_prediction_list[ind].bbox.maxy, |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
clss = object_prediction_list[ind].category.name |
|
|
|
|
|
|
|
boxes_list.append(boxes) |
|
|
|
|
|
|
|
clss_list.append(clss) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for box, cls in zip(boxes_list, clss_list): |
|
|
|
|
|
|
|
x1, y1, x2, y2 = box |
|
|
|
|
|
|
|
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (56, 56, 255), 2) |
|
|
|
|
|
|
|
label = str(cls) |
|
|
|
|
|
|
|
t_size = cv2.getTextSize(label, 0, fontScale=0.6, thickness=1)[0] |
|
|
|
|
|
|
|
cv2.rectangle( |
|
|
|
|
|
|
|
frame, (int(x1), int(y1) - t_size[1] - 3), (int(x1) + t_size[0], int(y1) + 3), (56, 56, 255), -1 |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
cv2.putText( |
|
|
|
|
|
|
|
frame, label, (int(x1), int(y1) - 2), 0, 0.6, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if view_img: |
|
|
|
|
|
|
|
cv2.imshow(Path(source).stem, frame) |
|
|
|
|
|
|
|
if save_img: |
|
|
|
|
|
|
|
video_writer.write(frame) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if cv2.waitKey(1) & 0xFF == ord("q"): |
|
|
|
def load_model(self, weights): |
|
|
|
break |
|
|
|
yolov8_model_path = f"models/{weights}" |
|
|
|
video_writer.release() |
|
|
|
download_yolov8s_model(yolov8_model_path) |
|
|
|
videocapture.release() |
|
|
|
self.detection_model = AutoDetectionModel.from_pretrained( |
|
|
|
cv2.destroyAllWindows() |
|
|
|
model_type="yolov8", model_path=yolov8_model_path, confidence_threshold=0.3, device="cpu" |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def parse_opt(): |
|
|
|
|
|
|
|
"""Parse command line arguments.""" |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
|
|
parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path") |
|
|
|
|
|
|
|
parser.add_argument("--source", type=str, required=True, help="video file path") |
|
|
|
|
|
|
|
parser.add_argument("--view-img", action="store_true", help="show results") |
|
|
|
|
|
|
|
parser.add_argument("--save-img", action="store_true", help="save results") |
|
|
|
|
|
|
|
parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment") |
|
|
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference( |
|
|
|
|
|
|
|
self, weights="yolov8n.pt", source="test.mp4", view_img=False, save_img=False, exist_ok=False, track=False |
|
|
|
|
|
|
|
): |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
Run object detection on a video using YOLOv8 and SAHI. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
|
|
weights (str): Model weights path. |
|
|
|
|
|
|
|
source (str): Video file path. |
|
|
|
|
|
|
|
view_img (bool): Show results. |
|
|
|
|
|
|
|
save_img (bool): Save results. |
|
|
|
|
|
|
|
exist_ok (bool): Overwrite existing files. |
|
|
|
|
|
|
|
track (bool): Enable object tracking with SAHI |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
# Video setup |
|
|
|
|
|
|
|
cap = cv2.VideoCapture(source) |
|
|
|
|
|
|
|
assert cap.isOpened(), "Error reading video file" |
|
|
|
|
|
|
|
frame_width, frame_height = int(cap.get(3)), int(cap.get(4)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Output setup |
|
|
|
|
|
|
|
save_dir = increment_path(Path("ultralytics_results_with_sahi") / "exp", exist_ok) |
|
|
|
|
|
|
|
save_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
video_writer = cv2.VideoWriter( |
|
|
|
|
|
|
|
str(save_dir / f"{Path(source).stem}.mp4"), |
|
|
|
|
|
|
|
cv2.VideoWriter_fourcc(*"mp4v"), |
|
|
|
|
|
|
|
int(cap.get(5)), |
|
|
|
|
|
|
|
(frame_width, frame_height), |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def main(opt): |
|
|
|
# Load model |
|
|
|
"""Main function.""" |
|
|
|
self.load_model(weights) |
|
|
|
run(**vars(opt)) |
|
|
|
while cap.isOpened(): |
|
|
|
|
|
|
|
success, frame = cap.read() |
|
|
|
|
|
|
|
if not success: |
|
|
|
|
|
|
|
break |
|
|
|
|
|
|
|
annotator = Annotator(frame) # Initialize annotator for plotting detection and tracking results |
|
|
|
|
|
|
|
results = get_sliced_prediction( |
|
|
|
|
|
|
|
frame, |
|
|
|
|
|
|
|
self.detection_model, |
|
|
|
|
|
|
|
slice_height=512, |
|
|
|
|
|
|
|
slice_width=512, |
|
|
|
|
|
|
|
overlap_height_ratio=0.2, |
|
|
|
|
|
|
|
overlap_width_ratio=0.2, |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
detection_data = [ |
|
|
|
|
|
|
|
(det.category.name, det.category.id, (det.bbox.minx, det.bbox.miny, det.bbox.maxx, det.bbox.maxy)) |
|
|
|
|
|
|
|
for det in results.object_prediction_list |
|
|
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for det in detection_data: |
|
|
|
|
|
|
|
annotator.box_label(det[2], label=str(det[0]), color=colors(int(det[1]), True)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if view_img: |
|
|
|
|
|
|
|
cv2.imshow(Path(source).stem, frame) |
|
|
|
|
|
|
|
if save_img: |
|
|
|
|
|
|
|
video_writer.write(frame) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if cv2.waitKey(1) & 0xFF == ord("q"): |
|
|
|
|
|
|
|
break |
|
|
|
|
|
|
|
video_writer.release() |
|
|
|
|
|
|
|
cap.release() |
|
|
|
|
|
|
|
cv2.destroyAllWindows() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_opt(self): |
|
|
|
|
|
|
|
"""Parse command line arguments.""" |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
|
|
parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path") |
|
|
|
|
|
|
|
parser.add_argument("--source", type=str, required=True, help="video file path") |
|
|
|
|
|
|
|
parser.add_argument("--view-img", action="store_true", help="show results") |
|
|
|
|
|
|
|
parser.add_argument("--save-img", action="store_true", help="save results") |
|
|
|
|
|
|
|
parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment") |
|
|
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
if __name__ == "__main__": |
|
|
|
opt = parse_opt() |
|
|
|
inference = SahiInference() |
|
|
|
main(opt) |
|
|
|
inference.inference(**vars(inference.parse_opt())) |
|
|
|