commit
ae7cc8fdae
122 changed files with 2153 additions and 3516 deletions
@ -1,17 +1,17 @@ |
||||
| Argument | Type | Default | Description | |
||||
| --------------- | -------------- | ---------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | |
||||
| `source` | `str` | `'ultralytics/assets'` | Specifies the data source for inference. Can be an image path, video file, directory, URL, or device ID for live feeds. Supports a wide range of formats and sources, enabling flexible application across [different types of input](/modes/predict.md/#inference-sources). | |
||||
| `conf` | `float` | `0.25` | Sets the minimum confidence threshold for detections. Objects detected with confidence below this threshold will be disregarded. Adjusting this value can help reduce false positives. | |
||||
| `iou` | `float` | `0.7` | [Intersection Over Union](https://www.ultralytics.com/glossary/intersection-over-union-iou) (IoU) threshold for Non-Maximum Suppression (NMS). Lower values result in fewer detections by eliminating overlapping boxes, useful for reducing duplicates. | |
||||
| `imgsz` | `int or tuple` | `640` | Defines the image size for inference. Can be a single integer `640` for square resizing or a (height, width) tuple. Proper sizing can improve detection [accuracy](https://www.ultralytics.com/glossary/accuracy) and processing speed. | |
||||
| `half` | `bool` | `False` | Enables half-[precision](https://www.ultralytics.com/glossary/precision) (FP16) inference, which can speed up model inference on supported GPUs with minimal impact on accuracy. | |
||||
| `device` | `str` | `None` | Specifies the device for inference (e.g., `cpu`, `cuda:0` or `0`). Allows users to select between CPU, a specific GPU, or other compute devices for model execution. | |
||||
| `max_det` | `int` | `300` | Maximum number of detections allowed per image. Limits the total number of objects the model can detect in a single inference, preventing excessive outputs in dense scenes. | |
||||
| `vid_stride` | `int` | `1` | Frame stride for video inputs. Allows skipping frames in videos to speed up processing at the cost of temporal resolution. A value of 1 processes every frame, higher values skip frames. | |
||||
| `stream_buffer` | `bool` | `False` | Determines the frame processing strategy for video streams. If `False` processing only the most recent frame, minimizing latency (optimized for real-time applications). If `True' processes all frames in order, ensuring no frames are skipped. | |
||||
| `visualize` | `bool` | `False` | Activates visualization of model features during inference, providing insights into what the model is "seeing". Useful for debugging and model interpretation. | |
||||
| `augment` | `bool` | `False` | Enables test-time augmentation (TTA) for predictions, potentially improving detection robustness at the cost of inference speed. | |
||||
| `agnostic_nms` | `bool` | `False` | Enables class-agnostic Non-Maximum Suppression (NMS), which merges overlapping boxes of different classes. Useful in multi-class detection scenarios where class overlap is common. | |
||||
| `classes` | `list[int]` | `None` | Filters predictions to a set of class IDs. Only detections belonging to the specified classes will be returned. Useful for focusing on relevant objects in multi-class detection tasks. | |
||||
| `retina_masks` | `bool` | `False` | Uses high-resolution segmentation masks if available in the model. This can enhance mask quality for segmentation tasks, providing finer detail. | |
||||
| `embed` | `list[int]` | `None` | Specifies the layers from which to extract feature vectors or [embeddings](https://www.ultralytics.com/glossary/embeddings). Useful for downstream tasks like clustering or similarity search. | |
||||
| Argument | Type | Default | Description | |
||||
| --------------- | -------------- | ---------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | |
||||
| `source` | `str` | `'ultralytics/assets'` | Specifies the data source for inference. Can be an image path, video file, directory, URL, or device ID for live feeds. Supports a wide range of formats and sources, enabling flexible application across [different types of input](/modes/predict.md/#inference-sources). | |
||||
| `conf` | `float` | `0.25` | Sets the minimum confidence threshold for detections. Objects detected with confidence below this threshold will be disregarded. Adjusting this value can help reduce false positives. | |
||||
| `iou` | `float` | `0.7` | [Intersection Over Union](https://www.ultralytics.com/glossary/intersection-over-union-iou) (IoU) threshold for Non-Maximum Suppression (NMS). Lower values result in fewer detections by eliminating overlapping boxes, useful for reducing duplicates. | |
||||
| `imgsz` | `int or tuple` | `640` | Defines the image size for inference. Can be a single integer `640` for square resizing or a (height, width) tuple. Proper sizing can improve detection [accuracy](https://www.ultralytics.com/glossary/accuracy) and processing speed. | |
||||
| `half` | `bool` | `False` | Enables half-[precision](https://www.ultralytics.com/glossary/precision) (FP16) inference, which can speed up model inference on supported GPUs with minimal impact on accuracy. | |
||||
| `device` | `str` | `None` | Specifies the device for inference (e.g., `cpu`, `cuda:0` or `0`). Allows users to select between CPU, a specific GPU, or other compute devices for model execution. | |
||||
| `max_det` | `int` | `300` | Maximum number of detections allowed per image. Limits the total number of objects the model can detect in a single inference, preventing excessive outputs in dense scenes. | |
||||
| `vid_stride` | `int` | `1` | Frame stride for video inputs. Allows skipping frames in videos to speed up processing at the cost of temporal resolution. A value of 1 processes every frame, higher values skip frames. | |
||||
| `stream_buffer` | `bool` | `False` | Determines whether to queue incoming frames for video streams. If `False`, old frames get dropped to accomodate new frames (optimized for real-time applications). If `True', queues new frames in a buffer, ensuring no frames get skipped, but will cause latency if inference FPS is lower than stream FPS. | |
||||
| `visualize` | `bool` | `False` | Activates visualization of model features during inference, providing insights into what the model is "seeing". Useful for debugging and model interpretation. | |
||||
| `augment` | `bool` | `False` | Enables test-time augmentation (TTA) for predictions, potentially improving detection robustness at the cost of inference speed. | |
||||
| `agnostic_nms` | `bool` | `False` | Enables class-agnostic Non-Maximum Suppression (NMS), which merges overlapping boxes of different classes. Useful in multi-class detection scenarios where class overlap is common. | |
||||
| `classes` | `list[int]` | `None` | Filters predictions to a set of class IDs. Only detections belonging to the specified classes will be returned. Useful for focusing on relevant objects in multi-class detection tasks. | |
||||
| `retina_masks` | `bool` | `False` | Uses high-resolution segmentation masks if available in the model. This can enhance mask quality for segmentation tasks, providing finer detail. | |
||||
| `embed` | `list[int]` | `None` | Specifies the layers from which to extract feature vectors or [embeddings](https://www.ultralytics.com/glossary/embeddings). Useful for downstream tasks like clustering or similarity search. | |
||||
|
@ -1,66 +0,0 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
import PIL |
||||
import pytest |
||||
|
||||
from ultralytics import Explorer |
||||
from ultralytics.utils import ASSETS |
||||
from ultralytics.utils.torch_utils import TORCH_1_13 |
||||
|
||||
|
||||
@pytest.mark.slow |
||||
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13") |
||||
def test_similarity(): |
||||
"""Test the correctness and response length of similarity calculations and SQL queries in the Explorer.""" |
||||
exp = Explorer(data="coco8.yaml") |
||||
exp.create_embeddings_table() |
||||
similar = exp.get_similar(idx=1) |
||||
assert len(similar) == 4 |
||||
similar = exp.get_similar(img=ASSETS / "bus.jpg") |
||||
assert len(similar) == 4 |
||||
similar = exp.get_similar(idx=[1, 2], limit=2) |
||||
assert len(similar) == 2 |
||||
sim_idx = exp.similarity_index() |
||||
assert len(sim_idx) == 4 |
||||
sql = exp.sql_query("WHERE labels LIKE '%zebra%'") |
||||
assert len(sql) == 1 |
||||
|
||||
|
||||
@pytest.mark.slow |
||||
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13") |
||||
def test_det(): |
||||
"""Test detection functionalities and verify embedding table includes bounding boxes.""" |
||||
exp = Explorer(data="coco8.yaml", model="yolo11n.pt") |
||||
exp.create_embeddings_table(force=True) |
||||
assert len(exp.table.head()["bboxes"]) > 0 |
||||
similar = exp.get_similar(idx=[1, 2], limit=10) |
||||
assert len(similar) > 0 |
||||
# This is a loose test, just checks errors not correctness |
||||
similar = exp.plot_similar(idx=[1, 2], limit=10) |
||||
assert isinstance(similar, PIL.Image.Image) |
||||
|
||||
|
||||
@pytest.mark.slow |
||||
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13") |
||||
def test_seg(): |
||||
"""Test segmentation functionalities and ensure the embedding table includes segmentation masks.""" |
||||
exp = Explorer(data="coco8-seg.yaml", model="yolo11n-seg.pt") |
||||
exp.create_embeddings_table(force=True) |
||||
assert len(exp.table.head()["masks"]) > 0 |
||||
similar = exp.get_similar(idx=[1, 2], limit=10) |
||||
assert len(similar) > 0 |
||||
similar = exp.plot_similar(idx=[1, 2], limit=10) |
||||
assert isinstance(similar, PIL.Image.Image) |
||||
|
||||
|
||||
@pytest.mark.slow |
||||
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13") |
||||
def test_pose(): |
||||
"""Test pose estimation functionality and verify the embedding table includes keypoints.""" |
||||
exp = Explorer(data="coco8-pose.yaml", model="yolo11n-pose.pt") |
||||
exp.create_embeddings_table(force=True) |
||||
assert len(exp.table.head()["keypoints"]) > 0 |
||||
similar = exp.get_similar(idx=[1, 2], limit=10) |
||||
assert len(similar) > 0 |
||||
similar = exp.plot_similar(idx=[1, 2], limit=10) |
||||
assert isinstance(similar, PIL.Image.Image) |
@ -0,0 +1,17 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
# Configuration for Ultralytics Solutions |
||||
|
||||
model: "yolo11n.pt" # The Ultralytics YOLO11 model to be used (e.g., yolo11n.pt for YOLO11 nano version and yolov8n.pt for YOLOv8 nano version) |
||||
|
||||
region: # Object counting, queue or speed estimation region points. Default region points are [(20, 400), (1080, 404), (1080, 360), (20, 360)] |
||||
line_width: 2 # Width of the annotator used to draw regions on the image/video frames + bounding boxes and tracks drawing. Default value is 2. |
||||
show: True # Flag to control whether to display output image or not, you can set this as False i.e. when deploying it on some embedded devices. |
||||
show_in: True # Flag to display objects moving *into* the defined region |
||||
show_out: True # Flag to display objects moving *out of* the defined region |
||||
classes: # To count specific classes. i.e, if you want to detect, track and count the person with COCO model, you can use classes=0, Default its None |
||||
up_angle: 145.0 # Workouts up_angle for counts, 145.0 is default value. You can adjust it for different workouts, based on position of keypoints. |
||||
down_angle: 90 # Workouts down_angle for counts, 90 is default value. You can change it for different workouts, based on position of keypoints. |
||||
kpts: [6, 8, 10] # Keypoints for workouts monitoring, i.e. If you want to consider keypoints for pushups that have mostly values of [6, 8, 10]. |
||||
colormap: # Colormap for heatmap, Only OPENCV supported colormaps can be used. By default COLORMAP_PARULA will be used for visualization. |
||||
analytics_type: "line" # Analytics type i.e "line", "pie", "bar" or "area" charts. By default, "line" analytics will be used for processing. |
@ -1,5 +0,0 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
from .utils import plot_query_result |
||||
|
||||
__all__ = ["plot_query_result"] |
@ -1,460 +0,0 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
from io import BytesIO |
||||
from pathlib import Path |
||||
from typing import Any, List, Tuple, Union |
||||
|
||||
import cv2 |
||||
import numpy as np |
||||
import torch |
||||
from matplotlib import pyplot as plt |
||||
from PIL import Image |
||||
from tqdm import tqdm |
||||
|
||||
from ultralytics.data.augment import Format |
||||
from ultralytics.data.dataset import YOLODataset |
||||
from ultralytics.data.utils import check_det_dataset |
||||
from ultralytics.models.yolo.model import YOLO |
||||
from ultralytics.utils import LOGGER, USER_CONFIG_DIR, IterableSimpleNamespace, checks |
||||
|
||||
from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch |
||||
|
||||
|
||||
class ExplorerDataset(YOLODataset): |
||||
"""Extends YOLODataset for advanced data exploration and manipulation in model training workflows.""" |
||||
|
||||
def __init__(self, *args, data: dict = None, **kwargs) -> None: |
||||
"""Initializes the ExplorerDataset with the provided data arguments, extending the YOLODataset class.""" |
||||
super().__init__(*args, data=data, **kwargs) |
||||
|
||||
def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]: |
||||
"""Loads 1 image from dataset index 'i' without any resize ops.""" |
||||
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] |
||||
if im is None: # not cached in RAM |
||||
if fn.exists(): # load npy |
||||
im = np.load(fn) |
||||
else: # read image |
||||
im = cv2.imread(f) # BGR |
||||
if im is None: |
||||
raise FileNotFoundError(f"Image Not Found {f}") |
||||
h0, w0 = im.shape[:2] # orig hw |
||||
return im, (h0, w0), im.shape[:2] |
||||
|
||||
return self.ims[i], self.im_hw0[i], self.im_hw[i] |
||||
|
||||
def build_transforms(self, hyp: IterableSimpleNamespace = None): |
||||
"""Creates transforms for dataset images without resizing.""" |
||||
return Format( |
||||
bbox_format="xyxy", |
||||
normalize=False, |
||||
return_mask=self.use_segments, |
||||
return_keypoint=self.use_keypoints, |
||||
batch_idx=True, |
||||
mask_ratio=hyp.mask_ratio, |
||||
mask_overlap=hyp.overlap_mask, |
||||
) |
||||
|
||||
|
||||
class Explorer: |
||||
"""Utility class for image embedding, table creation, and similarity querying using LanceDB and YOLO models.""" |
||||
|
||||
def __init__( |
||||
self, |
||||
data: Union[str, Path] = "coco128.yaml", |
||||
model: str = "yolov8n.pt", |
||||
uri: str = USER_CONFIG_DIR / "explorer", |
||||
) -> None: |
||||
"""Initializes the Explorer class with dataset path, model, and URI for database connection.""" |
||||
# Note duckdb==0.10.0 bug https://github.com/ultralytics/ultralytics/pull/8181 |
||||
checks.check_requirements(["lancedb>=0.4.3", "duckdb<=0.9.2"]) |
||||
import lancedb |
||||
|
||||
self.connection = lancedb.connect(uri) |
||||
self.table_name = f"{Path(data).name.lower()}_{model.lower()}" |
||||
self.sim_idx_base_name = ( |
||||
f"{self.table_name}_sim_idx".lower() |
||||
) # Use this name and append thres and top_k to reuse the table |
||||
self.model = YOLO(model) |
||||
self.data = data # None |
||||
self.choice_set = None |
||||
|
||||
self.table = None |
||||
self.progress = 0 |
||||
|
||||
def create_embeddings_table(self, force: bool = False, split: str = "train") -> None: |
||||
""" |
||||
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it |
||||
already exists. Pass force=True to overwrite the existing table. |
||||
|
||||
Args: |
||||
force (bool): Whether to overwrite the existing table or not. Defaults to False. |
||||
split (str): Split of the dataset to use. Defaults to 'train'. |
||||
|
||||
Example: |
||||
```python |
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
``` |
||||
""" |
||||
if self.table is not None and not force: |
||||
LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.") |
||||
return |
||||
if self.table_name in self.connection.table_names() and not force: |
||||
LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.") |
||||
self.table = self.connection.open_table(self.table_name) |
||||
self.progress = 1 |
||||
return |
||||
if self.data is None: |
||||
raise ValueError("Data must be provided to create embeddings table") |
||||
|
||||
data_info = check_det_dataset(self.data) |
||||
if split not in data_info: |
||||
raise ValueError( |
||||
f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}" |
||||
) |
||||
|
||||
choice_set = data_info[split] |
||||
choice_set = choice_set if isinstance(choice_set, list) else [choice_set] |
||||
self.choice_set = choice_set |
||||
dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task) |
||||
|
||||
# Create the table schema |
||||
batch = dataset[0] |
||||
vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0] |
||||
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite") |
||||
table.add( |
||||
self._yield_batches( |
||||
dataset, |
||||
data_info, |
||||
self.model, |
||||
exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"], |
||||
) |
||||
) |
||||
|
||||
self.table = table |
||||
|
||||
def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]): |
||||
"""Generates batches of data for embedding, excluding specified keys.""" |
||||
for i in tqdm(range(len(dataset))): |
||||
self.progress = float(i + 1) / len(dataset) |
||||
batch = dataset[i] |
||||
for k in exclude_keys: |
||||
batch.pop(k, None) |
||||
batch = sanitize_batch(batch, data_info) |
||||
batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist() |
||||
yield [batch] |
||||
|
||||
def query( |
||||
self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25 |
||||
) -> Any: # pyarrow.Table |
||||
""" |
||||
Query the table for similar images. Accepts a single image or a list of images. |
||||
|
||||
Args: |
||||
imgs (str or list): Path to the image or a list of paths to the images. |
||||
limit (int): Number of results to return. |
||||
|
||||
Returns: |
||||
(pyarrow.Table): An arrow table containing the results. Supports converting to: |
||||
- pandas dataframe: `result.to_pandas()` |
||||
- dict of lists: `result.to_pydict()` |
||||
|
||||
Example: |
||||
```python |
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
similar = exp.query(img="https://ultralytics.com/images/zidane.jpg") |
||||
``` |
||||
""" |
||||
if self.table is None: |
||||
raise ValueError("Table is not created. Please create the table first.") |
||||
if isinstance(imgs, str): |
||||
imgs = [imgs] |
||||
assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}" |
||||
embeds = self.model.embed(imgs) |
||||
# Get avg if multiple images are passed (len > 1) |
||||
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy() |
||||
return self.table.search(embeds).limit(limit).to_arrow() |
||||
|
||||
def sql_query( |
||||
self, query: str, return_type: str = "pandas" |
||||
) -> Union[Any, None]: # pandas.DataFrame or pyarrow.Table |
||||
""" |
||||
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown. |
||||
|
||||
Args: |
||||
query (str): SQL query to run. |
||||
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'. |
||||
|
||||
Returns: |
||||
(pyarrow.Table): An arrow table containing the results. |
||||
|
||||
Example: |
||||
```python |
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'" |
||||
result = exp.sql_query(query) |
||||
``` |
||||
""" |
||||
assert return_type in { |
||||
"pandas", |
||||
"arrow", |
||||
}, f"Return type should be either `pandas` or `arrow`, but got {return_type}" |
||||
import duckdb |
||||
|
||||
if self.table is None: |
||||
raise ValueError("Table is not created. Please create the table first.") |
||||
|
||||
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this. |
||||
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB |
||||
if not query.startswith("SELECT") and not query.startswith("WHERE"): |
||||
raise ValueError( |
||||
f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE " |
||||
f"clause. found {query}" |
||||
) |
||||
if query.startswith("WHERE"): |
||||
query = f"SELECT * FROM 'table' {query}" |
||||
LOGGER.info(f"Running query: {query}") |
||||
|
||||
rs = duckdb.sql(query) |
||||
if return_type == "arrow": |
||||
return rs.arrow() |
||||
elif return_type == "pandas": |
||||
return rs.df() |
||||
|
||||
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image: |
||||
""" |
||||
Plot the results of a SQL-Like query on the table. |
||||
|
||||
Args: |
||||
query (str): SQL query to run. |
||||
labels (bool): Whether to plot the labels or not. |
||||
|
||||
Returns: |
||||
(PIL.Image): Image containing the plot. |
||||
|
||||
Example: |
||||
```python |
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'" |
||||
result = exp.plot_sql_query(query) |
||||
``` |
||||
""" |
||||
result = self.sql_query(query, return_type="arrow") |
||||
if len(result) == 0: |
||||
LOGGER.info("No results found.") |
||||
return None |
||||
img = plot_query_result(result, plot_labels=labels) |
||||
return Image.fromarray(img) |
||||
|
||||
def get_similar( |
||||
self, |
||||
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, |
||||
idx: Union[int, List[int]] = None, |
||||
limit: int = 25, |
||||
return_type: str = "pandas", |
||||
) -> Any: # pandas.DataFrame or pyarrow.Table |
||||
""" |
||||
Query the table for similar images. Accepts a single image or a list of images. |
||||
|
||||
Args: |
||||
img (str or list): Path to the image or a list of paths to the images. |
||||
idx (int or list): Index of the image in the table or a list of indexes. |
||||
limit (int): Number of results to return. Defaults to 25. |
||||
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'. |
||||
|
||||
Returns: |
||||
(pandas.DataFrame): A dataframe containing the results. |
||||
|
||||
Example: |
||||
```python |
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
similar = exp.get_similar(img="https://ultralytics.com/images/zidane.jpg") |
||||
``` |
||||
""" |
||||
assert return_type in {"pandas", "arrow"}, f"Return type should be `pandas` or `arrow`, but got {return_type}" |
||||
img = self._check_imgs_or_idxs(img, idx) |
||||
similar = self.query(img, limit=limit) |
||||
|
||||
if return_type == "arrow": |
||||
return similar |
||||
elif return_type == "pandas": |
||||
return similar.to_pandas() |
||||
|
||||
def plot_similar( |
||||
self, |
||||
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, |
||||
idx: Union[int, List[int]] = None, |
||||
limit: int = 25, |
||||
labels: bool = True, |
||||
) -> Image.Image: |
||||
""" |
||||
Plot the similar images. Accepts images or indexes. |
||||
|
||||
Args: |
||||
img (str or list): Path to the image or a list of paths to the images. |
||||
idx (int or list): Index of the image in the table or a list of indexes. |
||||
labels (bool): Whether to plot the labels or not. |
||||
limit (int): Number of results to return. Defaults to 25. |
||||
|
||||
Returns: |
||||
(PIL.Image): Image containing the plot. |
||||
|
||||
Example: |
||||
```python |
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
similar = exp.plot_similar(img="https://ultralytics.com/images/zidane.jpg") |
||||
``` |
||||
""" |
||||
similar = self.get_similar(img, idx, limit, return_type="arrow") |
||||
if len(similar) == 0: |
||||
LOGGER.info("No results found.") |
||||
return None |
||||
img = plot_query_result(similar, plot_labels=labels) |
||||
return Image.fromarray(img) |
||||
|
||||
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Any: # pd.DataFrame |
||||
""" |
||||
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that |
||||
are max_dist or closer to the image in the embedding space at a given index. |
||||
|
||||
Args: |
||||
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2. |
||||
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit. |
||||
vector search. Defaults: None. |
||||
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True. |
||||
|
||||
Returns: |
||||
(pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image, |
||||
and columns include indices of similar images and their respective distances. |
||||
|
||||
Example: |
||||
```python |
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
sim_idx = exp.similarity_index() |
||||
``` |
||||
""" |
||||
if self.table is None: |
||||
raise ValueError("Table is not created. Please create the table first.") |
||||
sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower() |
||||
if sim_idx_table_name in self.connection.table_names() and not force: |
||||
LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.") |
||||
return self.connection.open_table(sim_idx_table_name).to_pandas() |
||||
|
||||
if top_k and not (1.0 >= top_k >= 0.0): |
||||
raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}") |
||||
if max_dist < 0.0: |
||||
raise ValueError(f"max_dist must be greater than 0. Got {max_dist}") |
||||
|
||||
top_k = int(top_k * len(self.table)) if top_k else len(self.table) |
||||
top_k = max(top_k, 1) |
||||
features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict() |
||||
im_files = features["im_file"] |
||||
embeddings = features["vector"] |
||||
|
||||
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite") |
||||
|
||||
def _yield_sim_idx(): |
||||
"""Generates a dataframe with similarity indices and distances for images.""" |
||||
for i in tqdm(range(len(embeddings))): |
||||
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}") |
||||
yield [ |
||||
{ |
||||
"idx": i, |
||||
"im_file": im_files[i], |
||||
"count": len(sim_idx), |
||||
"sim_im_files": sim_idx["im_file"].tolist(), |
||||
} |
||||
] |
||||
|
||||
sim_table.add(_yield_sim_idx()) |
||||
self.sim_index = sim_table |
||||
return sim_table.to_pandas() |
||||
|
||||
def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image: |
||||
""" |
||||
Plot the similarity index of all the images in the table. Here, the index will contain the data points that are |
||||
max_dist or closer to the image in the embedding space at a given index. |
||||
|
||||
Args: |
||||
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2. |
||||
top_k (float): Percentage of closest data points to consider when counting. Used to apply limit when |
||||
running vector search. Defaults to 0.01. |
||||
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True. |
||||
|
||||
Returns: |
||||
(PIL.Image): Image containing the plot. |
||||
|
||||
Example: |
||||
```python |
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
|
||||
similarity_idx_plot = exp.plot_similarity_index() |
||||
similarity_idx_plot.show() # view image preview |
||||
similarity_idx_plot.save("path/to/save/similarity_index_plot.png") # save contents to file |
||||
``` |
||||
""" |
||||
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force) |
||||
sim_count = sim_idx["count"].tolist() |
||||
sim_count = np.array(sim_count) |
||||
|
||||
indices = np.arange(len(sim_count)) |
||||
|
||||
# Create the bar plot |
||||
plt.bar(indices, sim_count) |
||||
|
||||
# Customize the plot (optional) |
||||
plt.xlabel("data idx") |
||||
plt.ylabel("Count") |
||||
plt.title("Similarity Count") |
||||
buffer = BytesIO() |
||||
plt.savefig(buffer, format="png") |
||||
buffer.seek(0) |
||||
|
||||
# Use Pillow to open the image from the buffer |
||||
return Image.fromarray(np.array(Image.open(buffer))) |
||||
|
||||
def _check_imgs_or_idxs( |
||||
self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]] |
||||
) -> List[np.ndarray]: |
||||
"""Determines whether to fetch images or indexes based on provided arguments and returns image paths.""" |
||||
if img is None and idx is None: |
||||
raise ValueError("Either img or idx must be provided.") |
||||
if img is not None and idx is not None: |
||||
raise ValueError("Only one of img or idx must be provided.") |
||||
if idx is not None: |
||||
idx = idx if isinstance(idx, list) else [idx] |
||||
img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"] |
||||
|
||||
return img if isinstance(img, list) else [img] |
||||
|
||||
def ask_ai(self, query): |
||||
""" |
||||
Ask AI a question. |
||||
|
||||
Args: |
||||
query (str): Question to ask. |
||||
|
||||
Returns: |
||||
(pandas.DataFrame): A dataframe containing filtered results to the SQL query. |
||||
|
||||
Example: |
||||
```python |
||||
exp = Explorer() |
||||
exp.create_embeddings_table() |
||||
answer = exp.ask_ai("Show images with 1 person and 2 dogs") |
||||
``` |
||||
""" |
||||
result = prompt_sql_query(query) |
||||
try: |
||||
return self.sql_query(result) |
||||
except Exception as e: |
||||
LOGGER.error("AI generated query is not valid. Please try again with a different prompt") |
||||
LOGGER.error(e) |
||||
return None |
@ -1 +0,0 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
@ -1,282 +0,0 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
import sys |
||||
import time |
||||
from threading import Thread |
||||
|
||||
from ultralytics import Explorer |
||||
from ultralytics.utils import ROOT, SETTINGS |
||||
from ultralytics.utils.checks import check_requirements |
||||
|
||||
check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.3")) |
||||
|
||||
import streamlit as st |
||||
from streamlit_select import image_select |
||||
|
||||
|
||||
def _get_explorer(): |
||||
"""Initializes and returns an instance of the Explorer class.""" |
||||
exp = Explorer(data=st.session_state.get("dataset"), model=st.session_state.get("model")) |
||||
thread = Thread( |
||||
target=exp.create_embeddings_table, |
||||
kwargs={"force": st.session_state.get("force_recreate_embeddings"), "split": st.session_state.get("split")}, |
||||
) |
||||
thread.start() |
||||
progress_bar = st.progress(0, text="Creating embeddings table...") |
||||
while exp.progress < 1: |
||||
time.sleep(0.1) |
||||
progress_bar.progress(exp.progress, text=f"Progress: {exp.progress * 100}%") |
||||
thread.join() |
||||
st.session_state["explorer"] = exp |
||||
progress_bar.empty() |
||||
|
||||
|
||||
def init_explorer_form(data=None, model=None): |
||||
"""Initializes an Explorer instance and creates embeddings table with progress tracking.""" |
||||
if data is None: |
||||
datasets = ROOT / "cfg" / "datasets" |
||||
ds = [d.name for d in datasets.glob("*.yaml")] |
||||
else: |
||||
ds = [data] |
||||
|
||||
if model is None: |
||||
models = [ |
||||
"yolov8n.pt", |
||||
"yolov8s.pt", |
||||
"yolov8m.pt", |
||||
"yolov8l.pt", |
||||
"yolov8x.pt", |
||||
"yolov8n-seg.pt", |
||||
"yolov8s-seg.pt", |
||||
"yolov8m-seg.pt", |
||||
"yolov8l-seg.pt", |
||||
"yolov8x-seg.pt", |
||||
"yolov8n-pose.pt", |
||||
"yolov8s-pose.pt", |
||||
"yolov8m-pose.pt", |
||||
"yolov8l-pose.pt", |
||||
"yolov8x-pose.pt", |
||||
] |
||||
else: |
||||
models = [model] |
||||
|
||||
splits = ["train", "val", "test"] |
||||
|
||||
with st.form(key="explorer_init_form"): |
||||
col1, col2, col3 = st.columns(3) |
||||
with col1: |
||||
st.selectbox("Select dataset", ds, key="dataset") |
||||
with col2: |
||||
st.selectbox("Select model", models, key="model") |
||||
with col3: |
||||
st.selectbox("Select split", splits, key="split") |
||||
st.checkbox("Force recreate embeddings", key="force_recreate_embeddings") |
||||
|
||||
st.form_submit_button("Explore", on_click=_get_explorer) |
||||
|
||||
|
||||
def query_form(): |
||||
"""Sets up a form in Streamlit to initialize Explorer with dataset and model selection.""" |
||||
with st.form("query_form"): |
||||
col1, col2 = st.columns([0.8, 0.2]) |
||||
with col1: |
||||
st.text_input( |
||||
"Query", |
||||
"WHERE labels LIKE '%person%' AND labels LIKE '%dog%'", |
||||
label_visibility="collapsed", |
||||
key="query", |
||||
) |
||||
with col2: |
||||
st.form_submit_button("Query", on_click=run_sql_query) |
||||
|
||||
|
||||
def ai_query_form(): |
||||
"""Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection.""" |
||||
with st.form("ai_query_form"): |
||||
col1, col2 = st.columns([0.8, 0.2]) |
||||
with col1: |
||||
st.text_input("Query", "Show images with 1 person and 1 dog", label_visibility="collapsed", key="ai_query") |
||||
with col2: |
||||
st.form_submit_button("Ask AI", on_click=run_ai_query) |
||||
|
||||
|
||||
def find_similar_imgs(imgs): |
||||
"""Initializes a Streamlit form for AI-based image querying with custom input.""" |
||||
exp = st.session_state["explorer"] |
||||
similar = exp.get_similar(img=imgs, limit=st.session_state.get("limit"), return_type="arrow") |
||||
paths = similar.to_pydict()["im_file"] |
||||
st.session_state["imgs"] = paths |
||||
st.session_state["res"] = similar |
||||
|
||||
|
||||
def similarity_form(selected_imgs): |
||||
"""Initializes a form for AI-based image querying with custom input in Streamlit.""" |
||||
st.write("Similarity Search") |
||||
with st.form("similarity_form"): |
||||
subcol1, subcol2 = st.columns([1, 1]) |
||||
with subcol1: |
||||
st.number_input( |
||||
"limit", min_value=None, max_value=None, value=25, label_visibility="collapsed", key="limit" |
||||
) |
||||
|
||||
with subcol2: |
||||
disabled = not len(selected_imgs) |
||||
st.write("Selected: ", len(selected_imgs)) |
||||
st.form_submit_button( |
||||
"Search", |
||||
disabled=disabled, |
||||
on_click=find_similar_imgs, |
||||
args=(selected_imgs,), |
||||
) |
||||
if disabled: |
||||
st.error("Select at least one image to search.") |
||||
|
||||
|
||||
# def persist_reset_form(): |
||||
# with st.form("persist_reset"): |
||||
# col1, col2 = st.columns([1, 1]) |
||||
# with col1: |
||||
# st.form_submit_button("Reset", on_click=reset) |
||||
# |
||||
# with col2: |
||||
# st.form_submit_button("Persist", on_click=update_state, args=("PERSISTING", True)) |
||||
|
||||
|
||||
def run_sql_query(): |
||||
"""Executes an SQL query and returns the results.""" |
||||
st.session_state["error"] = None |
||||
query = st.session_state.get("query") |
||||
if query.rstrip().lstrip(): |
||||
exp = st.session_state["explorer"] |
||||
res = exp.sql_query(query, return_type="arrow") |
||||
st.session_state["imgs"] = res.to_pydict()["im_file"] |
||||
st.session_state["res"] = res |
||||
|
||||
|
||||
def run_ai_query(): |
||||
"""Execute SQL query and update session state with query results.""" |
||||
if not SETTINGS["openai_api_key"]: |
||||
st.session_state["error"] = ( |
||||
'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."' |
||||
) |
||||
return |
||||
import pandas # scope for faster 'import ultralytics' |
||||
|
||||
st.session_state["error"] = None |
||||
query = st.session_state.get("ai_query") |
||||
if query.rstrip().lstrip(): |
||||
exp = st.session_state["explorer"] |
||||
res = exp.ask_ai(query) |
||||
if not isinstance(res, pandas.DataFrame) or res.empty: |
||||
st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it." |
||||
return |
||||
st.session_state["imgs"] = res["im_file"].to_list() |
||||
st.session_state["res"] = res |
||||
|
||||
|
||||
def reset_explorer(): |
||||
"""Resets the explorer to its initial state by clearing session variables.""" |
||||
st.session_state["explorer"] = None |
||||
st.session_state["imgs"] = None |
||||
st.session_state["error"] = None |
||||
|
||||
|
||||
def utralytics_explorer_docs_callback(): |
||||
"""Resets the explorer to its initial state by clearing session variables.""" |
||||
with st.container(border=True): |
||||
st.image( |
||||
"https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg", |
||||
width=100, |
||||
) |
||||
st.markdown( |
||||
"<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>", |
||||
unsafe_allow_html=True, |
||||
help=None, |
||||
) |
||||
st.link_button("Ultrlaytics Explorer API", "https://docs.ultralytics.com/datasets/explorer/") |
||||
|
||||
|
||||
def layout(data=None, model=None): |
||||
"""Resets explorer session variables and provides documentation with a link to API docs.""" |
||||
st.set_page_config(layout="wide", initial_sidebar_state="collapsed") |
||||
st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True) |
||||
|
||||
if st.session_state.get("explorer") is None: |
||||
init_explorer_form(data, model) |
||||
return |
||||
|
||||
st.button(":arrow_backward: Select Dataset", on_click=reset_explorer) |
||||
exp = st.session_state.get("explorer") |
||||
col1, col2 = st.columns([0.75, 0.25], gap="small") |
||||
imgs = [] |
||||
if st.session_state.get("error"): |
||||
st.error(st.session_state["error"]) |
||||
elif st.session_state.get("imgs"): |
||||
imgs = st.session_state.get("imgs") |
||||
else: |
||||
imgs = exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"] |
||||
st.session_state["res"] = exp.table.to_arrow() |
||||
total_imgs, selected_imgs = len(imgs), [] |
||||
with col1: |
||||
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5) |
||||
with subcol1: |
||||
st.write("Max Images Displayed:") |
||||
with subcol2: |
||||
num = st.number_input( |
||||
"Max Images Displayed", |
||||
min_value=0, |
||||
max_value=total_imgs, |
||||
value=min(500, total_imgs), |
||||
key="num_imgs_displayed", |
||||
label_visibility="collapsed", |
||||
) |
||||
with subcol3: |
||||
st.write("Start Index:") |
||||
with subcol4: |
||||
start_idx = st.number_input( |
||||
"Start Index", |
||||
min_value=0, |
||||
max_value=total_imgs, |
||||
value=0, |
||||
key="start_index", |
||||
label_visibility="collapsed", |
||||
) |
||||
with subcol5: |
||||
reset = st.button("Reset", use_container_width=False, key="reset") |
||||
if reset: |
||||
st.session_state["imgs"] = None |
||||
st.experimental_rerun() |
||||
|
||||
query_form() |
||||
ai_query_form() |
||||
if total_imgs: |
||||
labels, boxes, masks, kpts, classes = None, None, None, None, None |
||||
task = exp.model.task |
||||
if st.session_state.get("display_labels"): |
||||
labels = st.session_state.get("res").to_pydict()["labels"][start_idx : start_idx + num] |
||||
boxes = st.session_state.get("res").to_pydict()["bboxes"][start_idx : start_idx + num] |
||||
masks = st.session_state.get("res").to_pydict()["masks"][start_idx : start_idx + num] |
||||
kpts = st.session_state.get("res").to_pydict()["keypoints"][start_idx : start_idx + num] |
||||
classes = st.session_state.get("res").to_pydict()["cls"][start_idx : start_idx + num] |
||||
imgs_displayed = imgs[start_idx : start_idx + num] |
||||
selected_imgs = image_select( |
||||
f"Total samples: {total_imgs}", |
||||
images=imgs_displayed, |
||||
use_container_width=False, |
||||
# indices=[i for i in range(num)] if select_all else None, |
||||
labels=labels, |
||||
classes=classes, |
||||
bboxes=boxes, |
||||
masks=masks if task == "segment" else None, |
||||
kpts=kpts if task == "pose" else None, |
||||
) |
||||
|
||||
with col2: |
||||
similarity_form(selected_imgs) |
||||
st.checkbox("Labels", value=False, key="display_labels") |
||||
utralytics_explorer_docs_callback() |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
kwargs = dict(zip(sys.argv[1::2], sys.argv[2::2])) |
||||
layout(**kwargs) |
@ -1,167 +0,0 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
import getpass |
||||
from typing import List |
||||
|
||||
import cv2 |
||||
import numpy as np |
||||
|
||||
from ultralytics.data.augment import LetterBox |
||||
from ultralytics.utils import LOGGER as logger |
||||
from ultralytics.utils import SETTINGS |
||||
from ultralytics.utils.checks import check_requirements |
||||
from ultralytics.utils.ops import xyxy2xywh |
||||
from ultralytics.utils.plotting import plot_images |
||||
|
||||
|
||||
def get_table_schema(vector_size): |
||||
"""Extracts and returns the schema of a database table.""" |
||||
from lancedb.pydantic import LanceModel, Vector |
||||
|
||||
class Schema(LanceModel): |
||||
im_file: str |
||||
labels: List[str] |
||||
cls: List[int] |
||||
bboxes: List[List[float]] |
||||
masks: List[List[List[int]]] |
||||
keypoints: List[List[List[float]]] |
||||
vector: Vector(vector_size) |
||||
|
||||
return Schema |
||||
|
||||
|
||||
def get_sim_index_schema(): |
||||
"""Returns a LanceModel schema for a database table with specified vector size.""" |
||||
from lancedb.pydantic import LanceModel |
||||
|
||||
class Schema(LanceModel): |
||||
idx: int |
||||
im_file: str |
||||
count: int |
||||
sim_im_files: List[str] |
||||
|
||||
return Schema |
||||
|
||||
|
||||
def sanitize_batch(batch, dataset_info): |
||||
"""Sanitizes input batch for inference, ensuring correct format and dimensions.""" |
||||
batch["cls"] = batch["cls"].flatten().int().tolist() |
||||
box_cls_pair = sorted(zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1]) |
||||
batch["bboxes"] = [box for box, _ in box_cls_pair] |
||||
batch["cls"] = [cls for _, cls in box_cls_pair] |
||||
batch["labels"] = [dataset_info["names"][i] for i in batch["cls"]] |
||||
batch["masks"] = batch["masks"].tolist() if "masks" in batch else [[[]]] |
||||
batch["keypoints"] = batch["keypoints"].tolist() if "keypoints" in batch else [[[]]] |
||||
return batch |
||||
|
||||
|
||||
def plot_query_result(similar_set, plot_labels=True): |
||||
""" |
||||
Plot images from the similar set. |
||||
|
||||
Args: |
||||
similar_set (list): Pyarrow or pandas object containing the similar data points |
||||
plot_labels (bool): Whether to plot labels or not |
||||
""" |
||||
import pandas # scope for faster 'import ultralytics' |
||||
|
||||
similar_set = ( |
||||
similar_set.to_dict(orient="list") if isinstance(similar_set, pandas.DataFrame) else similar_set.to_pydict() |
||||
) |
||||
empty_masks = [[[]]] |
||||
empty_boxes = [[]] |
||||
images = similar_set.get("im_file", []) |
||||
bboxes = similar_set.get("bboxes", []) if similar_set.get("bboxes") is not empty_boxes else [] |
||||
masks = similar_set.get("masks") if similar_set.get("masks")[0] != empty_masks else [] |
||||
kpts = similar_set.get("keypoints") if similar_set.get("keypoints")[0] != empty_masks else [] |
||||
cls = similar_set.get("cls", []) |
||||
|
||||
plot_size = 640 |
||||
imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], [] |
||||
for i, imf in enumerate(images): |
||||
im = cv2.imread(imf) |
||||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) |
||||
h, w = im.shape[:2] |
||||
r = min(plot_size / h, plot_size / w) |
||||
imgs.append(LetterBox(plot_size, center=False)(image=im).transpose(2, 0, 1)) |
||||
if plot_labels: |
||||
if len(bboxes) > i and len(bboxes[i]) > 0: |
||||
box = np.array(bboxes[i], dtype=np.float32) |
||||
box[:, [0, 2]] *= r |
||||
box[:, [1, 3]] *= r |
||||
plot_boxes.append(box) |
||||
if len(masks) > i and len(masks[i]) > 0: |
||||
mask = np.array(masks[i], dtype=np.uint8)[0] |
||||
plot_masks.append(LetterBox(plot_size, center=False)(image=mask)) |
||||
if len(kpts) > i and kpts[i] is not None: |
||||
kpt = np.array(kpts[i], dtype=np.float32) |
||||
kpt[:, :, :2] *= r |
||||
plot_kpts.append(kpt) |
||||
batch_idx.append(np.ones(len(np.array(bboxes[i], dtype=np.float32))) * i) |
||||
imgs = np.stack(imgs, axis=0) |
||||
masks = np.stack(plot_masks, axis=0) if plot_masks else np.zeros(0, dtype=np.uint8) |
||||
kpts = np.concatenate(plot_kpts, axis=0) if plot_kpts else np.zeros((0, 51), dtype=np.float32) |
||||
boxes = xyxy2xywh(np.concatenate(plot_boxes, axis=0)) if plot_boxes else np.zeros(0, dtype=np.float32) |
||||
batch_idx = np.concatenate(batch_idx, axis=0) |
||||
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0) |
||||
|
||||
return plot_images( |
||||
imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, max_subplots=len(images), save=False, threaded=False |
||||
) |
||||
|
||||
|
||||
def prompt_sql_query(query): |
||||
"""Plots images with optional labels from a similar data set.""" |
||||
check_requirements("openai>=1.6.1") |
||||
from openai import OpenAI |
||||
|
||||
if not SETTINGS["openai_api_key"]: |
||||
logger.warning("OpenAI API key not found in settings. Please enter your API key below.") |
||||
openai_api_key = getpass.getpass("OpenAI API key: ") |
||||
SETTINGS.update({"openai_api_key": openai_api_key}) |
||||
openai = OpenAI(api_key=SETTINGS["openai_api_key"]) |
||||
|
||||
messages = [ |
||||
{ |
||||
"role": "system", |
||||
"content": """ |
||||
You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on |
||||
the following schema and a user request. You only need to output the format with fixed selection |
||||
statement that selects everything from "'table'", like `SELECT * from 'table'` |
||||
|
||||
Schema: |
||||
im_file: string not null |
||||
labels: list<item: string> not null |
||||
child 0, item: string |
||||
cls: list<item: int64> not null |
||||
child 0, item: int64 |
||||
bboxes: list<item: list<item: double>> not null |
||||
child 0, item: list<item: double> |
||||
child 0, item: double |
||||
masks: list<item: list<item: list<item: int64>>> not null |
||||
child 0, item: list<item: list<item: int64>> |
||||
child 0, item: list<item: int64> |
||||
child 0, item: int64 |
||||
keypoints: list<item: list<item: list<item: double>>> not null |
||||
child 0, item: list<item: list<item: double>> |
||||
child 0, item: list<item: double> |
||||
child 0, item: double |
||||
vector: fixed_size_list<item: float>[256] not null |
||||
child 0, item: float |
||||
|
||||
Some details about the schema: |
||||
- the "labels" column contains the string values like 'person' and 'dog' for the respective objects |
||||
in each image |
||||
- the "cls" column contains the integer values on these classes that map them the labels |
||||
|
||||
Example of a correct query: |
||||
request - Get all data points that contain 2 or more people and at least one dog |
||||
correct query- |
||||
SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1; |
||||
""", |
||||
}, |
||||
{"role": "user", "content": f"{query}"}, |
||||
] |
||||
|
||||
response = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages) |
||||
return response.choices[0].message.content |
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue