You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
469 lines
18 KiB
469 lines
18 KiB
# Ultralytics YOLO 🚀, AGPL-3.0 license |
|
|
|
from io import BytesIO |
|
from pathlib import Path |
|
from typing import Any, List, Tuple, Union |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from matplotlib import pyplot as plt |
|
from pandas import DataFrame |
|
from PIL import Image |
|
from tqdm import tqdm |
|
|
|
from ultralytics.data.augment import Format |
|
from ultralytics.data.dataset import YOLODataset |
|
from ultralytics.data.utils import check_det_dataset |
|
from ultralytics.models.yolo.model import YOLO |
|
from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks |
|
|
|
from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch |
|
|
|
|
|
class ExplorerDataset(YOLODataset): |
|
def __init__(self, *args, data: dict = None, **kwargs) -> None: |
|
super().__init__(*args, data=data, **kwargs) |
|
|
|
def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]: |
|
"""Loads 1 image from dataset index 'i' without any resize ops.""" |
|
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] |
|
if im is None: # not cached in RAM |
|
if fn.exists(): # load npy |
|
im = np.load(fn) |
|
else: # read image |
|
im = cv2.imread(f) # BGR |
|
if im is None: |
|
raise FileNotFoundError(f"Image Not Found {f}") |
|
h0, w0 = im.shape[:2] # orig hw |
|
return im, (h0, w0), im.shape[:2] |
|
|
|
return self.ims[i], self.im_hw0[i], self.im_hw[i] |
|
|
|
def build_transforms(self, hyp: IterableSimpleNamespace = None): |
|
"""Creates transforms for dataset images without resizing.""" |
|
return Format( |
|
bbox_format="xyxy", |
|
normalize=False, |
|
return_mask=self.use_segments, |
|
return_keypoint=self.use_keypoints, |
|
batch_idx=True, |
|
mask_ratio=hyp.mask_ratio, |
|
mask_overlap=hyp.overlap_mask, |
|
) |
|
|
|
|
|
class Explorer: |
|
def __init__( |
|
self, data: Union[str, Path] = "coco128.yaml", model: str = "yolov8n.pt", uri: str = "~/ultralytics/explorer" |
|
) -> None: |
|
checks.check_requirements(["lancedb>=0.4.3", "duckdb"]) |
|
import lancedb |
|
|
|
self.connection = lancedb.connect(uri) |
|
self.table_name = Path(data).name.lower() + "_" + model.lower() |
|
self.sim_idx_base_name = ( |
|
f"{self.table_name}_sim_idx".lower() |
|
) # Use this name and append thres and top_k to reuse the table |
|
self.model = YOLO(model) |
|
self.data = data # None |
|
self.choice_set = None |
|
|
|
self.table = None |
|
self.progress = 0 |
|
|
|
def create_embeddings_table(self, force: bool = False, split: str = "train") -> None: |
|
""" |
|
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it |
|
already exists. Pass force=True to overwrite the existing table. |
|
|
|
Args: |
|
force (bool): Whether to overwrite the existing table or not. Defaults to False. |
|
split (str): Split of the dataset to use. Defaults to 'train'. |
|
|
|
Example: |
|
```python |
|
exp = Explorer() |
|
exp.create_embeddings_table() |
|
``` |
|
""" |
|
if self.table is not None and not force: |
|
LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.") |
|
return |
|
if self.table_name in self.connection.table_names() and not force: |
|
LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.") |
|
self.table = self.connection.open_table(self.table_name) |
|
self.progress = 1 |
|
return |
|
if self.data is None: |
|
raise ValueError("Data must be provided to create embeddings table") |
|
|
|
data_info = check_det_dataset(self.data) |
|
if split not in data_info: |
|
raise ValueError( |
|
f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}" |
|
) |
|
|
|
choice_set = data_info[split] |
|
choice_set = choice_set if isinstance(choice_set, list) else [choice_set] |
|
self.choice_set = choice_set |
|
dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task) |
|
|
|
# Create the table schema |
|
batch = dataset[0] |
|
vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0] |
|
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite") |
|
table.add( |
|
self._yield_batches( |
|
dataset, |
|
data_info, |
|
self.model, |
|
exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"], |
|
) |
|
) |
|
|
|
self.table = table |
|
|
|
def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]): |
|
"""Generates batches of data for embedding, excluding specified keys.""" |
|
for i in tqdm(range(len(dataset))): |
|
self.progress = float(i + 1) / len(dataset) |
|
batch = dataset[i] |
|
for k in exclude_keys: |
|
batch.pop(k, None) |
|
batch = sanitize_batch(batch, data_info) |
|
batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist() |
|
yield [batch] |
|
|
|
def query( |
|
self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25 |
|
) -> Any: # pyarrow.Table |
|
""" |
|
Query the table for similar images. Accepts a single image or a list of images. |
|
|
|
Args: |
|
imgs (str or list): Path to the image or a list of paths to the images. |
|
limit (int): Number of results to return. |
|
|
|
Returns: |
|
(pyarrow.Table): An arrow table containing the results. Supports converting to: |
|
- pandas dataframe: `result.to_pandas()` |
|
- dict of lists: `result.to_pydict()` |
|
|
|
Example: |
|
```python |
|
exp = Explorer() |
|
exp.create_embeddings_table() |
|
similar = exp.query(img='https://ultralytics.com/images/zidane.jpg') |
|
``` |
|
""" |
|
if self.table is None: |
|
raise ValueError("Table is not created. Please create the table first.") |
|
if isinstance(imgs, str): |
|
imgs = [imgs] |
|
assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}" |
|
embeds = self.model.embed(imgs) |
|
# Get avg if multiple images are passed (len > 1) |
|
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy() |
|
return self.table.search(embeds).limit(limit).to_arrow() |
|
|
|
def sql_query( |
|
self, query: str, return_type: str = "pandas" |
|
) -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table |
|
""" |
|
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown. |
|
|
|
Args: |
|
query (str): SQL query to run. |
|
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'. |
|
|
|
Returns: |
|
(pyarrow.Table): An arrow table containing the results. |
|
|
|
Example: |
|
```python |
|
exp = Explorer() |
|
exp.create_embeddings_table() |
|
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'" |
|
result = exp.sql_query(query) |
|
``` |
|
""" |
|
assert return_type in [ |
|
"pandas", |
|
"arrow", |
|
], f"Return type should be either `pandas` or `arrow`, but got {return_type}" |
|
import duckdb |
|
|
|
if self.table is None: |
|
raise ValueError("Table is not created. Please create the table first.") |
|
|
|
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this. |
|
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB |
|
if not query.startswith("SELECT") and not query.startswith("WHERE"): |
|
raise ValueError( |
|
f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}" |
|
) |
|
if query.startswith("WHERE"): |
|
query = f"SELECT * FROM 'table' {query}" |
|
LOGGER.info(f"Running query: {query}") |
|
|
|
rs = duckdb.sql(query) |
|
if return_type == "pandas": |
|
return rs.df() |
|
elif return_type == "arrow": |
|
return rs.arrow() |
|
|
|
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image: |
|
""" |
|
Plot the results of a SQL-Like query on the table. |
|
Args: |
|
query (str): SQL query to run. |
|
labels (bool): Whether to plot the labels or not. |
|
|
|
Returns: |
|
(PIL.Image): Image containing the plot. |
|
|
|
Example: |
|
```python |
|
exp = Explorer() |
|
exp.create_embeddings_table() |
|
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'" |
|
result = exp.plot_sql_query(query) |
|
``` |
|
""" |
|
result = self.sql_query(query, return_type="arrow") |
|
if len(result) == 0: |
|
LOGGER.info("No results found.") |
|
return None |
|
img = plot_query_result(result, plot_labels=labels) |
|
return Image.fromarray(img) |
|
|
|
def get_similar( |
|
self, |
|
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, |
|
idx: Union[int, List[int]] = None, |
|
limit: int = 25, |
|
return_type: str = "pandas", |
|
) -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table |
|
""" |
|
Query the table for similar images. Accepts a single image or a list of images. |
|
|
|
Args: |
|
img (str or list): Path to the image or a list of paths to the images. |
|
idx (int or list): Index of the image in the table or a list of indexes. |
|
limit (int): Number of results to return. Defaults to 25. |
|
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'. |
|
|
|
Returns: |
|
(pandas.DataFrame): A dataframe containing the results. |
|
|
|
Example: |
|
```python |
|
exp = Explorer() |
|
exp.create_embeddings_table() |
|
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg') |
|
``` |
|
""" |
|
assert return_type in [ |
|
"pandas", |
|
"arrow", |
|
], f"Return type should be either `pandas` or `arrow`, but got {return_type}" |
|
img = self._check_imgs_or_idxs(img, idx) |
|
similar = self.query(img, limit=limit) |
|
|
|
if return_type == "pandas": |
|
return similar.to_pandas() |
|
elif return_type == "arrow": |
|
return similar |
|
|
|
def plot_similar( |
|
self, |
|
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, |
|
idx: Union[int, List[int]] = None, |
|
limit: int = 25, |
|
labels: bool = True, |
|
) -> Image.Image: |
|
""" |
|
Plot the similar images. Accepts images or indexes. |
|
|
|
Args: |
|
img (str or list): Path to the image or a list of paths to the images. |
|
idx (int or list): Index of the image in the table or a list of indexes. |
|
labels (bool): Whether to plot the labels or not. |
|
limit (int): Number of results to return. Defaults to 25. |
|
|
|
Returns: |
|
(PIL.Image): Image containing the plot. |
|
|
|
Example: |
|
```python |
|
exp = Explorer() |
|
exp.create_embeddings_table() |
|
similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg') |
|
``` |
|
""" |
|
similar = self.get_similar(img, idx, limit, return_type="arrow") |
|
if len(similar) == 0: |
|
LOGGER.info("No results found.") |
|
return None |
|
img = plot_query_result(similar, plot_labels=labels) |
|
return Image.fromarray(img) |
|
|
|
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame: |
|
""" |
|
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that |
|
are max_dist or closer to the image in the embedding space at a given index. |
|
|
|
Args: |
|
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2. |
|
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running |
|
vector search. Defaults: None. |
|
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True. |
|
|
|
Returns: |
|
(pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image, and columns |
|
include indices of similar images and their respective distances. |
|
|
|
Example: |
|
```python |
|
exp = Explorer() |
|
exp.create_embeddings_table() |
|
sim_idx = exp.similarity_index() |
|
``` |
|
""" |
|
if self.table is None: |
|
raise ValueError("Table is not created. Please create the table first.") |
|
sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower() |
|
if sim_idx_table_name in self.connection.table_names() and not force: |
|
LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.") |
|
return self.connection.open_table(sim_idx_table_name).to_pandas() |
|
|
|
if top_k and not (1.0 >= top_k >= 0.0): |
|
raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}") |
|
if max_dist < 0.0: |
|
raise ValueError(f"max_dist must be greater than 0. Got {max_dist}") |
|
|
|
top_k = int(top_k * len(self.table)) if top_k else len(self.table) |
|
top_k = max(top_k, 1) |
|
features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict() |
|
im_files = features["im_file"] |
|
embeddings = features["vector"] |
|
|
|
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite") |
|
|
|
def _yield_sim_idx(): |
|
"""Generates a dataframe with similarity indices and distances for images.""" |
|
for i in tqdm(range(len(embeddings))): |
|
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}") |
|
yield [ |
|
{ |
|
"idx": i, |
|
"im_file": im_files[i], |
|
"count": len(sim_idx), |
|
"sim_im_files": sim_idx["im_file"].tolist(), |
|
} |
|
] |
|
|
|
sim_table.add(_yield_sim_idx()) |
|
self.sim_index = sim_table |
|
return sim_table.to_pandas() |
|
|
|
def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image: |
|
""" |
|
Plot the similarity index of all the images in the table. Here, the index will contain the data points that are |
|
max_dist or closer to the image in the embedding space at a given index. |
|
|
|
Args: |
|
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2. |
|
top_k (float): Percentage of closest data points to consider when counting. Used to apply limit when |
|
running vector search. Defaults to 0.01. |
|
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True. |
|
|
|
Returns: |
|
(PIL.Image): Image containing the plot. |
|
|
|
Example: |
|
```python |
|
exp = Explorer() |
|
exp.create_embeddings_table() |
|
|
|
similarity_idx_plot = exp.plot_similarity_index() |
|
similarity_idx_plot.show() # view image preview |
|
similarity_idx_plot.save('path/to/save/similarity_index_plot.png') # save contents to file |
|
``` |
|
""" |
|
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force) |
|
sim_count = sim_idx["count"].tolist() |
|
sim_count = np.array(sim_count) |
|
|
|
indices = np.arange(len(sim_count)) |
|
|
|
# Create the bar plot |
|
plt.bar(indices, sim_count) |
|
|
|
# Customize the plot (optional) |
|
plt.xlabel("data idx") |
|
plt.ylabel("Count") |
|
plt.title("Similarity Count") |
|
buffer = BytesIO() |
|
plt.savefig(buffer, format="png") |
|
buffer.seek(0) |
|
|
|
# Use Pillow to open the image from the buffer |
|
return Image.fromarray(np.array(Image.open(buffer))) |
|
|
|
def _check_imgs_or_idxs( |
|
self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]] |
|
) -> List[np.ndarray]: |
|
if img is None and idx is None: |
|
raise ValueError("Either img or idx must be provided.") |
|
if img is not None and idx is not None: |
|
raise ValueError("Only one of img or idx must be provided.") |
|
if idx is not None: |
|
idx = idx if isinstance(idx, list) else [idx] |
|
img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"] |
|
|
|
return img if isinstance(img, list) else [img] |
|
|
|
def ask_ai(self, query): |
|
""" |
|
Ask AI a question. |
|
|
|
Args: |
|
query (str): Question to ask. |
|
|
|
Returns: |
|
(pandas.DataFrame): A dataframe containing filtered results to the SQL query. |
|
|
|
Example: |
|
```python |
|
exp = Explorer() |
|
exp.create_embeddings_table() |
|
answer = exp.ask_ai('Show images with 1 person and 2 dogs') |
|
``` |
|
""" |
|
result = prompt_sql_query(query) |
|
try: |
|
df = self.sql_query(result) |
|
except Exception as e: |
|
LOGGER.error("AI generated query is not valid. Please try again with a different prompt") |
|
LOGGER.error(e) |
|
return None |
|
return df |
|
|
|
def visualize(self, result): |
|
""" |
|
Visualize the results of a query. TODO. |
|
|
|
Args: |
|
result (pyarrow.Table): Table containing the results of a query. |
|
""" |
|
pass |
|
|
|
def generate_report(self, result): |
|
""" |
|
Generate a report of the dataset. |
|
|
|
TODO |
|
""" |
|
pass
|
|
|