`ultralytics 8.0.236` dataset semantic & SQL search API (#7136)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing-q <1182102784@qq.com>pull/7363/head^2 v8.0.236
parent
40a5c0abe7
commit
aca8eb1fd4
27 changed files with 1749 additions and 192 deletions
@ -0,0 +1,297 @@ |
||||
--- |
||||
comments: true |
||||
description: Explore and analyze CV datasets with Ultralytics Explorer API, offering SQL, vector similarity, and semantic searches for efficient dataset insights. |
||||
keywords: Ultralytics Explorer API, Dataset Exploration, SQL Queries, Vector Similarity Search, Semantic Search, Embeddings Table, Image Similarity, Python API for Datasets, CV Dataset Analysis, LanceDB Integration |
||||
--- |
||||
|
||||
# Ultralytics Explorer API |
||||
|
||||
## Introduction |
||||
|
||||
The Explorer API is a Python API for exploring your datasets. It supports filtering and searching your dataset using SQL queries, vector similarity search and semantic search. |
||||
|
||||
## Installation |
||||
|
||||
Explorer depends on external libraries for some of its functionality. These are automatically installed on usage. To manually install these dependencies, use the following command: |
||||
|
||||
```bash |
||||
pip install ultralytics[explorer] |
||||
``` |
||||
|
||||
## Usage |
||||
|
||||
```python |
||||
from ultralytics import Explorer |
||||
|
||||
# Create an Explorer object |
||||
explorer = Explorer(data='coco128.yaml', model='yolov8n.pt') |
||||
|
||||
# Create embeddings for your dataset |
||||
explorer.create_embeddings_table() |
||||
|
||||
# Search for similar images to a given image/images |
||||
dataframe = explorer.get_similar(img='path/to/image.jpg') |
||||
|
||||
# Or search for similar images to a given index/indices |
||||
dataframe = explorer.get_similar()(idx=0) |
||||
``` |
||||
|
||||
## 1. Similarity Search |
||||
|
||||
Similarity search is a technique for finding similar images to a given image. It is based on the idea that similar images will have similar embeddings. |
||||
One the embeddings table is built, you can get run semantic search in any of the following ways: |
||||
|
||||
- On a given index / list of indices in the dataset like - `exp.get_similar(idx=[1,10], limit=10)` |
||||
- On any image/ list of images not in the dataset - `exp.get_similar(img=["path/to/img1", "path/to/img2"], limit=10)` |
||||
- |
||||
|
||||
In case of multiple inputs, the aggregate of their embeddings is used. |
||||
|
||||
You get a pandas dataframe with the `limit` number of most similar data points to the input, along with their distance in the embedding space. You can use this dataset to perform further filtering |
||||
|
||||
!!! Example "Semantic Search" |
||||
|
||||
=== "Using Images" |
||||
|
||||
```python |
||||
from ultralytics import Explorer |
||||
|
||||
# create an Explorer object |
||||
exp = Explorer(data='coco128.yaml', model='yolov8n.pt') |
||||
exp.create_embeddings_table() |
||||
|
||||
similar = exp.get_similar(img='https://ultralytics.com/images/bus.jpg', limit=10) |
||||
print(similar.head()) |
||||
|
||||
# Search using multiple indices |
||||
similar = exp.get_similar( |
||||
img=['https://ultralytics.com/images/bus.jpg', |
||||
'https://ultralytics.com/images/bus.jpg'], |
||||
limit=10 |
||||
) |
||||
print(similar.head()) |
||||
``` |
||||
|
||||
=== "Using Dataset Indices" |
||||
|
||||
```python |
||||
from ultralytics import Explorer |
||||
|
||||
# create an Explorer object |
||||
exp = Explorer(data='coco128.yaml', model='yolov8n.pt') |
||||
exp.create_embeddings_table() |
||||
|
||||
similar = exp.get_similar(idx=1, limit=10) |
||||
print(similar.head()) |
||||
|
||||
# Search using multiple indices |
||||
similar = exp.get_similar(idx=[1,10], limit=10) |
||||
print(similar.head()) |
||||
``` |
||||
|
||||
### Plotting Similar Images |
||||
|
||||
You can also plot the similar images using the `plot_similar` method. This method takes the same arguments as `get_similar` and plots the similar images in a grid. |
||||
|
||||
!!! Example "Plotting Similar Images" |
||||
|
||||
=== "Using Images" |
||||
|
||||
```python |
||||
from ultralytics import Explorer |
||||
|
||||
# create an Explorer object |
||||
exp = Explorer(data='coco128.yaml', model='yolov8n.pt') |
||||
exp.create_embeddings_table() |
||||
|
||||
plt = exp.plot_similar(img='https://ultralytics.com/images/bus.jpg', limit=10) |
||||
plt.show() |
||||
``` |
||||
|
||||
=== "Using Dataset Indices" |
||||
|
||||
```python |
||||
from ultralytics import Explorer |
||||
|
||||
# create an Explorer object |
||||
exp = Explorer(data='coco128.yaml', model='yolov8n.pt') |
||||
exp.create_embeddings_table() |
||||
|
||||
plt = exp.plot_similar(idx=1, limit=10) |
||||
plt.show() |
||||
``` |
||||
|
||||
## 2. SQL Querying |
||||
|
||||
You can run SQL queries on your dataset using the `sql_query` method. This method takes a SQL query as input and returns a pandas dataframe with the results. |
||||
|
||||
!!! Example "SQL Query" |
||||
|
||||
```python |
||||
from ultralytics import Explorer |
||||
|
||||
# create an Explorer object |
||||
exp = Explorer(data='coco128.yaml', model='yolov8n.pt') |
||||
exp.create_embeddings_table() |
||||
|
||||
df = exp.sql_query("WHERE labels LIKE '%person%' AND labels LIKE '%dog%'") |
||||
print(df.head()) |
||||
``` |
||||
|
||||
### Plotting SQL Query Results |
||||
|
||||
You can also plot the results of a SQL query using the `plot_sql_query` method. This method takes the same arguments as `sql_query` and plots the results in a grid. |
||||
|
||||
!!! Example "Plotting SQL Query Results" |
||||
|
||||
```python |
||||
from ultralytics import Explorer |
||||
|
||||
# create an Explorer object |
||||
exp = Explorer(data='coco128.yaml', model='yolov8n.pt') |
||||
exp.create_embeddings_table() |
||||
|
||||
df = exp.sql_query("WHERE labels LIKE '%person%' AND labels LIKE '%dog%'") |
||||
print(df.head()) |
||||
``` |
||||
|
||||
## 3. Working with embeddings Table (Advanced) |
||||
|
||||
You can also work with the embeddings table directly. Once the embeddings table is created, you can access it using the `Explorer.table` |
||||
|
||||
!!! Tip "Explorer works on [LanceDB](https://lancedb.github.io/lancedb/) tables internally. You can access this table directly, using `Explorer.table` object and run raw queries, push down pre- and post-filters, etc." |
||||
|
||||
```python |
||||
from ultralytics import Explorer |
||||
|
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
table = exp.table |
||||
``` |
||||
|
||||
Here are some examples of what you can do with the table: |
||||
|
||||
### Get raw Embeddings |
||||
|
||||
!!! Example |
||||
|
||||
```python |
||||
from ultralytics import Explorer |
||||
|
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
table = exp.table |
||||
|
||||
embeddings = table.to_pandas()["vector"] |
||||
print(embeddings) |
||||
``` |
||||
|
||||
### Advanced Querying with pre and post filters |
||||
|
||||
!!! Example |
||||
|
||||
```python |
||||
from ultralytics import Explorer |
||||
|
||||
exp = Explorer(model="yolov8n.pt") |
||||
exp.create_embeddings_table() |
||||
table = exp.table |
||||
|
||||
# Dummy embedding |
||||
embedding = [i for i in range(256)] |
||||
rs = table.search(embedding).metric("cosine").where("").limit(10) |
||||
``` |
||||
|
||||
### Create Vector Index |
||||
|
||||
When using large datasets, you can also create a dedicated vector index for faster querying. This is done using the `create_index` method on LanceDB table. |
||||
|
||||
```python |
||||
table.create_index(num_partitions=..., num_sub_vectors=...) |
||||
``` |
||||
|
||||
Find more details on the type vector indices available and parameters [here](https://lancedb.github.io/lancedb/ann_indexes/#types-of-index) |
||||
In the future, we will add support for creating vector indices directly from Explorer API. |
||||
|
||||
## 4. Embeddings Applications |
||||
|
||||
You can use the embeddings table to perform a variety of exploratory analysis. Here are some examples: |
||||
|
||||
### Similarity Index |
||||
|
||||
Explorer comes with a `similarity_index` operation: |
||||
|
||||
* It tries to estimate how similar each data point is with the rest of the dataset. |
||||
* It does that by counting how many image embeddings lie closer than `max_dist` to the current image in the generated embedding space, considering `top_k` similar images at a time. |
||||
|
||||
It returns a pandas dataframe with the following columns: |
||||
|
||||
* `idx`: Index of the image in the dataset |
||||
* `im_file`: Path to the image file |
||||
* `count`: Number of images in the dataset that are closer than `max_dist` to the current image |
||||
* `sim_im_files`: List of paths to the `count` similar images |
||||
|
||||
!!! Tip |
||||
|
||||
For a given dataset, model, `max_dist` & `top_k` the similarity index once generated will be reused. In case, your dataset has changed, or you simply need to regenerate the similarity index, you can pass `force=True`. |
||||
|
||||
!!! Example "Similarity Index" |
||||
|
||||
```python |
||||
from ultralytics import Explorer |
||||
|
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
|
||||
sim_idx = exp.similarity_index() |
||||
``` |
||||
|
||||
You can use similarity index to build custom conditions to filter out the dataset. For example, you can filter out images that are not similar to any other image in the dataset using the following code: |
||||
|
||||
```python |
||||
import numpy as np |
||||
|
||||
sim_count = np.array(sim_idx["count"]) |
||||
sim_idx['im_file'][sim_count > 30] |
||||
``` |
||||
|
||||
### Visualize Embedding Space |
||||
|
||||
You can also visualize the embedding space using the plotting tool of your choice. For example here is a simple example using matplotlib: |
||||
|
||||
```python |
||||
import numpy as np |
||||
from sklearn.decomposition import PCA |
||||
import matplotlib.pyplot as plt |
||||
from mpl_toolkits.mplot3d import Axes3D |
||||
|
||||
# Reduce dimensions using PCA to 3 components for visualization in 3D |
||||
pca = PCA(n_components=3) |
||||
reduced_data = pca.fit_transform(embeddings) |
||||
|
||||
# Create a 3D scatter plot using Matplotlib Axes3D |
||||
fig = plt.figure(figsize=(8, 6)) |
||||
ax = fig.add_subplot(111, projection='3d') |
||||
|
||||
# Scatter plot |
||||
ax.scatter(reduced_data[:, 0], reduced_data[:, 1], reduced_data[:, 2], alpha=0.5) |
||||
ax.set_title('3D Scatter Plot of Reduced 256-Dimensional Data (PCA)') |
||||
ax.set_xlabel('Component 1') |
||||
ax.set_ylabel('Component 2') |
||||
ax.set_zlabel('Component 3') |
||||
|
||||
plt.show() |
||||
``` |
||||
|
||||
Start creating your own CV dataset exploration reports using the Explorer API. For inspiration, check out the |
||||
|
||||
# Apps Built Using Ultralytics Explorer |
||||
|
||||
Try our GUI Demo based on Explorer API |
||||
|
||||
# Coming Soon |
||||
|
||||
- [ ] Merge specific labels from datasets. Example - Import all `person` labels from COCO and `car` labels from Cityscapes |
||||
- [ ] Remove images that have a higher similarity index than the given threshold |
||||
- [ ] Automatically persist new datasets after merging/removing entries |
||||
- [ ] Advanced Dataset Visualizations |
@ -0,0 +1,31 @@ |
||||
--- |
||||
comments: true |
||||
description: Discover the Ultralytics Explorer, a versatile tool and Python API for CV dataset exploration, enabling semantic search, SQL queries, and vector similarity searches. |
||||
keywords: Ultralytics Explorer, CV Dataset Tools, Semantic Search, SQL Dataset Queries, Vector Similarity, Python API, GUI Explorer, Dataset Analysis, YOLO Explorer, Data Insights |
||||
--- |
||||
|
||||
# Ultralytics Explorer |
||||
|
||||
Ultralytics Explorer is a tool for exploring CV datasets using semantic search, SQL queries and vector similarity search. It is also a Python API for accessing the same functionality. |
||||
|
||||
### Installation of optional dependencies |
||||
|
||||
Explorer depends on external libraries for some of its functionality. These are automatically installed on usage. To manually install these dependencies, use the following command: |
||||
|
||||
```bash |
||||
pip install ultralytics[explorer] |
||||
``` |
||||
|
||||
## GUI Explorer Usage |
||||
|
||||
The GUI demo runs in your browser allowing you to create embeddings for your dataset and search for similar images, run SQL queries and perform semantic search. It can be run using the following command: |
||||
|
||||
```bash |
||||
yolo explorer |
||||
``` |
||||
|
||||
### Explorer API |
||||
|
||||
This is a Python API for Exploring your datasets. It also powers the GUI Explorer. You can use this to create your own exploratory notebooks or scripts to get insights into your datasets. |
||||
|
||||
Learn more about the Explorer API [here](api.md). |
@ -0,0 +1,50 @@ |
||||
from ultralytics import Explorer |
||||
|
||||
|
||||
def test_similarity(): |
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
similar = exp.get_similar(idx=1) |
||||
assert len(similar) == 25 |
||||
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg') |
||||
assert len(similar) == 25 |
||||
similar = exp.get_similar(idx=[1, 2], limit=10) |
||||
assert len(similar) == 10 |
||||
sim_idx = exp.similarity_index() |
||||
assert len(sim_idx) > 0 |
||||
sql = exp.sql_query("WHERE labels LIKE '%person%'") |
||||
len(sql) > 0 |
||||
|
||||
|
||||
def test_det(): |
||||
exp = Explorer(data='coco8.yaml', model='yolov8n.pt') |
||||
exp.create_embeddings_table(force=True) |
||||
assert len(exp.table.head()['bboxes']) > 0 |
||||
similar = exp.get_similar(idx=[1, 2], limit=10) |
||||
assert len(similar) > 0 |
||||
# This is a loose test, just checks errors not correctness |
||||
similar = exp.plot_similar(idx=[1, 2], limit=10) |
||||
assert similar is not None |
||||
similar.show() |
||||
|
||||
|
||||
def test_seg(): |
||||
exp = Explorer(data='coco8-seg.yaml', model='yolov8n-seg.pt') |
||||
exp.create_embeddings_table(force=True) |
||||
assert len(exp.table.head()['masks']) > 0 |
||||
similar = exp.get_similar(idx=[1, 2], limit=10) |
||||
assert len(similar) > 0 |
||||
similar = exp.plot_similar(idx=[1, 2], limit=10) |
||||
assert similar is not None |
||||
similar.show() |
||||
|
||||
|
||||
def test_pose(): |
||||
exp = Explorer(data='coco8-pose.yaml', model='yolov8n-pose.pt') |
||||
exp.create_embeddings_table(force=True) |
||||
assert len(exp.table.head()['keypoints']) > 0 |
||||
similar = exp.get_similar(idx=[1, 2], limit=10) |
||||
assert len(similar) > 0 |
||||
similar = exp.plot_similar(idx=[1, 2], limit=10) |
||||
assert similar is not None |
||||
similar.show() |
@ -0,0 +1,403 @@ |
||||
from io import BytesIO |
||||
from pathlib import Path |
||||
from typing import List |
||||
|
||||
import cv2 |
||||
import numpy as np |
||||
import torch |
||||
from matplotlib import pyplot as plt |
||||
from PIL import Image |
||||
from tqdm import tqdm |
||||
|
||||
from ultralytics.data.augment import Format |
||||
from ultralytics.data.dataset import YOLODataset |
||||
from ultralytics.data.utils import check_det_dataset |
||||
from ultralytics.models.yolo.model import YOLO |
||||
from ultralytics.utils import LOGGER, checks |
||||
|
||||
from .utils import get_sim_index_schema, get_table_schema, plot_similar_images, sanitize_batch |
||||
|
||||
|
||||
class ExplorerDataset(YOLODataset): |
||||
|
||||
def __init__(self, *args, data=None, **kwargs): |
||||
super().__init__(*args, data=data, **kwargs) |
||||
|
||||
# NOTE: Load the image directly without any resize operations. |
||||
def load_image(self, i): |
||||
"""Loads 1 image from dataset index 'i', returns (im, resized hw).""" |
||||
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] |
||||
if im is None: # not cached in RAM |
||||
if fn.exists(): # load npy |
||||
im = np.load(fn) |
||||
else: # read image |
||||
im = cv2.imread(f) # BGR |
||||
if im is None: |
||||
raise FileNotFoundError(f'Image Not Found {f}') |
||||
h0, w0 = im.shape[:2] # orig hw |
||||
return im, (h0, w0), im.shape[:2] |
||||
|
||||
return self.ims[i], self.im_hw0[i], self.im_hw[i] |
||||
|
||||
def build_transforms(self, hyp=None): |
||||
transforms = Format( |
||||
bbox_format='xyxy', |
||||
normalize=False, |
||||
return_mask=self.use_segments, |
||||
return_keypoint=self.use_keypoints, |
||||
batch_idx=True, |
||||
mask_ratio=hyp.mask_ratio, |
||||
mask_overlap=hyp.overlap_mask, |
||||
) |
||||
return transforms |
||||
|
||||
|
||||
class Explorer: |
||||
|
||||
def __init__(self, data='coco128.yaml', model='yolov8n.pt', uri='~/ultralytics/explorer') -> None: |
||||
checks.check_requirements(['lancedb', 'duckdb']) |
||||
import lancedb |
||||
|
||||
self.connection = lancedb.connect(uri) |
||||
self.table_name = Path(data).name.lower() + '_' + model.lower() |
||||
self.sim_idx_base_name = f'{self.table_name}_sim_idx'.lower( |
||||
) # Use this name and append thres and top_k to reuse the table |
||||
self.model = YOLO(model) |
||||
self.data = data # None |
||||
self.choice_set = None |
||||
|
||||
self.table = None |
||||
self.progress = 0 |
||||
|
||||
def create_embeddings_table(self, force=False, split='train'): |
||||
""" |
||||
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it |
||||
already exists. Pass force=True to overwrite the existing table. |
||||
|
||||
Args: |
||||
force (bool): Whether to overwrite the existing table or not. Defaults to False. |
||||
split (str): Split of the dataset to use. Defaults to 'train'. |
||||
|
||||
Example: |
||||
```python |
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
``` |
||||
""" |
||||
if self.table is not None and not force: |
||||
LOGGER.info('Table already exists. Reusing it. Pass force=True to overwrite it.') |
||||
return |
||||
if self.table_name in self.connection.table_names() and not force: |
||||
LOGGER.info(f'Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.') |
||||
self.table = self.connection.open_table(self.table_name) |
||||
self.progress = 1 |
||||
return |
||||
if self.data is None: |
||||
raise ValueError('Data must be provided to create embeddings table') |
||||
|
||||
data_info = check_det_dataset(self.data) |
||||
if split not in data_info: |
||||
raise ValueError( |
||||
f'Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}' |
||||
) |
||||
|
||||
choice_set = data_info[split] |
||||
choice_set = choice_set if isinstance(choice_set, list) else [choice_set] |
||||
self.choice_set = choice_set |
||||
dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task) |
||||
|
||||
# Create the table schema |
||||
batch = dataset[0] |
||||
vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0] |
||||
Schema = get_table_schema(vector_size) |
||||
table = self.connection.create_table(self.table_name, schema=Schema, mode='overwrite') |
||||
table.add( |
||||
self._yield_batches(dataset, |
||||
data_info, |
||||
self.model, |
||||
exclude_keys=['img', 'ratio_pad', 'resized_shape', 'ori_shape', 'batch_idx'])) |
||||
|
||||
self.table = table |
||||
|
||||
def _yield_batches(self, dataset, data_info, model, exclude_keys: List): |
||||
# Implement Batching |
||||
for i in tqdm(range(len(dataset))): |
||||
self.progress = float(i + 1) / len(dataset) |
||||
batch = dataset[i] |
||||
for k in exclude_keys: |
||||
batch.pop(k, None) |
||||
batch = sanitize_batch(batch, data_info) |
||||
batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist() |
||||
yield [batch] |
||||
|
||||
def query(self, imgs=None, limit=25): |
||||
""" |
||||
Query the table for similar images. Accepts a single image or a list of images. |
||||
|
||||
Args: |
||||
imgs (str or list): Path to the image or a list of paths to the images. |
||||
limit (int): Number of results to return. |
||||
|
||||
Returns: |
||||
An arrow table containing the results. Supports converting to: |
||||
- pandas dataframe: `result.to_pandas()` |
||||
- dict of lists: `result.to_pydict()` |
||||
|
||||
Example: |
||||
```python |
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
similar = exp.query(img='https://ultralytics.com/images/zidane.jpg') |
||||
``` |
||||
""" |
||||
if self.table is None: |
||||
raise ValueError('Table is not created. Please create the table first.') |
||||
if isinstance(imgs, str): |
||||
imgs = [imgs] |
||||
elif isinstance(imgs, list): |
||||
pass |
||||
else: |
||||
raise ValueError(f'img must be a string or a list of strings. Got {type(imgs)}') |
||||
embeds = self.model.embed(imgs) |
||||
# Get avg if multiple images are passed (len > 1) |
||||
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy() |
||||
query = self.table.search(embeds).limit(limit).to_arrow() |
||||
return query |
||||
|
||||
def sql_query(self, query, return_type='pandas'): |
||||
""" |
||||
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown. |
||||
|
||||
Args: |
||||
query (str): SQL query to run. |
||||
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'. |
||||
|
||||
Returns: |
||||
An arrow table containing the results. |
||||
|
||||
Example: |
||||
```python |
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
query = 'SELECT * FROM table WHERE labels LIKE "%person%"' |
||||
result = exp.sql_query(query) |
||||
``` |
||||
""" |
||||
import duckdb |
||||
|
||||
if self.table is None: |
||||
raise ValueError('Table is not created. Please create the table first.') |
||||
|
||||
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this. |
||||
table = self.table.to_arrow() # noqa |
||||
if not query.startswith('SELECT') and not query.startswith('WHERE'): |
||||
raise ValueError( |
||||
'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause.') |
||||
if query.startswith('WHERE'): |
||||
query = f"SELECT * FROM 'table' {query}" |
||||
LOGGER.info(f'Running query: {query}') |
||||
|
||||
rs = duckdb.sql(query) |
||||
if return_type == 'pandas': |
||||
return rs.df() |
||||
elif return_type == 'arrow': |
||||
return rs.arrow() |
||||
|
||||
def plot_sql_query(self, query, labels=True): |
||||
""" |
||||
Plot the results of a SQL-Like query on the table. |
||||
Args: |
||||
query (str): SQL query to run. |
||||
labels (bool): Whether to plot the labels or not. |
||||
|
||||
Returns: |
||||
PIL Image containing the plot. |
||||
|
||||
Example: |
||||
```python |
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
query = 'SELECT * FROM table WHERE labels LIKE "%person%"' |
||||
result = exp.plot_sql_query(query) |
||||
``` |
||||
""" |
||||
result = self.sql_query(query, return_type='arrow') |
||||
img = plot_similar_images(result, plot_labels=labels) |
||||
img = Image.fromarray(img) |
||||
return img |
||||
|
||||
def get_similar(self, img=None, idx=None, limit=25, return_type='pandas'): |
||||
""" |
||||
Query the table for similar images. Accepts a single image or a list of images. |
||||
|
||||
Args: |
||||
img (str or list): Path to the image or a list of paths to the images. |
||||
idx (int or list): Index of the image in the table or a list of indexes. |
||||
limit (int): Number of results to return. Defaults to 25. |
||||
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'. |
||||
|
||||
Returns: |
||||
A table or pandas dataframe containing the results. |
||||
|
||||
Example: |
||||
```python |
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg') |
||||
``` |
||||
""" |
||||
img = self._check_imgs_or_idxs(img, idx) |
||||
similar = self.query(img, limit=limit) |
||||
|
||||
if return_type == 'pandas': |
||||
return similar.to_pandas() |
||||
elif return_type == 'arrow': |
||||
return similar |
||||
|
||||
def plot_similar(self, img=None, idx=None, limit=25, labels=True): |
||||
""" |
||||
Plot the similar images. Accepts images or indexes. |
||||
|
||||
Args: |
||||
img (str or list): Path to the image or a list of paths to the images. |
||||
idx (int or list): Index of the image in the table or a list of indexes. |
||||
labels (bool): Whether to plot the labels or not. |
||||
limit (int): Number of results to return. Defaults to 25. |
||||
|
||||
Returns: |
||||
PIL Image containing the plot. |
||||
|
||||
Example: |
||||
```python |
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg') |
||||
``` |
||||
""" |
||||
similar = self.get_similar(img, idx, limit, return_type='arrow') |
||||
img = plot_similar_images(similar, plot_labels=labels) |
||||
img = Image.fromarray(img) |
||||
return img |
||||
|
||||
def similarity_index(self, max_dist=0.2, top_k=None, force=False): |
||||
""" |
||||
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that |
||||
are max_dist or closer to the image in the embedding space at a given index. |
||||
|
||||
Args: |
||||
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2. |
||||
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running |
||||
vector search. Defaults to 0.01. |
||||
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True. |
||||
|
||||
Returns: |
||||
A pandas dataframe containing the similarity index. |
||||
|
||||
Example: |
||||
```python |
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
sim_idx = exp.similarity_index() |
||||
``` |
||||
""" |
||||
if self.table is None: |
||||
raise ValueError('Table is not created. Please create the table first.') |
||||
sim_idx_table_name = f'{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}'.lower() |
||||
if sim_idx_table_name in self.connection.table_names() and not force: |
||||
LOGGER.info('Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.') |
||||
return self.connection.open_table(sim_idx_table_name).to_pandas() |
||||
|
||||
if top_k and not (1.0 >= top_k >= 0.0): |
||||
raise ValueError(f'top_k must be between 0.0 and 1.0. Got {top_k}') |
||||
if max_dist < 0.0: |
||||
raise ValueError(f'max_dist must be greater than 0. Got {max_dist}') |
||||
|
||||
top_k = int(top_k * len(self.table)) if top_k else len(self.table) |
||||
top_k = max(top_k, 1) |
||||
features = self.table.to_lance().to_table(columns=['vector', 'im_file']).to_pydict() |
||||
im_files = features['im_file'] |
||||
embeddings = features['vector'] |
||||
|
||||
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode='overwrite') |
||||
|
||||
def _yield_sim_idx(): |
||||
for i in tqdm(range(len(embeddings))): |
||||
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f'_distance <= {max_dist}') |
||||
yield [{ |
||||
'idx': i, |
||||
'im_file': im_files[i], |
||||
'count': len(sim_idx), |
||||
'sim_im_files': sim_idx['im_file'].tolist()}] |
||||
|
||||
sim_table.add(_yield_sim_idx()) |
||||
self.sim_index = sim_table |
||||
|
||||
return sim_table.to_pandas() |
||||
|
||||
def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False): |
||||
""" |
||||
Plot the similarity index of all the images in the table. Here, the index will contain the data points that are |
||||
max_dist or closer to the image in the embedding space at a given index. |
||||
|
||||
Args: |
||||
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2. |
||||
top_k (float): Percentage of closest data points to consider when counting. Used to apply limit when |
||||
running vector search. Defaults to 0.01. |
||||
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True. |
||||
|
||||
Returns: |
||||
PIL Image containing the plot. |
||||
|
||||
Example: |
||||
```python |
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
exp.plot_similarity_index() |
||||
``` |
||||
""" |
||||
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force) |
||||
sim_count = sim_idx['count'].tolist() |
||||
sim_count = np.array(sim_count) |
||||
|
||||
indices = np.arange(len(sim_count)) |
||||
|
||||
# Create the bar plot |
||||
plt.bar(indices, sim_count) |
||||
|
||||
# Customize the plot (optional) |
||||
plt.xlabel('data idx') |
||||
plt.ylabel('Count') |
||||
plt.title('Similarity Count') |
||||
buffer = BytesIO() |
||||
plt.savefig(buffer, format='png') |
||||
buffer.seek(0) |
||||
|
||||
# Use Pillow to open the image from the buffer |
||||
image = Image.open(buffer) |
||||
return image |
||||
|
||||
def _check_imgs_or_idxs(self, img, idx): |
||||
if img is None and idx is None: |
||||
raise ValueError('Either img or idx must be provided.') |
||||
if img is not None and idx is not None: |
||||
raise ValueError('Only one of img or idx must be provided.') |
||||
if idx is not None: |
||||
idx = idx if isinstance(idx, list) else [idx] |
||||
img = self.table.to_lance().take(idx, columns=['im_file']).to_pydict()['im_file'] |
||||
|
||||
img = img if isinstance(img, list) else [img] |
||||
return img |
||||
|
||||
def visualize(self, result): |
||||
""" |
||||
Visualize the results of a query. |
||||
|
||||
Args: |
||||
result (arrow table): Arrow table containing the results of a query. |
||||
""" |
||||
# TODO: |
||||
pass |
||||
|
||||
def generate_report(self, result): |
||||
"""Generate a report of the dataset.""" |
||||
pass |
@ -0,0 +1,178 @@ |
||||
import time |
||||
from threading import Thread |
||||
|
||||
from ultralytics import Explorer |
||||
from ultralytics.utils import ROOT |
||||
from ultralytics.utils.checks import check_requirements |
||||
|
||||
check_requirements('streamlit') |
||||
check_requirements('streamlit-select>=0.2') |
||||
import streamlit as st |
||||
from streamlit_select import image_select |
||||
|
||||
|
||||
def _get_explorer(): |
||||
exp = Explorer(data=st.session_state.get('dataset'), model=st.session_state.get('model')) |
||||
thread = Thread(target=exp.create_embeddings_table, |
||||
kwargs={'force': st.session_state.get('force_recreate_embeddings')}) |
||||
thread.start() |
||||
progress_bar = st.progress(0, text='Creating embeddings table...') |
||||
while exp.progress < 1: |
||||
time.sleep(0.1) |
||||
progress_bar.progress(exp.progress, text=f'Progress: {exp.progress * 100}%') |
||||
thread.join() |
||||
st.session_state['explorer'] = exp |
||||
progress_bar.empty() |
||||
|
||||
|
||||
def init_explorer_form(): |
||||
datasets = ROOT / 'cfg' / 'datasets' |
||||
ds = [d.name for d in datasets.glob('*.yaml')] |
||||
models = [ |
||||
'yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt', 'yolov8l.pt', 'yolov8x.pt', 'yolov8n-seg.pt', 'yolov8s-seg.pt', |
||||
'yolov8m-seg.pt', 'yolov8l-seg.pt', 'yolov8x-seg.pt', 'yolov8n-pose.pt', 'yolov8s-pose.pt', 'yolov8m-pose.pt', |
||||
'yolov8l-pose.pt', 'yolov8x-pose.pt'] |
||||
with st.form(key='explorer_init_form'): |
||||
col1, col2 = st.columns(2) |
||||
with col1: |
||||
dataset = st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml')) |
||||
with col2: |
||||
model = st.selectbox('Select model', models, key='model') |
||||
st.checkbox('Force recreate embeddings', key='force_recreate_embeddings') |
||||
|
||||
st.form_submit_button('Explore', on_click=_get_explorer) |
||||
|
||||
|
||||
def query_form(): |
||||
with st.form('query_form'): |
||||
col1, col2 = st.columns([0.8, 0.2]) |
||||
with col1: |
||||
query = st.text_input('Query', '', label_visibility='collapsed', key='query') |
||||
with col2: |
||||
st.form_submit_button('Query', on_click=run_sql_query) |
||||
|
||||
|
||||
def find_similar_imgs(imgs): |
||||
exp = st.session_state['explorer'] |
||||
similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow') |
||||
paths = similar.to_pydict()['im_file'] |
||||
st.session_state['imgs'] = paths |
||||
|
||||
|
||||
def similarity_form(selected_imgs): |
||||
st.write('Similarity Search') |
||||
with st.form('similarity_form'): |
||||
subcol1, subcol2 = st.columns([1, 1]) |
||||
with subcol1: |
||||
limit = st.number_input('limit', |
||||
min_value=None, |
||||
max_value=None, |
||||
value=25, |
||||
label_visibility='collapsed', |
||||
key='limit') |
||||
|
||||
with subcol2: |
||||
disabled = not len(selected_imgs) |
||||
st.write('Selected: ', len(selected_imgs)) |
||||
st.form_submit_button( |
||||
'Search', |
||||
disabled=disabled, |
||||
on_click=find_similar_imgs, |
||||
args=(selected_imgs, ), |
||||
) |
||||
if disabled: |
||||
st.error('Select at least one image to search.') |
||||
|
||||
|
||||
# def persist_reset_form(): |
||||
# with st.form("persist_reset"): |
||||
# col1, col2 = st.columns([1, 1]) |
||||
# with col1: |
||||
# st.form_submit_button("Reset", on_click=reset) |
||||
# |
||||
# with col2: |
||||
# st.form_submit_button("Persist", on_click=update_state, args=("PERSISTING", True)) |
||||
|
||||
|
||||
def run_sql_query(): |
||||
query = st.session_state.get('query') |
||||
if query.rstrip().lstrip(): |
||||
exp = st.session_state['explorer'] |
||||
res = exp.sql_query(query, return_type='arrow') |
||||
st.session_state['imgs'] = res.to_pydict()['im_file'] |
||||
|
||||
|
||||
def reset_explorer(): |
||||
st.session_state['explorer'] = None |
||||
st.session_state['imgs'] = None |
||||
|
||||
|
||||
def utralytics_explorer_docs_callback(): |
||||
with st.container(border=True): |
||||
st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg', |
||||
width=100) |
||||
st.markdown( |
||||
"<p>This demo is built using Ultralytics Explorer API. Visit <a href=''>API docs</a> to try examples & learn more</p>", |
||||
unsafe_allow_html=True, |
||||
help=None) |
||||
st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/') |
||||
|
||||
|
||||
def layout(): |
||||
st.set_page_config(layout='wide', initial_sidebar_state='collapsed') |
||||
st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True) |
||||
|
||||
if st.session_state.get('explorer') is None: |
||||
init_explorer_form() |
||||
return |
||||
|
||||
st.button(':arrow_backward: Select Dataset', on_click=reset_explorer) |
||||
exp = st.session_state.get('explorer') |
||||
col1, col2 = st.columns([0.75, 0.25], gap='small') |
||||
|
||||
imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file'] |
||||
total_imgs = len(imgs) |
||||
with col1: |
||||
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5) |
||||
with subcol1: |
||||
st.write('Max Images Displayed:') |
||||
with subcol2: |
||||
num = st.number_input('Max Images Displayed', |
||||
min_value=0, |
||||
max_value=total_imgs, |
||||
value=min(500, total_imgs), |
||||
key='num_imgs_displayed', |
||||
label_visibility='collapsed') |
||||
with subcol3: |
||||
st.write('Start Index:') |
||||
with subcol4: |
||||
start_idx = st.number_input('Start Index', |
||||
min_value=0, |
||||
max_value=total_imgs, |
||||
value=0, |
||||
key='start_index', |
||||
label_visibility='collapsed') |
||||
with subcol5: |
||||
reset = st.button('Reset', use_container_width=False, key='reset') |
||||
if reset: |
||||
st.session_state['imgs'] = None |
||||
st.experimental_rerun() |
||||
|
||||
query_form() |
||||
if total_imgs: |
||||
imgs_displayed = imgs[start_idx:start_idx + num] |
||||
selected_imgs = image_select( |
||||
f'Total samples: {total_imgs}', |
||||
images=imgs_displayed, |
||||
use_container_width=False, |
||||
# indices=[i for i in range(num)] if select_all else None, |
||||
) |
||||
|
||||
with col2: |
||||
similarity_form(selected_imgs) |
||||
# display_labels = st.checkbox("Labels", value=False, key="display_labels") |
||||
utralytics_explorer_docs_callback() |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
layout() |
@ -0,0 +1,103 @@ |
||||
from pathlib import Path |
||||
from typing import List |
||||
|
||||
import cv2 |
||||
import numpy as np |
||||
|
||||
from ultralytics.data.augment import LetterBox |
||||
from ultralytics.utils.ops import xyxy2xywh |
||||
from ultralytics.utils.plotting import plot_images |
||||
|
||||
|
||||
def get_table_schema(vector_size): |
||||
from lancedb.pydantic import LanceModel, Vector |
||||
|
||||
class Schema(LanceModel): |
||||
im_file: str |
||||
labels: List[str] |
||||
cls: List[int] |
||||
bboxes: List[List[float]] |
||||
masks: List[List[List[int]]] |
||||
keypoints: List[List[List[float]]] |
||||
vector: Vector(vector_size) |
||||
|
||||
return Schema |
||||
|
||||
|
||||
def get_sim_index_schema(): |
||||
from lancedb.pydantic import LanceModel |
||||
|
||||
class Schema(LanceModel): |
||||
idx: int |
||||
im_file: str |
||||
count: int |
||||
sim_im_files: List[str] |
||||
|
||||
return Schema |
||||
|
||||
|
||||
def sanitize_batch(batch, dataset_info): |
||||
batch['cls'] = batch['cls'].flatten().int().tolist() |
||||
box_cls_pair = sorted(zip(batch['bboxes'].tolist(), batch['cls']), key=lambda x: x[1]) |
||||
batch['bboxes'] = [box for box, _ in box_cls_pair] |
||||
batch['cls'] = [cls for _, cls in box_cls_pair] |
||||
batch['labels'] = [dataset_info['names'][i] for i in batch['cls']] |
||||
batch['masks'] = batch['masks'].tolist() if 'masks' in batch else [[[]]] |
||||
batch['keypoints'] = batch['keypoints'].tolist() if 'keypoints' in batch else [[[]]] |
||||
|
||||
return batch |
||||
|
||||
|
||||
def plot_similar_images(similar_set, plot_labels=True): |
||||
""" |
||||
Plot images from the similar set. |
||||
|
||||
Args: |
||||
similar_set (list): Pyarrow table containing the similar data points |
||||
plot_labels (bool): Whether to plot labels or not |
||||
""" |
||||
similar_set = similar_set.to_pydict() |
||||
empty_masks = [[[]]] |
||||
empty_boxes = [[]] |
||||
images = similar_set.get('im_file', []) |
||||
bboxes = similar_set.get('bboxes', []) if similar_set.get('bboxes') is not empty_boxes else [] |
||||
masks = similar_set.get('masks') if similar_set.get('masks')[0] != empty_masks else [] |
||||
kpts = similar_set.get('keypoints') if similar_set.get('keypoints')[0] != empty_masks else [] |
||||
cls = similar_set.get('cls', []) |
||||
|
||||
plot_size = 640 |
||||
imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], [] |
||||
for i, imf in enumerate(images): |
||||
im = cv2.imread(imf) |
||||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) |
||||
h, w = im.shape[:2] |
||||
r = min(plot_size / h, plot_size / w) |
||||
imgs.append(LetterBox(plot_size, center=False)(image=im).transpose(2, 0, 1)) |
||||
if plot_labels: |
||||
if len(bboxes) > i and len(bboxes[i]) > 0: |
||||
box = np.array(bboxes[i], dtype=np.float32) |
||||
box[:, [0, 2]] *= r |
||||
box[:, [1, 3]] *= r |
||||
plot_boxes.append(box) |
||||
if len(masks) > i and len(masks[i]) > 0: |
||||
mask = np.array(masks[i], dtype=np.uint8)[0] |
||||
plot_masks.append(LetterBox(plot_size, center=False)(image=mask)) |
||||
if len(kpts) > i and kpts[i] is not None: |
||||
kpt = np.array(kpts[i], dtype=np.float32) |
||||
kpt[:, :, :2] *= r |
||||
plot_kpts.append(kpt) |
||||
batch_idx.append(np.ones(len(np.array(bboxes[i], dtype=np.float32))) * i) |
||||
imgs = np.stack(imgs, axis=0) |
||||
masks = np.stack(plot_masks, axis=0) if len(plot_masks) > 0 else np.zeros(0, dtype=np.uint8) |
||||
kpts = np.concatenate(plot_kpts, axis=0) if len(plot_kpts) > 0 else np.zeros((0, 51), dtype=np.float32) |
||||
boxes = xyxy2xywh(np.concatenate(plot_boxes, axis=0)) if len(plot_boxes) > 0 else np.zeros(0, dtype=np.float32) |
||||
batch_idx = np.concatenate(batch_idx, axis=0) |
||||
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0) |
||||
|
||||
fname = 'temp_exp_grid.jpg' |
||||
plot_images(imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, fname=fname, |
||||
max_subplots=len(images)).join() |
||||
img = cv2.imread(fname, cv2.IMREAD_COLOR) |
||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
||||
Path(fname).unlink() |
||||
return img_rgb |
Loading…
Reference in new issue