Add type hinting to explorer.py (#7388)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/7375/head
Burhan 11 months ago committed by GitHub
parent e19398a537
commit d0562d7a2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 62
      ultralytics/data/explorer/explorer.py

@ -1,11 +1,12 @@
from io import BytesIO
from pathlib import Path
from typing import List
from typing import Any, List, Tuple, Union
import cv2
import numpy as np
import torch
from matplotlib import pyplot as plt
from pandas import DataFrame
from PIL import Image
from tqdm import tqdm
@ -13,18 +14,18 @@ from ultralytics.data.augment import Format
from ultralytics.data.dataset import YOLODataset
from ultralytics.data.utils import check_det_dataset
from ultralytics.models.yolo.model import YOLO
from ultralytics.utils import LOGGER, checks
from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks
from .utils import get_sim_index_schema, get_table_schema, plot_similar_images, sanitize_batch
class ExplorerDataset(YOLODataset):
def __init__(self, *args, data=None, **kwargs):
def __init__(self, *args, data: dict = None, **kwargs) -> None:
super().__init__(*args, data=data, **kwargs)
# NOTE: Load the image directly without any resize operations.
def load_image(self, i):
def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]:
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
if im is None: # not cached in RAM
@ -39,7 +40,7 @@ class ExplorerDataset(YOLODataset):
return self.ims[i], self.im_hw0[i], self.im_hw[i]
def build_transforms(self, hyp=None):
def build_transforms(self, hyp: IterableSimpleNamespace = None):
return Format(
bbox_format='xyxy',
normalize=False,
@ -53,7 +54,10 @@ class ExplorerDataset(YOLODataset):
class Explorer:
def __init__(self, data='coco128.yaml', model='yolov8n.pt', uri='~/ultralytics/explorer') -> None:
def __init__(self,
data: Union[str, Path] = 'coco128.yaml',
model: str = 'yolov8n.pt',
uri: str = '~/ultralytics/explorer') -> None:
checks.check_requirements(['lancedb', 'duckdb'])
import lancedb
@ -68,7 +72,7 @@ class Explorer:
self.table = None
self.progress = 0
def create_embeddings_table(self, force=False, split='train'):
def create_embeddings_table(self, force: bool = False, split: str = 'train') -> None:
"""
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
already exists. Pass force=True to overwrite the existing table.
@ -118,7 +122,7 @@ class Explorer:
self.table = table
def _yield_batches(self, dataset, data_info, model, exclude_keys: List):
def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]):
# Implement Batching
for i in tqdm(range(len(dataset))):
self.progress = float(i + 1) / len(dataset)
@ -129,7 +133,9 @@ class Explorer:
batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist()
yield [batch]
def query(self, imgs=None, limit=25):
def query(self,
imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
limit: int = 25) -> Any: # pyarrow.Table
"""
Query the table for similar images. Accepts a single image or a list of images.
@ -162,7 +168,9 @@ class Explorer:
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
return self.table.search(embeds).limit(limit).to_arrow()
def sql_query(self, query, return_type='pandas'):
def sql_query(self,
query: str,
return_type: str = 'pandas') -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
"""
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
@ -177,7 +185,7 @@ class Explorer:
```python
exp = Explorer()
exp.create_embeddings_table()
query = 'SELECT * FROM table WHERE labels LIKE "%person%"'
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
result = exp.sql_query(query)
```
"""
@ -201,7 +209,7 @@ class Explorer:
elif return_type == 'arrow':
return rs.arrow()
def plot_sql_query(self, query, labels=True):
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
"""
Plot the results of a SQL-Like query on the table.
Args:
@ -215,7 +223,7 @@ class Explorer:
```python
exp = Explorer()
exp.create_embeddings_table()
query = 'SELECT * FROM table WHERE labels LIKE "%person%"'
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
result = exp.plot_sql_query(query)
```
"""
@ -223,7 +231,11 @@ class Explorer:
img = plot_similar_images(result, plot_labels=labels)
return Image.fromarray(img)
def get_similar(self, img=None, idx=None, limit=25, return_type='pandas'):
def get_similar(self,
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
idx: Union[int, List[int]] = None,
limit: int = 25,
return_type: str = 'pandas') -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
"""
Query the table for similar images. Accepts a single image or a list of images.
@ -251,7 +263,11 @@ class Explorer:
elif return_type == 'arrow':
return similar
def plot_similar(self, img=None, idx=None, limit=25, labels=True):
def plot_similar(self,
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
idx: Union[int, List[int]] = None,
limit: int = 25,
labels: bool = True) -> Image.Image:
"""
Plot the similar images. Accepts images or indexes.
@ -275,7 +291,7 @@ class Explorer:
img = plot_similar_images(similar, plot_labels=labels)
return Image.fromarray(img)
def similarity_index(self, max_dist=0.2, top_k=None, force=False):
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
"""
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
are max_dist or closer to the image in the embedding space at a given index.
@ -329,7 +345,7 @@ class Explorer:
self.sim_index = sim_table
return sim_table.to_pandas()
def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False):
def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image:
"""
Plot the similarity index of all the images in the table. Here, the index will contain the data points that are
max_dist or closer to the image in the embedding space at a given index.
@ -341,13 +357,16 @@ class Explorer:
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
Returns:
PIL Image containing the plot.
PIL.PngImagePlugin.PngImageFile containing the plot.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
exp.plot_similarity_index()
similarity_idx_plot = exp.plot_similarity_index()
similarity_idx_plot.show() # view image preview
similarity_idx_plot.save('path/to/save/similarity_index_plot.png') # save contents to file
```
"""
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
@ -368,9 +387,10 @@ class Explorer:
buffer.seek(0)
# Use Pillow to open the image from the buffer
return Image.open(buffer)
return Image.fromarray(np.array(Image.open(buffer)))
def _check_imgs_or_idxs(self, img, idx):
def _check_imgs_or_idxs(self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None],
idx: Union[None, int, List[int]]) -> List[np.ndarray]:
if img is None and idx is None:
raise ValueError('Either img or idx must be provided.')
if img is not None and idx is not None:

Loading…
Cancel
Save