`ultralytics 8.0.238` Explorer Ask AI feature and fixes (#7408)

Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
Co-authored-by: uwer <uwe.rosebrock@gmail.com>
Co-authored-by: Uwe Rosebrock <ro260@csiro.au>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing-q <1182102784@qq.com>
Co-authored-by: Muhammad Rizwan Munawar <chr043416@gmail.com>
Co-authored-by: AdamP <adamp87hun@gmail.com>
pull/5673/merge v8.0.238
Glenn Jocher 10 months ago committed by GitHub
parent e76754eab0
commit 783033fa6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 30
      docs/en/datasets/explorer/api.md
  2. 0
      docs/en/datasets/explorer/dash.md
  3. 58
      docs/en/datasets/explorer/dashboard.md
  4. 70
      docs/en/datasets/explorer/explorer.ipynb
  5. 17
      docs/en/datasets/explorer/index.md
  6. 2
      docs/en/guides/heatmaps.md
  7. 20
      docs/en/guides/instance-segmentation-and-tracking.md
  8. 2
      docs/mkdocs.yml
  9. 2
      ultralytics/__init__.py
  10. 3
      ultralytics/data/explorer/__init__.py
  11. 59
      ultralytics/data/explorer/explorer.py
  12. 68
      ultralytics/data/explorer/gui/dash.py
  13. 70
      ultralytics/data/explorer/utils.py
  14. 2
      ultralytics/engine/model.py
  15. 3
      ultralytics/models/yolo/obb/predict.py
  16. 14
      ultralytics/nn/modules/head.py
  17. 6
      ultralytics/solutions/distance_calculation.py
  18. 36
      ultralytics/solutions/heatmap.py
  19. 1
      ultralytics/utils/__init__.py

@ -119,7 +119,31 @@ You can also plot the similar images using the `plot_similar` method. This metho
plt.show()
```
## 2. SQL Querying
## 2. Ask AI (Natural Language Querying)
This allows you to write how you want to filter your dataset using natural language. You don't have to be proficient in writing SQL queries. Our AI powered query generator will automatically do that under the hood. For example - you can say - "show me 100 images with exactly one person and 2 dogs. There can be other objects too" and it'll internally generate the query and show you those results.
Note: This works using LLMs under the hood so the results are probabilistic and might get things wrong sometimes
!!! Example "Ask AI"
```python
from ultralytics import Explorer
from ultralytics.data.explorer import plot_query_result
# create an Explorer object
exp = Explorer(data='coco128.yaml', model='yolov8n.pt')
exp.create_embeddings_table()
df = exp.ask_ai("show me 100 images with exactly one person and 2 dogs. There can be other objects too")
print(df.head())
# plot the results
plt = plot_query_result(df)
plt.show()
```
## 3. SQL Querying
You can run SQL queries on your dataset using the `sql_query` method. This method takes a SQL query as input and returns a pandas dataframe with the results.
@ -153,7 +177,7 @@ You can also plot the results of a SQL query using the `plot_sql_query` method.
print(df.head())
```
## 3. Working with embeddings Table (Advanced)
## 4. Working with embeddings Table (Advanced)
You can also work with the embeddings table directly. Once the embeddings table is created, you can access it using the `Explorer.table`
@ -210,7 +234,7 @@ When using large datasets, you can also create a dedicated vector index for fast
Find more details on the type vector indices available and parameters [here](https://lancedb.github.io/lancedb/ann_indexes/#types-of-index) In the future, we will add support for creating vector indices directly from Explorer API.
## 4. Embeddings Applications
## 5. Embeddings Applications
You can use the embeddings table to perform a variety of exploratory analysis. Here are some examples:

@ -0,0 +1,58 @@
---
comments: 5rue
description: Learn about Ultralytics Explorer GUI for semantic search, SQL queries, and AI-powered natural language search in CV datasets.
keywords: Ultralytics, Explorer GUI, semantic search, vector similarity search, AI queries, SQL queries, computer vision, dataset exploration, image search, OpenAI integration
---
# Explorer GUI
Explorer GUI is like a playground build using (Ultralytics Explorer API)[api.md]. It allows you to run semantic/vector similarity search, SQL queries and even search using natural language using our ask AI feature powered by LLMs.
### Installation
```bash
pip install ultralytics[explorer]
```
!!! note "Note"
Ask AI feature works using OpenAI, so you'll be prompted to set the api key for OpenAI when you first run the GUI.
You can set it like this - `yolo settings openai_api_key="..."`
## Semantic Search / Vector Similarity Search
Semantic search is a technique for finding similar images to a given image. It is based on the idea that similar images will have similar embeddings. In the UI, you can select one of more images and search for the images similar to them. This can be useful when you want to find images similar to a given image or a set of images that don't perform as expected.
For example:
In this VOC Exploration dashboard, user selects a couple aeroplane images like this:
<p>
<img width="1710" alt="Screenshot 2024-01-08 at 8 46 33PM" src="https://github.com/AyushExel/assets/assets/15766192/da5f1b0a-9eb5-4712-919c-7d5512240dd8">
</p>
On performing similarity search, you should see a similar result:
<p>
<img width="1710" alt="Screenshot 2024-01-08 at 8 46 46PM" src="https://github.com/AyushExel/assets/assets/15766192/5e4c6445-8e4e-48bb-a15a-9fb6c6994af8">
</p>
## Ask AI
This allows you to write how you want to filter your dataset using natural language. You don't have to be proficient in writing SQL queries. Our AI powered query generator will automatically do that under the hood. For example - you can say - "show me 100 images with exactly one person and 2 dogs. There can be other objects too" and it'll internally generate the query and show you those results. Here's an example output when asked to "Show 10 images with exactly 5 persons" and you'll see a result like this:
<p>
<img width="1709" alt="Screenshot 2024-01-08 at 7 19 48PM (1)" src="https://github.com/AyushExel/assets/assets/15766192/e536b0eb-6bce-43fe-b800-3e79510d2e5b">
</p>
Note: This works using LLMs under the hood so the results are probabilistic and might get things wrong sometimes
## Run SQL queries on your CV datasets
You can run SQL queries on your dataset to filter it. It also works if you only provide the WHERE clause. Example SQL query would show only the images that have at least one 1 person and 1 dog in them:
```sql
WHERE labels LIKE '%person%' AND labels LIKE '%dog%'
```
<p>
<img width="1707" alt="Screenshot 2024-01-08 at 8 57 49PM" src="https://github.com/AyushExel/assets/assets/15766192/71619e16-4db9-4fdb-b951-0d1fdbf59a6a">
</p>
This is a Demo build using the Explorer API. You can use the API to build your own exploratory notebooks or scripts to get insights into your datasets. Learn more about the Explorer API [here](api.md).

@ -109,7 +109,10 @@
"metadata": {},
"source": [
"You can use the also plot the similar samples directly using the `plot_similar` util\n",
"<img width=\"689\" alt=\"Screenshot 2024-01-06 at 9 46 48PM\" src=\"https://github.com/AyushExel/assets/assets/15766192/70e1a4c4-6c67-4664-b77a-ad27b1fba8f8\">\n"
"<p>\n",
"\n",
" <img src=\"https://github.com/AyushExel/assets/assets/15766192/a3c9247b-9271-47df-aaa5-36d96c5034b1\" />\n",
"</p>\n"
]
},
{
@ -139,17 +142,74 @@
"metadata": {},
"source": [
"<p>\n",
"<img width=\"766\" alt=\"Screenshot 2024-01-06 at 10 05 10PM\" src=\"https://github.com/AyushExel/assets/assets/15766192/faa9c544-d96b-4528-a2ea-95c5d8856744\">\n",
"<img src=\"https://github.com/AyushExel/assets/assets/15766192/8e011195-b0da-43ef-b3cd-5fb6f383037e\">\n",
"\n",
"</p>"
]
},
{
"cell_type": "markdown",
"id": "0cea63f1-71f1-46da-af2b-b1b7d8f73553",
"metadata": {},
"source": [
"## 2. Ask AI: Search or filter with Natural Language\n",
"You can prompt the Explorer object with the kind of data points you want to see and it'll try to return a dataframe with those. Because it is powered by LLMs, it doesn't always get it right. In that case, it'll return None.\n",
"<p>\n",
"<img width=\"1131\" alt=\"Screenshot 2024-01-07 at 2 34 53PM\" src=\"https://github.com/AyushExel/assets/assets/15766192/c4a69fd9-e54f-4d6a-aba5-dc9cfae1bc04\">\n",
"\n",
"</p>\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "92fb92ac-7f76-465a-a9ba-ea7492498d9c",
"metadata": {},
"outputs": [],
"source": [
"df = exp.ask_ai(\"show me images containing more than 10 objects with at least 2 persons\")\n",
"df.head(5)"
]
},
{
"cell_type": "markdown",
"id": "f2a7d26e-0ce5-4578-ad1a-b1253805280f",
"metadata": {},
"source": [
"for plotting these results you can use `plot_query_result` util\n",
"Example:\n",
"```\n",
"plt = plot_query_result(exp.ask_ai(\"show me 10 images containing exactly 2 persons\"))\n",
"Image.fromarray(plt)\n",
"```\n",
"<p>\n",
" <img src=\"https://github.com/AyushExel/assets/assets/15766192/2cb780de-d05b-4412-a526-7f7f0f10e669\">\n",
"\n",
"</p>"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1cfab84-9835-4da0-8e9a-42b30cf84511",
"metadata": {},
"outputs": [],
"source": [
"# plot\n",
"from ultralytics.data.explorer import plot_query_result\n",
"from PIL import Image\n",
"\n",
"plt = plot_query_result(exp.ask_ai(\"show me 10 images containing exactly 2 persons\"))\n",
"Image.fromarray(plt)"
]
},
{
"cell_type": "markdown",
"id": "35315ae6-d827-40e4-8813-279f97a83b34",
"metadata": {},
"source": [
"## 2. Run SQL queries on your Dataset!\n",
"## 3. Run SQL queries on your Dataset!\n",
"Sometimes you might want to investigate a certain type of entries in your dataset. For this Explorer allows you to execute SQL queries.\n",
"It accepts either of the formats:\n",
"- Queries beginning with \"WHERE\" will automatically select all columns. This can be thought of as a short-hand query\n",
@ -179,7 +239,7 @@
"metadata": {},
"source": [
"Just like similarity search, you also get a util to directly plot the sql queries using `exp.plot_sql_query`\n",
"<img width=\"771\" alt=\"Screenshot 2024-01-06 at 9 48 08PM\" src=\"https://github.com/AyushExel/assets/assets/15766192/332f5acd-3a4e-462d-a281-5d5effd1886e\">\n"
"<img src=\"https://github.com/AyushExel/assets/assets/15766192/f8b66629-8dd0-419e-8f44-9837969ba678\">\n"
]
},
{
@ -419,7 +479,7 @@
"metadata": {},
"source": [
"You should see something like this\n",
"<img width=\"897\" alt=\"Screenshot 2024-01-06 at 9 50 48PM\" src=\"https://github.com/AyushExel/assets/assets/15766192/5d3f0e35-2ad4-4a67-8df7-3a4c17867b72\">\n"
"<img src=\"https://github.com/AyushExel/assets/assets/15766192/649bc366-ca2d-46ea-bfd9-3097cf575584\">\n"
]
},
{

@ -16,6 +16,12 @@ Explorer depends on external libraries for some of its functionality. These are
pip install ultralytics[explorer]
```
### Explorer API
This is a Python API for Exploring your datasets. It also powers the GUI Explorer. You can use this to create your own exploratory notebooks or scripts to get insights into your datasets.
Learn more about the Explorer API [here](api.md).
## GUI Explorer Usage
The GUI demo runs in your browser allowing you to create embeddings for your dataset and search for similar images, run SQL queries and perform semantic search. It can be run using the following command:
@ -24,8 +30,11 @@ The GUI demo runs in your browser allowing you to create embeddings for your dat
yolo explorer
```
### Explorer API
!!! note "Note"
Ask AI feature works using OpenAI, so you'll be prompted to set the api key for OpenAI when you first run the GUI.
You can set it like this - `yolo settings openai_api_key="..."`
This is a Python API for Exploring your datasets. It also powers the GUI Explorer. You can use this to create your own exploratory notebooks or scripts to get insights into your datasets.
Learn more about the Explorer API [here](api.md).
Example
<p>
<img width="1709" alt="Screenshot 2024-01-08 at 7 19 48PM (1)" src="https://github.com/AyushExel/assets/assets/15766192/e536b0eb-6bce-43fe-b800-3e79510d2e5b">
</p>

@ -99,7 +99,7 @@ A heatmap generated with [Ultralytics YOLOv8](https://github.com/ultralytics/ult
fps,
(w, h))
line_points = [(256, 409), (694, 532)] # line for object counting
line_points = [(20, 400), (1080, 404)] # line for object counting
# Init heatmap
heatmap_obj = heatmap.Heatmap()

@ -31,7 +31,7 @@ There are two types of instance segmentation tracking available in the Ultralyti
from ultralytics import YOLO
from ultralytics.utils.plotting import Annotator, colors
model = YOLO("yolov8n-seg.pt")
model = YOLO("yolov8n-seg.pt") # segmentation model
names = model.model.names
cap = cv2.VideoCapture("path/to/video/file.mp4")
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
@ -45,15 +45,15 @@ There are two types of instance segmentation tracking available in the Ultralyti
break
results = model.predict(im0)
clss = results[0].boxes.cls.cpu().tolist()
masks = results[0].masks.xy
annotator = Annotator(im0, line_width=2)
for mask, cls in zip(masks, clss):
annotator.seg_bbox(mask=mask,
mask_color=colors(int(cls), True),
det_label=names[int(cls)])
if results[0].masks is not None:
clss = results[0].boxes.cls.cpu().tolist()
masks = results[0].masks.xy
for mask, cls in zip(masks, clss):
annotator.seg_bbox(mask=mask,
mask_color=colors(int(cls), True),
det_label=names[int(cls)])
out.write(im0)
cv2.imshow("instance-segmentation", im0)
@ -77,7 +77,7 @@ There are two types of instance segmentation tracking available in the Ultralyti
track_history = defaultdict(lambda: [])
model = YOLO("yolov8n-seg.pt")
model = YOLO("yolov8n-seg.pt") # segmentation model
cap = cv2.VideoCapture("path/to/video/file.mp4")
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
@ -93,7 +93,7 @@ There are two types of instance segmentation tracking available in the Ultralyti
results = model.track(im0, persist=True)
if results[0].boxes.id is not None:
if results[0].boxes.id is not None and results[0].masks is not None:
masks = results[0].masks.xy
track_ids = results[0].boxes.id.int().cpu().tolist()

@ -222,7 +222,7 @@ nav:
- Explorer:
- datasets/explorer/index.md
- Explorer API: datasets/explorer/api.md
- GUI Dashboard Demo: datasets/explorer/dash.md
- Explorer Dashboard: datasets/explorer/dashboard.md
- VOC Exploration Example: datasets/explorer/explorer.ipynb
- Detection:
- datasets/detect/index.md

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.237'
__version__ = '8.0.238'
from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO

@ -0,0 +1,3 @@
from .utils import plot_query_result
__all__ = ['plot_query_result']

@ -16,7 +16,7 @@ from ultralytics.data.utils import check_det_dataset
from ultralytics.models.yolo.model import YOLO
from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks
from .utils import get_sim_index_schema, get_table_schema, plot_similar_images, sanitize_batch
from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
class ExplorerDataset(YOLODataset):
@ -58,7 +58,7 @@ class Explorer:
data: Union[str, Path] = 'coco128.yaml',
model: str = 'yolov8n.pt',
uri: str = '~/ultralytics/explorer') -> None:
checks.check_requirements(['lancedb', 'duckdb'])
checks.check_requirements(['lancedb>=0.4.3', 'duckdb'])
import lancedb
self.connection = lancedb.connect(uri)
@ -112,8 +112,7 @@ class Explorer:
# Create the table schema
batch = dataset[0]
vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0]
Schema = get_table_schema(vector_size)
table = self.connection.create_table(self.table_name, schema=Schema, mode='overwrite')
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode='overwrite')
table.add(
self._yield_batches(dataset,
data_info,
@ -159,10 +158,7 @@ class Explorer:
raise ValueError('Table is not created. Please create the table first.')
if isinstance(imgs, str):
imgs = [imgs]
elif isinstance(imgs, list):
pass
else:
raise ValueError(f'img must be a string or a list of strings. Got {type(imgs)}')
assert isinstance(imgs, list), f'img must be a string or a list of strings. Got {type(imgs)}'
embeds = self.model.embed(imgs)
# Get avg if multiple images are passed (len > 1)
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
@ -189,16 +185,19 @@ class Explorer:
result = exp.sql_query(query)
```
"""
assert return_type in ['pandas',
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
import duckdb
if self.table is None:
raise ValueError('Table is not created. Please create the table first.')
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
table = self.table.to_arrow() # noqa
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
if not query.startswith('SELECT') and not query.startswith('WHERE'):
raise ValueError(
'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause.')
f'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}'
)
if query.startswith('WHERE'):
query = f"SELECT * FROM 'table' {query}"
LOGGER.info(f'Running query: {query}')
@ -228,7 +227,10 @@ class Explorer:
```
"""
result = self.sql_query(query, return_type='arrow')
img = plot_similar_images(result, plot_labels=labels)
if len(result) == 0:
LOGGER.info('No results found.')
return None
img = plot_query_result(result, plot_labels=labels)
return Image.fromarray(img)
def get_similar(self,
@ -255,6 +257,8 @@ class Explorer:
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
```
"""
assert return_type in ['pandas',
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
img = self._check_imgs_or_idxs(img, idx)
similar = self.query(img, limit=limit)
@ -288,7 +292,10 @@ class Explorer:
```
"""
similar = self.get_similar(img, idx, limit, return_type='arrow')
img = plot_similar_images(similar, plot_labels=labels)
if len(similar) == 0:
LOGGER.info('No results found.')
return None
img = plot_query_result(similar, plot_labels=labels)
return Image.fromarray(img)
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
@ -299,7 +306,7 @@ class Explorer:
Args:
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
vector search. Defaults to 0.01.
vector search. Defaults: None.
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
Returns:
@ -401,6 +408,32 @@ class Explorer:
return img if isinstance(img, list) else [img]
def ask_ai(self, query):
"""
Ask AI a question.
Args:
query (str): Question to ask.
Returns:
Answer from AI.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
answer = exp.ask_ai('Show images with 1 person and 2 dogs')
```
"""
result = prompt_sql_query(query)
try:
df = self.sql_query(result)
except Exception as e:
LOGGER.error('AI generated query is not valid. Please try again with a different prompt')
LOGGER.error(e)
return None
return df
def visualize(self, result):
"""
Visualize the results of a query.

@ -1,11 +1,13 @@
import time
from threading import Thread
import pandas as pd
from ultralytics import Explorer
from ultralytics.utils import ROOT
from ultralytics.utils import ROOT, SETTINGS
from ultralytics.utils.checks import check_requirements
check_requirements('streamlit')
check_requirements('streamlit>=1.29.0')
check_requirements('streamlit-select>=0.2')
import streamlit as st
from streamlit_select import image_select
@ -35,9 +37,9 @@ def init_explorer_form():
with st.form(key='explorer_init_form'):
col1, col2 = st.columns(2)
with col1:
dataset = st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
with col2:
model = st.selectbox('Select model', models, key='model')
st.selectbox('Select model', models, key='model')
st.checkbox('Force recreate embeddings', key='force_recreate_embeddings')
st.form_submit_button('Explore', on_click=_get_explorer)
@ -47,11 +49,23 @@ def query_form():
with st.form('query_form'):
col1, col2 = st.columns([0.8, 0.2])
with col1:
query = st.text_input('Query', '', label_visibility='collapsed', key='query')
st.text_input('Query',
"WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
label_visibility='collapsed',
key='query')
with col2:
st.form_submit_button('Query', on_click=run_sql_query)
def ai_query_form():
with st.form('ai_query_form'):
col1, col2 = st.columns([0.8, 0.2])
with col1:
st.text_input('Query', 'Show images with 1 person and 1 dog', label_visibility='collapsed', key='ai_query')
with col2:
st.form_submit_button('Ask AI', on_click=run_ai_query)
def find_similar_imgs(imgs):
exp = st.session_state['explorer']
similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow')
@ -64,12 +78,12 @@ def similarity_form(selected_imgs):
with st.form('similarity_form'):
subcol1, subcol2 = st.columns([1, 1])
with subcol1:
limit = st.number_input('limit',
min_value=None,
max_value=None,
value=25,
label_visibility='collapsed',
key='limit')
st.number_input('limit',
min_value=None,
max_value=None,
value=25,
label_visibility='collapsed',
key='limit')
with subcol2:
disabled = not len(selected_imgs)
@ -95,6 +109,7 @@ def similarity_form(selected_imgs):
def run_sql_query():
st.session_state['error'] = None
query = st.session_state.get('query')
if query.rstrip().lstrip():
exp = st.session_state['explorer']
@ -102,9 +117,26 @@ def run_sql_query():
st.session_state['imgs'] = res.to_pydict()['im_file']
def run_ai_query():
if not SETTINGS['openai_api_key']:
st.session_state[
'error'] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
return
st.session_state['error'] = None
query = st.session_state.get('ai_query')
if query.rstrip().lstrip():
exp = st.session_state['explorer']
res = exp.ask_ai(query)
if not isinstance(res, pd.DataFrame) or res.empty:
st.session_state['error'] = 'No results found using AI generated query. Try another query or rerun it.'
return
st.session_state['imgs'] = res['im_file'].to_list()
def reset_explorer():
st.session_state['explorer'] = None
st.session_state['imgs'] = None
st.session_state['error'] = None
def utralytics_explorer_docs_callback():
@ -112,10 +144,10 @@ def utralytics_explorer_docs_callback():
st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg',
width=100)
st.markdown(
"<p>This demo is built using Ultralytics Explorer API. Visit <a href=''>API docs</a> to try examples & learn more</p>",
"<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>",
unsafe_allow_html=True,
help=None)
st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/')
st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/datasets/explorer/')
def layout():
@ -129,9 +161,12 @@ def layout():
st.button(':arrow_backward: Select Dataset', on_click=reset_explorer)
exp = st.session_state.get('explorer')
col1, col2 = st.columns([0.75, 0.25], gap='small')
imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
total_imgs = len(imgs)
imgs = []
if st.session_state.get('error'):
st.error(st.session_state['error'])
else:
imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
total_imgs, selected_imgs = len(imgs), []
with col1:
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
with subcol1:
@ -159,6 +194,7 @@ def layout():
st.experimental_rerun()
query_form()
ai_query_form()
if total_imgs:
imgs_displayed = imgs[start_idx:start_idx + num]
selected_imgs = image_select(

@ -1,9 +1,14 @@
import getpass
from typing import List
import cv2
import numpy as np
import pandas as pd
from ultralytics.data.augment import LetterBox
from ultralytics.utils import LOGGER as logger
from ultralytics.utils import SETTINGS
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.ops import xyxy2xywh
from ultralytics.utils.plotting import plot_images
@ -47,15 +52,16 @@ def sanitize_batch(batch, dataset_info):
return batch
def plot_similar_images(similar_set, plot_labels=True):
def plot_query_result(similar_set, plot_labels=True):
"""
Plot images from the similar set.
Args:
similar_set (list): Pyarrow table containing the similar data points
similar_set (list): Pyarrow or pandas object containing the similar data points
plot_labels (bool): Whether to plot labels or not
"""
similar_set = similar_set.to_pydict()
similar_set = similar_set.to_dict(
orient='list') if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
empty_masks = [[[]]]
empty_boxes = [[]]
images = similar_set.get('im_file', [])
@ -102,3 +108,61 @@ def plot_similar_images(similar_set, plot_labels=True):
max_subplots=len(images),
save=False,
threaded=False)
def prompt_sql_query(query):
check_requirements('openai>=1.6.1')
from openai import OpenAI
if not SETTINGS['openai_api_key']:
logger.warning('OpenAI API key not found in settings. Please enter your API key below.')
openai_api_key = getpass.getpass('OpenAI API key: ')
SETTINGS.update({'openai_api_key': openai_api_key})
openai = OpenAI(api_key=SETTINGS['openai_api_key'])
messages = [
{
'role':
'system',
'content':
'''
You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
the following schema and a user request. You only need to output the format with fixed selection
statement that selects everything from "'table'", like `SELECT * from 'table'`
Schema:
im_file: string not null
labels: list<item: string> not null
child 0, item: string
cls: list<item: int64> not null
child 0, item: int64
bboxes: list<item: list<item: double>> not null
child 0, item: list<item: double>
child 0, item: double
masks: list<item: list<item: list<item: int64>>> not null
child 0, item: list<item: list<item: int64>>
child 0, item: list<item: int64>
child 0, item: int64
keypoints: list<item: list<item: list<item: double>>> not null
child 0, item: list<item: list<item: double>>
child 0, item: list<item: double>
child 0, item: double
vector: fixed_size_list<item: float>[256] not null
child 0, item: float
Some details about the schema:
- the "labels" column contains the string values like 'person' and 'dog' for the respective objects
in each image
- the "cls" column contains the integer values on these classes that map them the labels
Example of a correct query:
request - Get all data points that contain 2 or more people and at least one dog
correct query-
SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
'''},
{
'role': 'user',
'content': f'{query}'}, ]
response = openai.chat.completions.create(model='gpt-3.5-turbo', messages=messages)
return response.choices[0].message.content

@ -246,7 +246,7 @@ class Model(nn.Module):
prompts = args.pop('prompts', None) # for SAM-type models
if not self.predictor:
self.predictor = (predictor or self._smart_load('predictor'))(overrides=args, _callbacks=self.callbacks)
self.predictor = predictor or self._smart_load('predictor')(overrides=args, _callbacks=self.callbacks)
self.predictor.setup_model(model=self.model, verbose=is_cli)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, args)

@ -41,8 +41,7 @@ class OBBPredictor(DetectionPredictor):
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
for i, pred in enumerate(preds):
orig_img = orig_imgs[i]
for i, (pred, orig_img) in enumerate(zip(preds, orig_imgs)):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True)
img_path = self.batch[0][i]
# xywh, r, conf, cls

@ -61,13 +61,13 @@ class Detect(nn.Module):
dbox = self.decode_bboxes(box)
if self.export and self.format in ('tflite', 'edgetpu'):
# Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:
# https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309
# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695
img_h = shape[2] * self.stride[0]
img_w = shape[3] * self.stride[0]
img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1)
dbox /= img_size
# Precompute normalization factor to increase numerical stability
# See https://github.com/ultralytics/ultralytics/issues/7371
img_h = shape[2]
img_w = shape[3]
img_size = torch.tensor([img_w, img_h, img_w, img_h], device=box.device).reshape(1, 4, 1)
norm = self.strides / (self.stride[0] * img_size)
dbox = dist2bbox(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2], xywh=True, dim=1)
y = torch.cat((dbox, cls.sigmoid()), 1)
return y if self.export else (y, x)

