|
|
|
@ -1,11 +1,12 @@ |
|
|
|
|
from io import BytesIO |
|
|
|
|
from pathlib import Path |
|
|
|
|
from typing import List |
|
|
|
|
from typing import Any, List, Tuple, Union |
|
|
|
|
|
|
|
|
|
import cv2 |
|
|
|
|
import numpy as np |
|
|
|
|
import torch |
|
|
|
|
from matplotlib import pyplot as plt |
|
|
|
|
from pandas import DataFrame |
|
|
|
|
from PIL import Image |
|
|
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
@ -13,18 +14,18 @@ from ultralytics.data.augment import Format |
|
|
|
|
from ultralytics.data.dataset import YOLODataset |
|
|
|
|
from ultralytics.data.utils import check_det_dataset |
|
|
|
|
from ultralytics.models.yolo.model import YOLO |
|
|
|
|
from ultralytics.utils import LOGGER, checks |
|
|
|
|
from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks |
|
|
|
|
|
|
|
|
|
from .utils import get_sim_index_schema, get_table_schema, plot_similar_images, sanitize_batch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExplorerDataset(YOLODataset): |
|
|
|
|
|
|
|
|
|
def __init__(self, *args, data=None, **kwargs): |
|
|
|
|
def __init__(self, *args, data: dict = None, **kwargs) -> None: |
|
|
|
|
super().__init__(*args, data=data, **kwargs) |
|
|
|
|
|
|
|
|
|
# NOTE: Load the image directly without any resize operations. |
|
|
|
|
def load_image(self, i): |
|
|
|
|
def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]: |
|
|
|
|
"""Loads 1 image from dataset index 'i', returns (im, resized hw).""" |
|
|
|
|
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] |
|
|
|
|
if im is None: # not cached in RAM |
|
|
|
@ -39,7 +40,7 @@ class ExplorerDataset(YOLODataset): |
|
|
|
|
|
|
|
|
|
return self.ims[i], self.im_hw0[i], self.im_hw[i] |
|
|
|
|
|
|
|
|
|
def build_transforms(self, hyp=None): |
|
|
|
|
def build_transforms(self, hyp: IterableSimpleNamespace = None): |
|
|
|
|
return Format( |
|
|
|
|
bbox_format='xyxy', |
|
|
|
|
normalize=False, |
|
|
|
@ -53,7 +54,10 @@ class ExplorerDataset(YOLODataset): |
|
|
|
|
|
|
|
|
|
class Explorer: |
|
|
|
|
|
|
|
|
|
def __init__(self, data='coco128.yaml', model='yolov8n.pt', uri='~/ultralytics/explorer') -> None: |
|
|
|
|
def __init__(self, |
|
|
|
|
data: Union[str, Path] = 'coco128.yaml', |
|
|
|
|
model: str = 'yolov8n.pt', |
|
|
|
|
uri: str = '~/ultralytics/explorer') -> None: |
|
|
|
|
checks.check_requirements(['lancedb', 'duckdb']) |
|
|
|
|
import lancedb |
|
|
|
|
|
|
|
|
@ -68,7 +72,7 @@ class Explorer: |
|
|
|
|
self.table = None |
|
|
|
|
self.progress = 0 |
|
|
|
|
|
|
|
|
|
def create_embeddings_table(self, force=False, split='train'): |
|
|
|
|
def create_embeddings_table(self, force: bool = False, split: str = 'train') -> None: |
|
|
|
|
""" |
|
|
|
|
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it |
|
|
|
|
already exists. Pass force=True to overwrite the existing table. |
|
|
|
@ -118,7 +122,7 @@ class Explorer: |
|
|
|
|
|
|
|
|
|
self.table = table |
|
|
|
|
|
|
|
|
|
def _yield_batches(self, dataset, data_info, model, exclude_keys: List): |
|
|
|
|
def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]): |
|
|
|
|
# Implement Batching |
|
|
|
|
for i in tqdm(range(len(dataset))): |
|
|
|
|
self.progress = float(i + 1) / len(dataset) |
|
|
|
@ -129,7 +133,9 @@ class Explorer: |
|
|
|
|
batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist() |
|
|
|
|
yield [batch] |
|
|
|
|
|
|
|
|
|
def query(self, imgs=None, limit=25): |
|
|
|
|
def query(self, |
|
|
|
|
imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, |
|
|
|
|
limit: int = 25) -> Any: # pyarrow.Table |
|
|
|
|
""" |
|
|
|
|
Query the table for similar images. Accepts a single image or a list of images. |
|
|
|
|
|
|
|
|
@ -162,7 +168,9 @@ class Explorer: |
|
|
|
|
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy() |
|
|
|
|
return self.table.search(embeds).limit(limit).to_arrow() |
|
|
|
|
|
|
|
|
|
def sql_query(self, query, return_type='pandas'): |
|
|
|
|
def sql_query(self, |
|
|
|
|
query: str, |
|
|
|
|
return_type: str = 'pandas') -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table |
|
|
|
|
""" |
|
|
|
|
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown. |
|
|
|
|
|
|
|
|
@ -177,7 +185,7 @@ class Explorer: |
|
|
|
|
```python |
|
|
|
|
exp = Explorer() |
|
|
|
|
exp.create_embeddings_table() |
|
|
|
|
query = 'SELECT * FROM table WHERE labels LIKE "%person%"' |
|
|
|
|
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'" |
|
|
|
|
result = exp.sql_query(query) |
|
|
|
|
``` |
|
|
|
|
""" |
|
|
|
@ -201,7 +209,7 @@ class Explorer: |
|
|
|
|
elif return_type == 'arrow': |
|
|
|
|
return rs.arrow() |
|
|
|
|
|
|
|
|
|
def plot_sql_query(self, query, labels=True): |
|
|
|
|
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image: |
|
|
|
|
""" |
|
|
|
|
Plot the results of a SQL-Like query on the table. |
|
|
|
|
Args: |
|
|
|
@ -215,7 +223,7 @@ class Explorer: |
|
|
|
|
```python |
|
|
|
|
exp = Explorer() |
|
|
|
|
exp.create_embeddings_table() |
|
|
|
|
query = 'SELECT * FROM table WHERE labels LIKE "%person%"' |
|
|
|
|
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'" |
|
|
|
|
result = exp.plot_sql_query(query) |
|
|
|
|
``` |
|
|
|
|
""" |
|
|
|
@ -223,7 +231,11 @@ class Explorer: |
|
|
|
|
img = plot_similar_images(result, plot_labels=labels) |
|
|
|
|
return Image.fromarray(img) |
|
|
|
|
|
|
|
|
|
def get_similar(self, img=None, idx=None, limit=25, return_type='pandas'): |
|
|
|
|
def get_similar(self, |
|
|
|
|
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, |
|
|
|
|
idx: Union[int, List[int]] = None, |
|
|
|
|
limit: int = 25, |
|
|
|
|
return_type: str = 'pandas') -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table |
|
|
|
|
""" |
|
|
|
|
Query the table for similar images. Accepts a single image or a list of images. |
|
|
|
|
|
|
|
|
@ -251,7 +263,11 @@ class Explorer: |
|
|
|
|
elif return_type == 'arrow': |
|
|
|
|
return similar |
|
|
|
|
|
|
|
|
|
def plot_similar(self, img=None, idx=None, limit=25, labels=True): |
|
|
|
|
def plot_similar(self, |
|
|
|
|
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, |
|
|
|
|
idx: Union[int, List[int]] = None, |
|
|
|
|
limit: int = 25, |
|
|
|
|
labels: bool = True) -> Image.Image: |
|
|
|
|
""" |
|
|
|
|
Plot the similar images. Accepts images or indexes. |
|
|
|
|
|
|
|
|
@ -275,7 +291,7 @@ class Explorer: |
|
|
|
|
img = plot_similar_images(similar, plot_labels=labels) |
|
|
|
|
return Image.fromarray(img) |
|
|
|
|
|
|
|
|
|
def similarity_index(self, max_dist=0.2, top_k=None, force=False): |
|
|
|
|
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame: |
|
|
|
|
""" |
|
|
|
|
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that |
|
|
|
|
are max_dist or closer to the image in the embedding space at a given index. |
|
|
|
@ -329,7 +345,7 @@ class Explorer: |
|
|
|
|
self.sim_index = sim_table |
|
|
|
|
return sim_table.to_pandas() |
|
|
|
|
|
|
|
|
|
def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False): |
|
|
|
|
def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image: |
|
|
|
|
""" |
|
|
|
|
Plot the similarity index of all the images in the table. Here, the index will contain the data points that are |
|
|
|
|
max_dist or closer to the image in the embedding space at a given index. |
|
|
|
@ -341,13 +357,16 @@ class Explorer: |
|
|
|
|
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
PIL Image containing the plot. |
|
|
|
|
PIL.PngImagePlugin.PngImageFile containing the plot. |
|
|
|
|
|
|
|
|
|
Example: |
|
|
|
|
```python |
|
|
|
|
exp = Explorer() |
|
|
|
|
exp.create_embeddings_table() |
|
|
|
|
exp.plot_similarity_index() |
|
|
|
|
|
|
|
|
|
similarity_idx_plot = exp.plot_similarity_index() |
|
|
|
|
similarity_idx_plot.show() # view image preview |
|
|
|
|
similarity_idx_plot.save('path/to/save/similarity_index_plot.png') # save contents to file |
|
|
|
|
``` |
|
|
|
|
""" |
|
|
|
|
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force) |
|
|
|
@ -368,9 +387,10 @@ class Explorer: |
|
|
|
|
buffer.seek(0) |
|
|
|
|
|
|
|
|
|
# Use Pillow to open the image from the buffer |
|
|
|
|
return Image.open(buffer) |
|
|
|
|
return Image.fromarray(np.array(Image.open(buffer))) |
|
|
|
|
|
|
|
|
|
def _check_imgs_or_idxs(self, img, idx): |
|
|
|
|
def _check_imgs_or_idxs(self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], |
|
|
|
|
idx: Union[None, int, List[int]]) -> List[np.ndarray]: |
|
|
|
|
if img is None and idx is None: |
|
|
|
|
raise ValueError('Either img or idx must be provided.') |
|
|
|
|
if img is not None and idx is not None: |
|
|
|
|