Explorer with LanceDB, Actions and Docs updates (#7487)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Muhammad Rizwan Munawar <chr043416@gmail.com>
Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
pull/8052/head
Glenn Jocher 11 months ago committed by GitHub
parent 0e7221fb62
commit 09ee982d35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 14
      .pre-commit-config.yaml
  2. 13
      docs/en/datasets/explorer/api.md
  3. 12
      docs/en/datasets/explorer/dashboard.md
  4. 11
      docs/en/datasets/explorer/index.md
  5. 2
      docs/mkdocs.yml
  6. 12
      docs/overrides/main.html
  7. 2
      pyproject.toml
  8. 8
      ultralytics/models/utils/loss.py
  9. 2
      ultralytics/models/yolo/obb/predict.py

@ -30,17 +30,11 @@ repos:
- id: pyupgrade
name: Upgrade code
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.11
hooks:
- id: isort
name: Sort imports
- repo: https://github.com/google/yapf
rev: v0.40.2
hooks:
- id: yapf
name: YAPF formatting
- id: ruff
args: [--fix]
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.17

@ -34,9 +34,16 @@ explorer.create_embeddings_table()
dataframe = explorer.get_similar(img='path/to/image.jpg')
# Or search for similar images to a given index/indices
dataframe = explorer.get_similar()(idx=0)
dataframe = explorer.get_similar(idx=0)
```
!!! Tip "Note"
Embeddings table for a given dataset and model pair is only created once and reused. These use [LanceDB](https://lancedb.github.io/lancedb/) under the hood, which scales on-disk, so you can create and reuse embeddings for large datasets like COCO without running out of memory.
In case you want to force update the embeddings table, you can pass `force=True` to `create_embeddings_table` method.
You can direclty access the LanceDB table object to perform advanced analysis. Learn more about it in [Working with table section](#4-advanced---working-with-embeddings-table)
## 1. Similarity Search
Similarity search is a technique for finding similar images to a given image. It is based on the idea that similar images will have similar embeddings. Once the embeddings table is built, you can get run semantic search in any of the following ways:
@ -178,7 +185,7 @@ You can also plot the results of a SQL query using the `plot_sql_query` method.
print(df.head())
```
## 4. Working with embeddings Table (Advanced)
## 4. Advanced - Working with Embeddings Table
You can also work with the embeddings table directly. Once the embeddings table is created, you can access it using the `Explorer.table`
@ -230,7 +237,7 @@ Here are some examples of what you can do with the table:
When using large datasets, you can also create a dedicated vector index for faster querying. This is done using the `create_index` method on LanceDB table.
```python
table.create_index(num_partitions=..., num_sub_vectors=...)
table.create_index(num_partitions=..., num_sub_vectors=...)
```
Find more details on the type vector indices available and parameters [here](https://lancedb.github.io/lancedb/ann_indexes/#types-of-index) In the future, we will add support for creating vector indices directly from Explorer API.

@ -8,6 +8,10 @@ keywords: Ultralytics, Explorer GUI, semantic search, vector similarity search,
Explorer GUI is like a playground build using [Ultralytics Explorer API](api.md). It allows you to run semantic/vector similarity search, SQL queries and even search using natural language using our ask AI feature powered by LLMs.
<p>
<img width="1709" alt="Explorer Dashboard Screenshot 1" src="https://github.com/AyushExel/assets/assets/15766192/f9c3c704-df3f-4209-81e9-c777b1f6ed9c">
</p>
### Installation
```bash
@ -26,19 +30,19 @@ Semantic search is a technique for finding similar images to a given image. It i
For example:
In this VOC Exploration dashboard, user selects a couple aeroplane images like this:
<p>
<img width="1710" alt="Screenshot 2024-01-08 at 8 46 33PM" src="https://github.com/AyushExel/assets/assets/15766192/da5f1b0a-9eb5-4712-919c-7d5512240dd8">
<img width="1710" alt="Explorer Dashboard Screenshot 2" src="https://github.com/AyushExel/assets/assets/15766192/da5f1b0a-9eb5-4712-919c-7d5512240dd8">
</p>
On performing similarity search, you should see a similar result:
<p>
<img width="1710" alt="Screenshot 2024-01-08 at 8 46 46PM" src="https://github.com/AyushExel/assets/assets/15766192/5e4c6445-8e4e-48bb-a15a-9fb6c6994af8">
<img width="1710" alt="Explorer Dashboard Screenshot 3" src="https://github.com/AyushExel/assets/assets/15766192/5e4c6445-8e4e-48bb-a15a-9fb6c6994af8">
</p>
## Ask AI
This allows you to write how you want to filter your dataset using natural language. You don't have to be proficient in writing SQL queries. Our AI powered query generator will automatically do that under the hood. For example - you can say - "show me 100 images with exactly one person and 2 dogs. There can be other objects too" and it'll internally generate the query and show you those results. Here's an example output when asked to "Show 10 images with exactly 5 persons" and you'll see a result like this:
<p>
<img width="1709" alt="Screenshot 2024-01-08 at 7 19 48PM (1)" src="https://github.com/AyushExel/assets/assets/15766192/e536b0eb-6bce-43fe-b800-3e79510d2e5b">
<img width="1709" alt="Explorer Dashboard Screenshot 4" src="https://github.com/AyushExel/assets/assets/15766192/e536b0eb-6bce-43fe-b800-3e79510d2e5b">
</p>
Note: This works using LLMs under the hood so the results are probabilistic and might get things wrong sometimes
@ -52,7 +56,7 @@ WHERE labels LIKE '%person%' AND labels LIKE '%dog%'
```
<p>
<img width="1707" alt="Screenshot 2024-01-08 at 8 57 49PM" src="https://github.com/AyushExel/assets/assets/15766192/71619e16-4db9-4fdb-b951-0d1fdbf59a6a">
<img width="1707" alt="Explorer Dashboard Screenshot 5" src="https://github.com/AyushExel/assets/assets/15766192/71619e16-4db9-4fdb-b951-0d1fdbf59a6a">
</p>
This is a Demo build using the Explorer API. You can use the API to build your own exploratory notebooks or scripts to get insights into your datasets. Learn more about the Explorer API [here](api.md).

@ -7,7 +7,7 @@ keywords: Ultralytics Explorer, CV Dataset Tools, Semantic Search, SQL Dataset Q
# Ultralytics Explorer
<p>
<img width="1709" alt="Screenshot 2024-01-08 at 7 19 48PM (1)" src="https://github.com/AyushExel/assets/assets/15766192/e536b0eb-6bce-43fe-b800-3e79510d2e5b">
<img width="1709" alt="Ultralytics Explorer Screenshot 1" src="https://github.com/AyushExel/assets/assets/15766192/85675606-fb7f-4b0c-ad1b-d9d07c919414">
</p>
<a href="https://colab.research.google.com/github/ultralytics/ultralytics/blob/main/docs/en/datasets/explorer/explorer.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
@ -21,6 +21,11 @@ Explorer depends on external libraries for some of its functionality. These are
pip install ultralytics[explorer]
```
!!! tip
Explorer works on embedding/semantic search & SQL querying and is powered by [LanceDB](https://lancedb.com/) serverless vector database. Unlike traditional in-memory DBs, it is persisted on disk without sacrificing performance, so you can scale locally to large datasets like COCO without running out of memory.
### Explorer API
This is a Python API for Exploring your datasets. It also powers the GUI Explorer. You can use this to create your own exploratory notebooks or scripts to get insights into your datasets.
@ -38,3 +43,7 @@ yolo explorer
!!! note "Note"
Ask AI feature works using OpenAI, so you'll be prompted to set the api key for OpenAI when you first run the GUI.
You can set it like this - `yolo settings openai_api_key="..."`
<p>
<img width="1709" alt="Ultralytics Explorer OpenAI Integration" src="https://github.com/AyushExel/assets/assets/15766192/1b5f3708-be3e-44c5-9ea3-adcd522dfc75">
</p>

@ -42,7 +42,7 @@ theme:
icon: material/brightness-7
name: Switch to dark mode
features:
- announce.dismiss
# - announce.dismiss
- content.action.edit
- content.code.annotate
- content.code.copy

@ -0,0 +1,12 @@
<!--Ultralytics YOLO 🚀, AGPL-3.0 license-->
{% extends "base.html" %}
{% block announce %}
<div style="text-align: center;">
<a href="https://www.ultralytics.com/blog/ultralytics-yolov8-turns-one-a-year-of-breakthroughs-and-innovations"
target="_blank" style="color: #FFFFFF;">
Ultralytics YOLOv8 Turns One! 🎉 A Year of Breakthroughs and Innovations &nbsp;
</a>
</div>
{% endblock %}

@ -180,4 +180,4 @@ close-quotes-on-newline = true
[tool.codespell]
ignore-words-list = "crate,nd,ned,strack,dota,ane,segway,fo,gool,winn,commend,bloc,nam,afterall"
skip = "*.pt,*.pth,*.torchscript,*.onnx,*.tflite,*.pb,*.bin,*.param,*.mlmodel,*.engine,*.npy,*.data*,*.csv,*pnnx*,*venv*,__pycache__*,*.ico,*.jpg,*.png,*.mp4,*.mov,/runs,/.git,./docs/??/*.md,./docs/mkdocs_??.yml"
skip = '*.pt,*.pth,*.torchscript,*.onnx,*.tflite,*.pb,*.bin,*.param,*.mlmodel,*.engine,*.npy,*.data*,*.csv,*pnnx*,*venv*,*translat*,__pycache__*,*.ico,*.jpg,*.png,*.mp4,*.mov,/runs,/.git,./docs/??/*.md,./docs/mkdocs_??.yml'

@ -200,14 +200,14 @@ class DETRLoss(nn.Module):
"""Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices."""
pred_assigned = torch.cat(
[
t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
for t, (I, _) in zip(pred_bboxes, match_indices)
t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
for t, (i, _) in zip(pred_bboxes, match_indices)
]
)
gt_assigned = torch.cat(
[
t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
for t, (_, J) in zip(gt_bboxes, match_indices)
t[j] if len(j) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
for t, (_, j) in zip(gt_bboxes, match_indices)
]
)
return pred_assigned, gt_assigned

@ -44,7 +44,7 @@ class OBBPredictor(DetectionPredictor):
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
for i, (pred, orig_img, img_path) in enumerate(zip(preds, orig_imgs, self.batch[0])):
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True)
# xywh, r, conf, cls
obb = torch.cat([pred[:, :4], pred[:, -1:], pred[:, 4:6]], dim=-1)

Loading…
Cancel
Save