@ -4,6 +4,7 @@ import math
import cv2
from ultralytics.utils.checks import check_imshow
from ultralytics.utils.plotting import Annotator, colors
@ -37,6 +38,9 @@ class DistanceCalculation:
self.left_mouse_count = 0
self.selected_boxes = {}
# Check if environment support imshow
self.env_check = check_imshow(warn=True)
def set_args(self,
names,
pixels_per_meter=10,
@ -168,7 +172,7 @@ class DistanceCalculation:
self.centroids = []
if self.view_img:
if self.view_img and self.env_check:
self.display_frames()
return im0

@ -28,6 +28,8 @@ class Heatmap:
self.imw = None
self.imh = None
self.im0 = None
self.view_in_counts = True
self.view_out_counts = True
# Heatmap colormap and heatmap np array
self.colormap = None
@ -67,6 +69,8 @@ class Heatmap:
colormap=cv2.COLORMAP_JET,
heatmap_alpha=0.5,
view_img=False,
view_in_counts=True,
view_out_counts=True,
count_reg_pts=None,
count_txt_thickness=2,
count_txt_color=(0, 0, 0),
@ -85,6 +89,8 @@ class Heatmap:
imh (int): The height of the frame.
heatmap_alpha (float): alpha value for heatmap display
view_img (bool): Flag indicating frame display
view_in_counts (bool): Flag to control whether to display the incounts on video stream.
view_out_counts (bool): Flag to control whether to display the outcounts on video stream.
count_reg_pts (list): Object counting region points
count_txt_thickness (int): Text thickness for object counting display
count_txt_color (RGB color): count text color value
@ -99,6 +105,8 @@ class Heatmap:
self.imh = imh
self.heatmap_alpha = heatmap_alpha
self.view_img = view_img
self.view_in_counts = view_in_counts
self.view_out_counts = view_out_counts
self.colormap = colormap
# Region and line selection
@ -171,9 +179,10 @@ class Heatmap:
if self.count_reg_pts is not None:
# Draw counting region
self.annotator.draw_region(reg_pts=self.count_reg_pts,
color=self.region_color,
thickness=self.region_thickness)
if self.view_in_counts or self.view_out_counts:
self.annotator.draw_region(reg_pts=self.count_reg_pts,
color=self.region_color,
thickness=self.region_thickness)
for box, cls, track_id in zip(self.boxes, self.clss, self.track_ids):
@ -235,11 +244,22 @@ class Heatmap:
heatmap_normalized = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX)
heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), self.colormap)
if self.count_reg_pts is not None:
incount_label = 'InCount : ' + f'{self.in_counts}'
outcount_label = 'OutCount : ' + f'{self.out_counts}'
self.annotator.count_labels(in_count=incount_label,
out_count=outcount_label,
incount_label = 'In Count : ' + f'{self.in_counts}'
outcount_label = 'OutCount : ' + f'{self.out_counts}'
# Display counts based on user choice
counts_label = None
if not self.view_in_counts and not self.view_out_counts:
counts_label = None
elif not self.view_in_counts:
counts_label = outcount_label
elif not self.view_out_counts:
counts_label = incount_label
else:
counts_label = incount_label + ' ' + outcount_label
if self.count_reg_pts is not None and counts_label is not None:
self.annotator.count_labels(counts=counts_label,
count_txt_size=self.count_txt_thickness,
txt_color=self.count_txt_color,
color=self.count_color)

@ -856,6 +856,7 @@ class SettingsManager(dict):
'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(),
'sync': True,
'api_key': '',
'openai_api_key': '',
'clearml': True, # integrations
'comet': True,
'dvc': True,

Loading…
Cancel
Save