|
|
|
@ -6,6 +6,8 @@ import time |
|
|
|
|
import cv2 |
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference(): |
|
|
|
|
"""Runs real-time object detection on video input using Ultralytics YOLOv8 in a Streamlit application.""" |
|
|
|
@ -65,28 +67,12 @@ def inference(): |
|
|
|
|
vid_file_name = 0 |
|
|
|
|
|
|
|
|
|
# Add dropdown menu for model selection |
|
|
|
|
yolov8_model = st.sidebar.selectbox( |
|
|
|
|
"Model", |
|
|
|
|
( |
|
|
|
|
"YOLOv8n", |
|
|
|
|
"YOLOv8s", |
|
|
|
|
"YOLOv8m", |
|
|
|
|
"YOLOv8l", |
|
|
|
|
"YOLOv8x", |
|
|
|
|
"YOLOv8n-Seg", |
|
|
|
|
"YOLOv8s-Seg", |
|
|
|
|
"YOLOv8m-Seg", |
|
|
|
|
"YOLOv8l-Seg", |
|
|
|
|
"YOLOv8x-Seg", |
|
|
|
|
"YOLOv8n-Pose", |
|
|
|
|
"YOLOv8s-Pose", |
|
|
|
|
"YOLOv8m-Pose", |
|
|
|
|
"YOLOv8l-Pose", |
|
|
|
|
"YOLOv8x-Pose", |
|
|
|
|
), |
|
|
|
|
) |
|
|
|
|
model = YOLO(f"{yolov8_model.lower()}.pt") # Load the yolov8 model |
|
|
|
|
class_names = list(model.names.values()) # Convert dictionary to list of class names |
|
|
|
|
available_models = (x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolov8")) |
|
|
|
|
selected_model = st.sidebar.selectbox("Model", available_models) |
|
|
|
|
with st.spinner("Model is downloading..."): |
|
|
|
|
model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model |
|
|
|
|
class_names = list(model.names.values()) # Convert dictionary to list of class names |
|
|
|
|
st.success("Model loaded successfully!") |
|
|
|
|
|
|
|
|
|
# Multiselect box with class names and get indices of selected classes |
|
|
|
|
selected_classes = st.sidebar.multiselect("Classes", class_names, default=class_names[:3]) |
|
|
|
@ -95,8 +81,9 @@ def inference(): |
|
|
|
|
if not isinstance(selected_ind, list): # Ensure selected_options is a list |
|
|
|
|
selected_ind = list(selected_ind) |
|
|
|
|
|
|
|
|
|
conf_thres = st.sidebar.slider("Confidence Threshold", 0.0, 1.0, 0.25, 0.01) |
|
|
|
|
nms_thres = st.sidebar.slider("NMS Threshold", 0.0, 1.0, 0.45, 0.01) |
|
|
|
|
enable_trk = st.sidebar.radio("Enable Tracking", ("Yes", "No")) |
|
|
|
|
conf = float(st.sidebar.slider("Confidence Threshold", 0.0, 1.0, 0.25, 0.01)) |
|
|
|
|
iou = float(st.sidebar.slider("IoU Threshold", 0.0, 1.0, 0.45, 0.01)) |
|
|
|
|
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
org_frame = col1.empty() |
|
|
|
@ -124,7 +111,10 @@ def inference(): |
|
|
|
|
prev_time = curr_time |
|
|
|
|
|
|
|
|
|
# Store model predictions |
|
|
|
|
results = model(frame, conf=float(conf_thres), iou=float(nms_thres), classes=selected_ind) |
|
|
|
|
if enable_trk: |
|
|
|
|
results = model.track(frame, conf=conf, iou=iou, classes=selected_ind, persist=True) |
|
|
|
|
else: |
|
|
|
|
results = model(frame, conf=conf, iou=iou, classes=selected_ind) |
|
|
|
|
annotated_frame = results[0].plot() # Add annotations on frame |
|
|
|
|
|
|
|
|
|
# display frame |
|
|
|
|