|
|
|
@ -9,7 +9,7 @@ from ultralytics import Explorer |
|
|
|
|
from ultralytics.utils import ROOT, SETTINGS |
|
|
|
|
from ultralytics.utils.checks import check_requirements |
|
|
|
|
|
|
|
|
|
check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.2")) |
|
|
|
|
check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.3")) |
|
|
|
|
|
|
|
|
|
import streamlit as st |
|
|
|
|
from streamlit_select import image_select |
|
|
|
@ -94,6 +94,7 @@ def find_similar_imgs(imgs): |
|
|
|
|
similar = exp.get_similar(img=imgs, limit=st.session_state.get("limit"), return_type="arrow") |
|
|
|
|
paths = similar.to_pydict()["im_file"] |
|
|
|
|
st.session_state["imgs"] = paths |
|
|
|
|
st.session_state["res"] = similar |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def similarity_form(selected_imgs): |
|
|
|
@ -137,6 +138,7 @@ def run_sql_query(): |
|
|
|
|
exp = st.session_state["explorer"] |
|
|
|
|
res = exp.sql_query(query, return_type="arrow") |
|
|
|
|
st.session_state["imgs"] = res.to_pydict()["im_file"] |
|
|
|
|
st.session_state["res"] = res |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_ai_query(): |
|
|
|
@ -155,6 +157,7 @@ def run_ai_query(): |
|
|
|
|
st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it." |
|
|
|
|
return |
|
|
|
|
st.session_state["imgs"] = res["im_file"].to_list() |
|
|
|
|
st.session_state["res"] = res |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_explorer(): |
|
|
|
@ -195,7 +198,11 @@ def layout(): |
|
|
|
|
if st.session_state.get("error"): |
|
|
|
|
st.error(st.session_state["error"]) |
|
|
|
|
else: |
|
|
|
|
imgs = st.session_state.get("imgs") or exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"] |
|
|
|
|
if st.session_state.get("imgs"): |
|
|
|
|
imgs = st.session_state.get("imgs") |
|
|
|
|
else: |
|
|
|
|
imgs = exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"] |
|
|
|
|
st.session_state["res"] = exp.table.to_arrow() |
|
|
|
|
total_imgs, selected_imgs = len(imgs), [] |
|
|
|
|
with col1: |
|
|
|
|
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5) |
|
|
|
@ -230,17 +237,30 @@ def layout(): |
|
|
|
|
query_form() |
|
|
|
|
ai_query_form() |
|
|
|
|
if total_imgs: |
|
|
|
|
labels, boxes, masks, kpts, classes = None, None, None, None, None |
|
|
|
|
task = exp.model.task |
|
|
|
|
if st.session_state.get("display_labels"): |
|
|
|
|
labels = st.session_state.get("res").to_pydict()["labels"][start_idx : start_idx + num] |
|
|
|
|
boxes = st.session_state.get("res").to_pydict()["bboxes"][start_idx : start_idx + num] |
|
|
|
|
masks = st.session_state.get("res").to_pydict()["masks"][start_idx : start_idx + num] |
|
|
|
|
kpts = st.session_state.get("res").to_pydict()["keypoints"][start_idx : start_idx + num] |
|
|
|
|
classes = st.session_state.get("res").to_pydict()["cls"][start_idx : start_idx + num] |
|
|
|
|
imgs_displayed = imgs[start_idx : start_idx + num] |
|
|
|
|
selected_imgs = image_select( |
|
|
|
|
f"Total samples: {total_imgs}", |
|
|
|
|
images=imgs_displayed, |
|
|
|
|
use_container_width=False, |
|
|
|
|
# indices=[i for i in range(num)] if select_all else None, |
|
|
|
|
labels=labels, |
|
|
|
|
classes=classes, |
|
|
|
|
bboxes=boxes, |
|
|
|
|
masks=masks if task == "segment" else None, |
|
|
|
|
kpts=kpts if task == "pose" else None, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
with col2: |
|
|
|
|
similarity_form(selected_imgs) |
|
|
|
|
# display_labels = st.checkbox("Labels", value=False, key="display_labels") |
|
|
|
|
display_labels = st.checkbox("Labels", value=False, key="display_labels") |
|
|
|
|
utralytics_explorer_docs_callback() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|