You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
108 lines
4.1 KiB
108 lines
4.1 KiB
1 year ago
|
import argparse
|
||
|
from pathlib import Path
|
||
|
|
||
|
import cv2
|
||
|
from sahi import AutoDetectionModel
|
||
|
from sahi.predict import get_sliced_prediction
|
||
|
from sahi.utils.yolov8 import download_yolov8s_model
|
||
|
|
||
|
from ultralytics.utils.files import increment_path
|
||
|
|
||
|
|
||
|
def run(weights='yolov8n.pt', source='test.mp4', view_img=False, save_img=False, exist_ok=False):
|
||
|
"""
|
||
|
Run object detection on a video using YOLOv8 and SAHI.
|
||
|
|
||
|
Args:
|
||
|
weights (str): Model weights path.
|
||
|
source (str): Video file path.
|
||
|
view_img (bool): Show results.
|
||
|
save_img (bool): Save results.
|
||
|
exist_ok (bool): Overwrite existing files.
|
||
|
"""
|
||
|
|
||
|
yolov8_model_path = f'models/{weights}'
|
||
|
download_yolov8s_model(yolov8_model_path)
|
||
|
detection_model = AutoDetectionModel.from_pretrained(model_type='yolov8',
|
||
|
model_path=yolov8_model_path,
|
||
|
confidence_threshold=0.3,
|
||
|
device='cpu')
|
||
|
|
||
|
# Video setup
|
||
|
videocapture = cv2.VideoCapture(source)
|
||
|
frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4))
|
||
|
fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*'mp4v')
|
||
|
|
||
|
# Output setup
|
||
|
save_dir = increment_path(Path('ultralytics_results_with_sahi') / 'exp', exist_ok)
|
||
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||
|
video_writer = cv2.VideoWriter(str(save_dir / f'{Path(source).stem}.mp4'), fourcc, fps, (frame_width, frame_height))
|
||
|
|
||
|
while videocapture.isOpened():
|
||
|
success, frame = videocapture.read()
|
||
|
if not success:
|
||
|
break
|
||
|
|
||
|
results = get_sliced_prediction(frame,
|
||
|
detection_model,
|
||
|
slice_height=512,
|
||
|
slice_width=512,
|
||
|
overlap_height_ratio=0.2,
|
||
|
overlap_width_ratio=0.2)
|
||
|
object_prediction_list = results.object_prediction_list
|
||
|
|
||
|
boxes_list = []
|
||
|
clss_list = []
|
||
|
for ind, _ in enumerate(object_prediction_list):
|
||
|
boxes = object_prediction_list[ind].bbox.minx, object_prediction_list[ind].bbox.miny, \
|
||
|
object_prediction_list[ind].bbox.maxx, object_prediction_list[ind].bbox.maxy
|
||
|
clss = object_prediction_list[ind].category.name
|
||
|
boxes_list.append(boxes)
|
||
|
clss_list.append(clss)
|
||
|
|
||
|
for box, cls in zip(boxes_list, clss_list):
|
||
|
x1, y1, x2, y2 = box
|
||
|
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (56, 56, 255), 2)
|
||
|
label = str(cls)
|
||
|
t_size = cv2.getTextSize(label, 0, fontScale=0.6, thickness=1)[0]
|
||
|
cv2.rectangle(frame, (int(x1), int(y1) - t_size[1] - 3), (int(x1) + t_size[0], int(y1) + 3), (56, 56, 255),
|
||
|
-1)
|
||
|
cv2.putText(frame,
|
||
|
label, (int(x1), int(y1) - 2),
|
||
|
0,
|
||
|
0.6, [255, 255, 255],
|
||
|
thickness=1,
|
||
|
lineType=cv2.LINE_AA)
|
||
|
|
||
|
if view_img:
|
||
|
cv2.imshow(Path(source).stem, frame)
|
||
|
if save_img:
|
||
|
video_writer.write(frame)
|
||
|
|
||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||
|
break
|
||
|
video_writer.release()
|
||
|
videocapture.release()
|
||
|
cv2.destroyAllWindows()
|
||
|
|
||
|
|
||
|
def parse_opt():
|
||
|
"""Parse command line arguments."""
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('--weights', type=str, default='yolov8n.pt', help='initial weights path')
|
||
|
parser.add_argument('--source', type=str, required=True, help='video file path')
|
||
|
parser.add_argument('--view-img', action='store_true', help='show results')
|
||
|
parser.add_argument('--save-img', action='store_true', help='save results')
|
||
|
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
|
||
|
return parser.parse_args()
|
||
|
|
||
|
|
||
|
def main(opt):
|
||
|
"""Main function."""
|
||
|
run(**vars(opt))
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
opt = parse_opt()
|
||
|
main(opt)
|