Docs updates: Add Explorer to tab, YOLOv5 in Guides and Usage in Quickstart (#7438)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Haixuan Xavier Tao <tao.xavier@outlook.com>
pull/7446/head
Ayush Chaurasia 1 year ago committed by GitHub
parent 53150a925b
commit a92adf8231
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      docs/en/datasets/explorer/dashboard.md
  2. 13
      docs/en/datasets/explorer/index.md
  3. 7
      docs/en/datasets/index.md
  4. 22
      docs/en/guides/object-blurring.md
  5. 24
      docs/en/guides/object-cropping.md
  6. 6
      docs/en/index.md
  7. 4
      docs/en/reference/cfg/__init__.md
  8. 2
      docs/en/reference/solutions/distance_calculation.md
  9. 2
      docs/en/reference/solutions/speed_estimation.md
  10. 47
      docs/en/usage/python.md
  11. 85
      docs/mkdocs.yml
  12. 15
      tests/test_explorer.py
  13. 2
      ultralytics/data/explorer/__init__.py
  14. 39
      ultralytics/data/explorer/explorer.py
  15. 1
      ultralytics/data/explorer/gui/__init__.py
  16. 16
      ultralytics/data/explorer/gui/dash.py
  17. 6
      ultralytics/data/explorer/utils.py
  18. 20
      ultralytics/data/split_dota.py
  19. 2
      ultralytics/models/rtdetr/val.py
  20. 2
      ultralytics/models/yolo/detect/val.py
  21. 1
      ultralytics/models/yolo/obb/predict.py
  22. 2
      ultralytics/models/yolo/obb/val.py
  23. 2
      ultralytics/models/yolo/pose/val.py
  24. 2
      ultralytics/models/yolo/segment/val.py
  25. 2
      ultralytics/nn/modules/head.py
  26. 1
      ultralytics/nn/tasks.py
  27. 1
      ultralytics/solutions/object_counter.py
  28. 1
      ultralytics/trackers/basetrack.py
  29. 1
      ultralytics/utils/__init__.py
  30. 2
      ultralytics/utils/downloads.py

@ -6,7 +6,7 @@ keywords: Ultralytics, Explorer GUI, semantic search, vector similarity search,
# Explorer GUI # Explorer GUI
Explorer GUI is like a playground build using (Ultralytics Explorer API)[api.md]. It allows you to run semantic/vector similarity search, SQL queries and even search using natural language using our ask AI feature powered by LLMs. Explorer GUI is like a playground build using [Ultralytics Explorer API](api.md). It allows you to run semantic/vector similarity search, SQL queries and even search using natural language using our ask AI feature powered by LLMs.
### Installation ### Installation

@ -6,7 +6,13 @@ keywords: Ultralytics Explorer, CV Dataset Tools, Semantic Search, SQL Dataset Q
# Ultralytics Explorer # Ultralytics Explorer
Ultralytics Explorer is a tool for exploring CV datasets using semantic search, SQL queries and vector similarity search. It is also a Python API for accessing the same functionality. <p>
<img width="1709" alt="Screenshot 2024-01-08 at 7 19 48PM (1)" src="https://github.com/AyushExel/assets/assets/15766192/e536b0eb-6bce-43fe-b800-3e79510d2e5b">
</p>
Ultralytics Explorer is a tool for exploring CV datasets using semantic search, SQL queries, vector similarity search and even using natural language. It is also a Python API for accessing the same functionality.
### Installation of optional dependencies ### Installation of optional dependencies
@ -33,8 +39,3 @@ yolo explorer
!!! note "Note" !!! note "Note"
Ask AI feature works using OpenAI, so you'll be prompted to set the api key for OpenAI when you first run the GUI. Ask AI feature works using OpenAI, so you'll be prompted to set the api key for OpenAI when you first run the GUI.
You can set it like this - `yolo settings openai_api_key="..."` You can set it like this - `yolo settings openai_api_key="..."`
Example
<p>
<img width="1709" alt="Screenshot 2024-01-08 at 7 19 48PM (1)" src="https://github.com/AyushExel/assets/assets/15766192/e536b0eb-6bce-43fe-b800-3e79510d2e5b">
</p>

@ -10,7 +10,12 @@ Ultralytics provides support for various datasets to facilitate computer vision
## 🌟 New: Ultralytics Explorer 🌟 ## 🌟 New: Ultralytics Explorer 🌟
Create embeddings for your dataset, search for similar images, run SQL queries and perform semantic search. You can get started with our GUI app or build your own using the API. Learn more [here](explorer/index.md). Create embeddings for your dataset, search for similar images, run SQL queries, perform semantic search and even search using natural language! You can get started with our GUI app or build your own using the API. Learn more [here](explorer/index.md).
<p>
<img width="1709" alt="Screenshot 2024-01-08 at 7 19 48PM (1)" src="https://github.com/AyushExel/assets/assets/15766192/e536b0eb-6bce-43fe-b800-3e79510d2e5b">
</p>
- Try the [GUI Demo](explorer/index.md) - Try the [GUI Demo](explorer/index.md)
- Learn more about the [Explorer API](explorer/index.md) - Learn more about the [Explorer API](explorer/index.md)

@ -23,47 +23,47 @@ Object blurring with [Ultralytics YOLOv8](https://github.com/ultralytics/ultraly
from ultralytics import YOLO from ultralytics import YOLO
from ultralytics.utils.plotting import Annotator, colors from ultralytics.utils.plotting import Annotator, colors
import cv2 import cv2
model = YOLO("yolov8n.pt") model = YOLO("yolov8n.pt")
names = model.names names = model.names
cap = cv2.VideoCapture("path/to/video/file.mp4") cap = cv2.VideoCapture("path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file" assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
# Blur ratio # Blur ratio
blur_ratio = 50 blur_ratio = 50
# Video writer # Video writer
video_writer = cv2.VideoWriter("object_blurring_output.avi", video_writer = cv2.VideoWriter("object_blurring_output.avi",
cv2.VideoWriter_fourcc(*'mp4v'), cv2.VideoWriter_fourcc(*'mp4v'),
fps, (w, h)) fps, (w, h))
while cap.isOpened(): while cap.isOpened():
success, im0 = cap.read() success, im0 = cap.read()
if not success: if not success:
print("Video frame is empty or video processing has been successfully completed.") print("Video frame is empty or video processing has been successfully completed.")
break break
results = model.predict(im0, show=False) results = model.predict(im0, show=False)
boxes = results[0].boxes.xyxy.cpu().tolist() boxes = results[0].boxes.xyxy.cpu().tolist()
clss = results[0].boxes.cls.cpu().tolist() clss = results[0].boxes.cls.cpu().tolist()
annotator = Annotator(im0, line_width=2, example=names) annotator = Annotator(im0, line_width=2, example=names)
if boxes is not None: if boxes is not None:
for box, cls in zip(boxes, clss): for box, cls in zip(boxes, clss):
annotator.box_label(box, color=colors(int(cls), True), label=names[int(cls)]) annotator.box_label(box, color=colors(int(cls), True), label=names[int(cls)])
obj = im0[int(box[1]):int(box[3]), int(box[0]):int(box[2])] obj = im0[int(box[1]):int(box[3]), int(box[0]):int(box[2])]
blur_obj = cv2.blur(obj, (blur_ratio, blur_ratio)) blur_obj = cv2.blur(obj, (blur_ratio, blur_ratio))
im0[int(box[1]):int(box[3]), int(box[0]):int(box[2])] = blur_obj im0[int(box[1]):int(box[3]), int(box[0]):int(box[2])] = blur_obj
cv2.imshow("ultralytics", im0) cv2.imshow("ultralytics", im0)
video_writer.write(im0) video_writer.write(im0)
if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
break break
cap.release() cap.release()
video_writer.release() video_writer.release()
cv2.destroyAllWindows() cv2.destroyAllWindows()

@ -24,50 +24,50 @@ Object cropping with [Ultralytics YOLOv8](https://github.com/ultralytics/ultraly
from ultralytics.utils.plotting import Annotator, colors from ultralytics.utils.plotting import Annotator, colors
import cv2 import cv2
import os import os
model = YOLO("yolov8n.pt") model = YOLO("yolov8n.pt")
names = model.names names = model.names
cap = cv2.VideoCapture("path/to/video/file.mp4") cap = cv2.VideoCapture("path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file" assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
crop_dir_name = "ultralytics_crop" crop_dir_name = "ultralytics_crop"
if not os.path.exists(crop_dir_name): if not os.path.exists(crop_dir_name):
os.mkdir(crop_dir_name) os.mkdir(crop_dir_name)
# Video writer # Video writer
video_writer = cv2.VideoWriter("object_cropping_output.avi", video_writer = cv2.VideoWriter("object_cropping_output.avi",
cv2.VideoWriter_fourcc(*'mp4v'), cv2.VideoWriter_fourcc(*'mp4v'),
fps, (w, h)) fps, (w, h))
idx = 0 idx = 0
while cap.isOpened(): while cap.isOpened():
success, im0 = cap.read() success, im0 = cap.read()
if not success: if not success:
print("Video frame is empty or video processing has been successfully completed.") print("Video frame is empty or video processing has been successfully completed.")
break break
results = model.predict(im0, show=False) results = model.predict(im0, show=False)
boxes = results[0].boxes.xyxy.cpu().tolist() boxes = results[0].boxes.xyxy.cpu().tolist()
clss = results[0].boxes.cls.cpu().tolist() clss = results[0].boxes.cls.cpu().tolist()
annotator = Annotator(im0, line_width=2, example=names) annotator = Annotator(im0, line_width=2, example=names)
if boxes is not None: if boxes is not None:
for box, cls in zip(boxes, clss): for box, cls in zip(boxes, clss):
idx += 1 idx += 1
annotator.box_label(box, color=colors(int(cls), True), label=names[int(cls)]) annotator.box_label(box, color=colors(int(cls), True), label=names[int(cls)])
crop_obj = im0[int(box[1]):int(box[3]), int(box[0]):int(box[2])] crop_obj = im0[int(box[1]):int(box[3]), int(box[0]):int(box[2])]
cv2.imwrite(os.path.join(crop_dir_name, str(idx)+".png"), crop_obj) cv2.imwrite(os.path.join(crop_dir_name, str(idx)+".png"), crop_obj)
cv2.imshow("ultralytics", im0) cv2.imshow("ultralytics", im0)
video_writer.write(im0) video_writer.write(im0)
if cv2.waitKey(1) & 0xFF == ord('q'): if cv2.waitKey(1) & 0xFF == ord('q'):
break break
cap.release() cap.release()
video_writer.release() video_writer.release()
cv2.destroyAllWindows() cv2.destroyAllWindows()

@ -39,17 +39,13 @@ Introducing [Ultralytics](https://ultralytics.com) [YOLOv8](https://github.com/u
Explore the YOLOv8 Docs, a comprehensive resource designed to help you understand and utilize its features and capabilities. Whether you are a seasoned machine learning practitioner or new to the field, this hub aims to maximize YOLOv8's potential in your projects Explore the YOLOv8 Docs, a comprehensive resource designed to help you understand and utilize its features and capabilities. Whether you are a seasoned machine learning practitioner or new to the field, this hub aims to maximize YOLOv8's potential in your projects
# 🌟 New: Ultralytics Explorer 🌟
Create embeddings for your dataset, search for similar images, run SQL queries and perform semantic search. You can get started with our GUI app or build your own using the API. Learn more [here](datasets/explorer/index.md).
## Where to Start ## Where to Start
- **Install** `ultralytics` with pip and get up and running in minutes &nbsp; [:material-clock-fast: Get Started](quickstart.md){ .md-button } - **Install** `ultralytics` with pip and get up and running in minutes &nbsp; [:material-clock-fast: Get Started](quickstart.md){ .md-button }
- **Predict** new images and videos with YOLOv8 &nbsp; [:octicons-image-16: Predict on Images](modes/predict.md){ .md-button } - **Predict** new images and videos with YOLOv8 &nbsp; [:octicons-image-16: Predict on Images](modes/predict.md){ .md-button }
- **Train** a new YOLOv8 model on your own custom dataset &nbsp; [:fontawesome-solid-brain: Train a Model](modes/train.md){ .md-button } - **Train** a new YOLOv8 model on your own custom dataset &nbsp; [:fontawesome-solid-brain: Train a Model](modes/train.md){ .md-button }
- **Tasks** YOLOv8 tasks like segment, classify, pose and track &nbsp; [:material-magnify-expand: Explore Tasks](tasks/index.md){ .md-button } - **Tasks** YOLOv8 tasks like segment, classify, pose and track &nbsp; [:material-magnify-expand: Explore Tasks](tasks/index.md){ .md-button }
- **Explore** datasets with advanced semantic and SQL search &nbsp; [:material-magnify-expand: Run Explorer](datasets/explorer/index.md){ .md-button } - **NEW 🚀 Explore** datasets with advanced semantic and SQL search &nbsp; [:material-magnify-expand: Explore a Dataset](datasets/explorer/index.md){ .md-button }
<p align="center"> <p align="center">
<br> <br>

@ -43,6 +43,10 @@ keywords: Ultralytics, YOLO, Configuration, cfg2dict, handle_deprecation, merge_
<br><br> <br><br>
## ::: ultralytics.cfg.handle_explorer
<br><br>
## ::: ultralytics.cfg.parse_key_value_pair ## ::: ultralytics.cfg.parse_key_value_pair
<br><br> <br><br>

@ -7,7 +7,7 @@ keywords: Ultralytics, YOLO, distance calculation, object tracking, data visuali
!!! Note !!! Note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/distance_calculation.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/distance_calculation.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/solutions/heatmap.py) 🛠. Thank you 🙏! This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/distance_calculation.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/distance_calculation.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/solutions/distance_calculation.py) 🛠. Thank you 🙏!
<br><br> <br><br>

@ -7,7 +7,7 @@ keywords: Ultralytics YOLO, speed estimation software, real-time vehicle trackin
!!! Note !!! Note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/speed_estimation.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/speed_estimation.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/solutions/object_counter.py) 🛠. Thank you 🙏! This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/speed_estimation.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/speed_estimation.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/solutions/speed_estimation.py) 🛠. Thank you 🙏!
<br><br> <br><br>

@ -240,6 +240,53 @@ Benchmark mode is used to profile the speed and accuracy of various export forma
[Benchmark Examples](../modes/benchmark.md){ .md-button } [Benchmark Examples](../modes/benchmark.md){ .md-button }
## Explorer
Explorer API can be used to explore datasets with advanced semantic, vector-similarity and SQL search among other features. It also searching for images based on their content using natural language by utilizing the power of LLMs. The Explorer API allows you to write your own dataset exploration notebooks or scripts to get insights into your datasets.
!!! Example "Semantic Search Using Explorer"
=== "Using Images"
```python
from ultralytics import Explorer
# create an Explorer object
exp = Explorer(data='coco128.yaml', model='yolov8n.pt')
exp.create_embeddings_table()
similar = exp.get_similar(img='https://ultralytics.com/images/bus.jpg', limit=10)
print(similar.head())
# Search using multiple indices
similar = exp.get_similar(
img=['https://ultralytics.com/images/bus.jpg',
'https://ultralytics.com/images/bus.jpg'],
limit=10
)
print(similar.head())
```
=== "Using Dataset Indices"
```python
from ultralytics import Explorer
# create an Explorer object
exp = Explorer(data='coco128.yaml', model='yolov8n.pt')
exp.create_embeddings_table()
similar = exp.get_similar(idx=1, limit=10)
print(similar.head())
# Search using multiple indices
similar = exp.get_similar(idx=[1,10], limit=10)
print(similar.head())
```
[Explorer](../datasets/explorer/index.md){ .md-button }
## Using Trainers ## Using Trainers
`YOLO` model class is a high-level wrapper on the Trainer classes. Each YOLO task has its own trainer that inherits from `BaseTrainer`. `YOLO` model class is a high-level wrapper on the Trainer classes. Each YOLO task has its own trainer that inherits from `BaseTrainer`.

@ -170,12 +170,14 @@ nav:
- Classify: tasks/classify.md - Classify: tasks/classify.md
- Pose: tasks/pose.md - Pose: tasks/pose.md
- OBB: tasks/obb.md - OBB: tasks/obb.md
- Guides:
- guides/index.md
- Models: - Models:
- models/index.md - models/index.md
- Datasets: - Datasets:
- datasets/index.md - datasets/index.md
- Guides:
- guides/index.md
- NEW 🚀 Explorer:
- datasets/explorer/index.md
- Languages: - Languages:
- 🇬🇧&nbsp English: https://docs.ultralytics.com/ - 🇬🇧&nbsp English: https://docs.ultralytics.com/
- 🇨🇳&nbsp 简体中文: https://docs.ultralytics.com/zh/ - 🇨🇳&nbsp 简体中文: https://docs.ultralytics.com/zh/
@ -188,7 +190,14 @@ nav:
- 🇵🇹&nbsp Português: https://docs.ultralytics.com/pt/ - 🇵🇹&nbsp Português: https://docs.ultralytics.com/pt/
- 🇮🇳&nbsp हि: https://docs.ultralytics.com/hi/ - 🇮🇳&nbsp हि: https://docs.ultralytics.com/hi/
- 🇸🇦&nbsp العربية: https://docs.ultralytics.com/ar/ - 🇸🇦&nbsp العربية: https://docs.ultralytics.com/ar/
- Quickstart: quickstart.md - Quickstart:
- quickstart.md
- Usage:
- CLI: usage/cli.md
- Python: usage/python.md
- Callbacks: usage/callbacks.md
- Configuration: usage/cfg.md
- Advanced Customization: usage/engine.md
- Modes: - Modes:
- modes/index.md - modes/index.md
- Train: modes/train.md - Train: modes/train.md
@ -219,7 +228,7 @@ nav:
- RT-DETR (Realtime Detection Transformer): models/rtdetr.md - RT-DETR (Realtime Detection Transformer): models/rtdetr.md
- Datasets: - Datasets:
- datasets/index.md - datasets/index.md
- Explorer: - NEW 🚀 Explorer:
- datasets/explorer/index.md - datasets/explorer/index.md
- Explorer API: datasets/explorer/api.md - Explorer API: datasets/explorer/api.md
- Explorer Dashboard: datasets/explorer/dashboard.md - Explorer Dashboard: datasets/explorer/dashboard.md
@ -263,6 +272,11 @@ nav:
- DOTA8: datasets/obb/dota8.md - DOTA8: datasets/obb/dota8.md
- Multi-Object Tracking: - Multi-Object Tracking:
- datasets/track/index.md - datasets/track/index.md
- NEW 🚀 Explorer:
- datasets/explorer/index.md
- Explorer API: datasets/explorer/api.md
- Explorer Dashboard Demo: datasets/explorer/dashboard.md
- VOC Exploration Example: datasets/explorer/explorer.ipynb
- Guides: - Guides:
- guides/index.md - guides/index.md
- YOLO Common Issues: guides/yolo-common-issues.md - YOLO Common Issues: guides/yolo-common-issues.md
@ -290,6 +304,31 @@ nav:
- VisionEye Mapping: guides/vision-eye.md - VisionEye Mapping: guides/vision-eye.md
- Speed Estimation: guides/speed-estimation.md - Speed Estimation: guides/speed-estimation.md
- Distance Calculation: guides/distance-calculation.md - Distance Calculation: guides/distance-calculation.md
- YOLOv5:
- yolov5/index.md
- Quickstart: yolov5/quickstart_tutorial.md
- Environments:
- Amazon Web Services (AWS): yolov5/environments/aws_quickstart_tutorial.md
- Google Cloud (GCP): yolov5/environments/google_cloud_quickstart_tutorial.md
- AzureML: yolov5/environments/azureml_quickstart_tutorial.md
- Docker Image: yolov5/environments/docker_image_quickstart_tutorial.md
- Tutorials:
- Train Custom Data: yolov5/tutorials/train_custom_data.md
- Tips for Best Training Results: yolov5/tutorials/tips_for_best_training_results.md
- Multi-GPU Training: yolov5/tutorials/multi_gpu_training.md
- PyTorch Hub: yolov5/tutorials/pytorch_hub_model_loading.md
- TFLite, ONNX, CoreML, TensorRT Export: yolov5/tutorials/model_export.md
- NVIDIA Jetson Nano Deployment: yolov5/tutorials/running_on_jetson_nano.md
- Test-Time Augmentation (TTA): yolov5/tutorials/test_time_augmentation.md
- Model Ensembling: yolov5/tutorials/model_ensembling.md
- Pruning/Sparsity Tutorial: yolov5/tutorials/model_pruning_and_sparsity.md
- Hyperparameter evolution: yolov5/tutorials/hyperparameter_evolution.md
- Transfer learning with frozen layers: yolov5/tutorials/transfer_learning_with_frozen_layers.md
- Architecture Summary: yolov5/tutorials/architecture_description.md
- Roboflow Datasets: yolov5/tutorials/roboflow_datasets_integration.md
- Neural Magic's DeepSparse: yolov5/tutorials/neural_magic_pruning_quantization.md
- Comet Logging: yolov5/tutorials/comet_logging_integration.md
- Clearml Logging: yolov5/tutorials/clearml_logging_integration.md
- Integrations: - Integrations:
- integrations/index.md - integrations/index.md
- Comet ML: integrations/comet.md - Comet ML: integrations/comet.md
@ -303,37 +342,6 @@ nav:
- Neural Magic: integrations/neural-magic.md - Neural Magic: integrations/neural-magic.md
- TensorBoard: integrations/tensorboard.md - TensorBoard: integrations/tensorboard.md
- Amazon SageMaker: integrations/amazon-sagemaker.md - Amazon SageMaker: integrations/amazon-sagemaker.md
- Usage:
- CLI: usage/cli.md
- Python: usage/python.md
- Callbacks: usage/callbacks.md
- Configuration: usage/cfg.md
- Advanced Customization: usage/engine.md
- YOLOv5:
- yolov5/index.md
- Quickstart: yolov5/quickstart_tutorial.md
- Environments:
- Amazon Web Services (AWS): yolov5/environments/aws_quickstart_tutorial.md
- Google Cloud (GCP): yolov5/environments/google_cloud_quickstart_tutorial.md
- AzureML: yolov5/environments/azureml_quickstart_tutorial.md
- Docker Image: yolov5/environments/docker_image_quickstart_tutorial.md
- Tutorials:
- Train Custom Data: yolov5/tutorials/train_custom_data.md
- Tips for Best Training Results: yolov5/tutorials/tips_for_best_training_results.md
- Multi-GPU Training: yolov5/tutorials/multi_gpu_training.md
- PyTorch Hub: yolov5/tutorials/pytorch_hub_model_loading.md
- TFLite, ONNX, CoreML, TensorRT Export: yolov5/tutorials/model_export.md
- NVIDIA Jetson Nano Deployment: yolov5/tutorials/running_on_jetson_nano.md
- Test-Time Augmentation (TTA): yolov5/tutorials/test_time_augmentation.md
- Model Ensembling: yolov5/tutorials/model_ensembling.md
- Pruning/Sparsity Tutorial: yolov5/tutorials/model_pruning_and_sparsity.md
- Hyperparameter evolution: yolov5/tutorials/hyperparameter_evolution.md
- Transfer learning with frozen layers: yolov5/tutorials/transfer_learning_with_frozen_layers.md
- Architecture Summary: yolov5/tutorials/architecture_description.md
- Roboflow Datasets: yolov5/tutorials/roboflow_datasets_integration.md
- Neural Magic's DeepSparse: yolov5/tutorials/neural_magic_pruning_quantization.md
- Comet Logging: yolov5/tutorials/comet_logging_integration.md
- Clearml Logging: yolov5/tutorials/clearml_logging_integration.md
- HUB: - HUB:
- hub/index.md - hub/index.md
- Quickstart: hub/quickstart.md - Quickstart: hub/quickstart.md
@ -357,6 +365,11 @@ nav:
- build: reference/data/build.md - build: reference/data/build.md
- converter: reference/data/converter.md - converter: reference/data/converter.md
- dataset: reference/data/dataset.md - dataset: reference/data/dataset.md
- explorer:
- explorer: reference/data/explorer/explorer.md
- gui:
- dash: reference/data/explorer/gui/dash.md
- utils: reference/data/explorer/utils.md
- loaders: reference/data/loaders.md - loaders: reference/data/loaders.md
- split_dota: reference/data/split_dota.md - split_dota: reference/data/split_dota.md
- utils: reference/data/utils.md - utils: reference/data/utils.md
@ -436,10 +449,10 @@ nav:
- tasks: reference/nn/tasks.md - tasks: reference/nn/tasks.md
- solutions: - solutions:
- ai_gym: reference/solutions/ai_gym.md - ai_gym: reference/solutions/ai_gym.md
- distance_calculation: reference/solutions/distance_calculation.md
- heatmap: reference/solutions/heatmap.md - heatmap: reference/solutions/heatmap.md
- object_counter: reference/solutions/object_counter.md - object_counter: reference/solutions/object_counter.md
- speed_estimation: reference/solutions/speed_estimation.md - speed_estimation: reference/solutions/speed_estimation.md
- distance_calculation: reference/solutions/distance_calculation.md
- trackers: - trackers:
- basetrack: reference/trackers/basetrack.md - basetrack: reference/trackers/basetrack.md
- bot_sort: reference/trackers/bot_sort.md - bot_sort: reference/trackers/bot_sort.md

@ -1,8 +1,11 @@
import PIL
from ultralytics import Explorer from ultralytics import Explorer
from ultralytics.utils import ASSETS from ultralytics.utils import ASSETS
def test_similarity(): def test_similarity():
"""Test similarity calculations and SQL queries for correctness and response length."""
exp = Explorer() exp = Explorer()
exp.create_embeddings_table() exp.create_embeddings_table()
similar = exp.get_similar(idx=1) similar = exp.get_similar(idx=1)
@ -18,6 +21,7 @@ def test_similarity():
def test_det(): def test_det():
"""Test detection functionalities and ensure the embedding table has bounding boxes."""
exp = Explorer(data='coco8.yaml', model='yolov8n.pt') exp = Explorer(data='coco8.yaml', model='yolov8n.pt')
exp.create_embeddings_table(force=True) exp.create_embeddings_table(force=True)
assert len(exp.table.head()['bboxes']) > 0 assert len(exp.table.head()['bboxes']) > 0
@ -25,27 +29,26 @@ def test_det():
assert len(similar) > 0 assert len(similar) > 0
# This is a loose test, just checks errors not correctness # This is a loose test, just checks errors not correctness
similar = exp.plot_similar(idx=[1, 2], limit=10) similar = exp.plot_similar(idx=[1, 2], limit=10)
assert similar is not None assert isinstance(similar, PIL.Image.Image)
similar.show()
def test_seg(): def test_seg():
"""Test segmentation functionalities and verify the embedding table includes masks."""
exp = Explorer(data='coco8-seg.yaml', model='yolov8n-seg.pt') exp = Explorer(data='coco8-seg.yaml', model='yolov8n-seg.pt')
exp.create_embeddings_table(force=True) exp.create_embeddings_table(force=True)
assert len(exp.table.head()['masks']) > 0 assert len(exp.table.head()['masks']) > 0
similar = exp.get_similar(idx=[1, 2], limit=10) similar = exp.get_similar(idx=[1, 2], limit=10)
assert len(similar) > 0 assert len(similar) > 0
similar = exp.plot_similar(idx=[1, 2], limit=10) similar = exp.plot_similar(idx=[1, 2], limit=10)
assert similar is not None assert isinstance(similar, PIL.Image.Image)
similar.show()
def test_pose(): def test_pose():
"""Test pose estimation functionalities and check the embedding table for keypoints."""
exp = Explorer(data='coco8-pose.yaml', model='yolov8n-pose.pt') exp = Explorer(data='coco8-pose.yaml', model='yolov8n-pose.pt')
exp.create_embeddings_table(force=True) exp.create_embeddings_table(force=True)
assert len(exp.table.head()['keypoints']) > 0 assert len(exp.table.head()['keypoints']) > 0
similar = exp.get_similar(idx=[1, 2], limit=10) similar = exp.get_similar(idx=[1, 2], limit=10)
assert len(similar) > 0 assert len(similar) > 0
similar = exp.plot_similar(idx=[1, 2], limit=10) similar = exp.plot_similar(idx=[1, 2], limit=10)
assert similar is not None assert isinstance(similar, PIL.Image.Image)
similar.show()

@ -1,3 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .utils import plot_query_result from .utils import plot_query_result
__all__ = ['plot_query_result'] __all__ = ['plot_query_result']

@ -1,3 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Any, List, Tuple, Union from typing import Any, List, Tuple, Union
@ -24,9 +26,8 @@ class ExplorerDataset(YOLODataset):
def __init__(self, *args, data: dict = None, **kwargs) -> None: def __init__(self, *args, data: dict = None, **kwargs) -> None:
super().__init__(*args, data=data, **kwargs) super().__init__(*args, data=data, **kwargs)
# NOTE: Load the image directly without any resize operations.
def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]: def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]:
"""Loads 1 image from dataset index 'i', returns (im, resized hw).""" """Loads 1 image from dataset index 'i' without any resize ops."""
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
if im is None: # not cached in RAM if im is None: # not cached in RAM
if fn.exists(): # load npy if fn.exists(): # load npy
@ -41,6 +42,7 @@ class ExplorerDataset(YOLODataset):
return self.ims[i], self.im_hw0[i], self.im_hw[i] return self.ims[i], self.im_hw0[i], self.im_hw[i]
def build_transforms(self, hyp: IterableSimpleNamespace = None): def build_transforms(self, hyp: IterableSimpleNamespace = None):
"""Creates transforms for dataset images without resizing."""
return Format( return Format(
bbox_format='xyxy', bbox_format='xyxy',
normalize=False, normalize=False,
@ -122,7 +124,7 @@ class Explorer:
self.table = table self.table = table
def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]): def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]):
# Implement Batching """Generates batches of data for embedding, excluding specified keys."""
for i in tqdm(range(len(dataset))): for i in tqdm(range(len(dataset))):
self.progress = float(i + 1) / len(dataset) self.progress = float(i + 1) / len(dataset)
batch = dataset[i] batch = dataset[i]
@ -143,7 +145,7 @@ class Explorer:
limit (int): Number of results to return. limit (int): Number of results to return.
Returns: Returns:
An arrow table containing the results. Supports converting to: (pyarrow.Table): An arrow table containing the results. Supports converting to:
- pandas dataframe: `result.to_pandas()` - pandas dataframe: `result.to_pandas()`
- dict of lists: `result.to_pydict()` - dict of lists: `result.to_pydict()`
@ -175,7 +177,7 @@ class Explorer:
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'. return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
Returns: Returns:
An arrow table containing the results. (pyarrow.Table): An arrow table containing the results.
Example: Example:
```python ```python
@ -216,7 +218,7 @@ class Explorer:
labels (bool): Whether to plot the labels or not. labels (bool): Whether to plot the labels or not.
Returns: Returns:
PIL Image containing the plot. (PIL.Image): Image containing the plot.
Example: Example:
```python ```python
@ -248,7 +250,7 @@ class Explorer:
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'. return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
Returns: Returns:
A table or pandas dataframe containing the results. (pandas.DataFrame): A dataframe containing the results.
Example: Example:
```python ```python
@ -282,7 +284,7 @@ class Explorer:
limit (int): Number of results to return. Defaults to 25. limit (int): Number of results to return. Defaults to 25.
Returns: Returns:
PIL Image containing the plot. (PIL.Image): Image containing the plot.
Example: Example:
```python ```python
@ -306,11 +308,12 @@ class Explorer:
Args: Args:
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2. max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
vector search. Defaults: None. vector search. Defaults: None.
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True. force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
Returns: Returns:
A pandas dataframe containing the similarity index. (pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image, and columns
include indices of similar images and their respective distances.
Example: Example:
```python ```python
@ -340,6 +343,7 @@ class Explorer:
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode='overwrite') sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode='overwrite')
def _yield_sim_idx(): def _yield_sim_idx():
"""Generates a dataframe with similarity indices and distances for images."""
for i in tqdm(range(len(embeddings))): for i in tqdm(range(len(embeddings))):
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f'_distance <= {max_dist}') sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f'_distance <= {max_dist}')
yield [{ yield [{
@ -364,7 +368,7 @@ class Explorer:
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True. force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
Returns: Returns:
PIL.PngImagePlugin.PngImageFile containing the plot. (PIL.Image): Image containing the plot.
Example: Example:
```python ```python
@ -416,7 +420,7 @@ class Explorer:
query (str): Question to ask. query (str): Question to ask.
Returns: Returns:
Answer from AI. (pandas.DataFrame): A dataframe containing filtered results to the SQL query.
Example: Example:
```python ```python
@ -436,14 +440,17 @@ class Explorer:
def visualize(self, result): def visualize(self, result):
""" """
Visualize the results of a query. Visualize the results of a query. TODO.
Args: Args:
result (arrow table): Arrow table containing the results of a query. result (pyarrow.Table): Table containing the results of a query.
""" """
# TODO:
pass pass
def generate_report(self, result): def generate_report(self, result):
"""Generate a report of the dataset.""" """
Generate a report of the dataset.
TODO
"""
pass pass

@ -0,0 +1 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license

@ -1,3 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import time import time
from threading import Thread from threading import Thread
@ -7,13 +9,13 @@ from ultralytics import Explorer
from ultralytics.utils import ROOT, SETTINGS from ultralytics.utils import ROOT, SETTINGS
from ultralytics.utils.checks import check_requirements from ultralytics.utils.checks import check_requirements
check_requirements('streamlit>=1.29.0') check_requirements(('streamlit>=1.29.0', 'streamlit-select>=0.2'))
check_requirements('streamlit-select>=0.2')
import streamlit as st import streamlit as st
from streamlit_select import image_select from streamlit_select import image_select
def _get_explorer(): def _get_explorer():
"""Initializes and returns an instance of the Explorer class."""
exp = Explorer(data=st.session_state.get('dataset'), model=st.session_state.get('model')) exp = Explorer(data=st.session_state.get('dataset'), model=st.session_state.get('model'))
thread = Thread(target=exp.create_embeddings_table, thread = Thread(target=exp.create_embeddings_table,
kwargs={'force': st.session_state.get('force_recreate_embeddings')}) kwargs={'force': st.session_state.get('force_recreate_embeddings')})
@ -28,6 +30,7 @@ def _get_explorer():
def init_explorer_form(): def init_explorer_form():
"""Initializes an Explorer instance and creates embeddings table with progress tracking."""
datasets = ROOT / 'cfg' / 'datasets' datasets = ROOT / 'cfg' / 'datasets'
ds = [d.name for d in datasets.glob('*.yaml')] ds = [d.name for d in datasets.glob('*.yaml')]
models = [ models = [
@ -46,6 +49,7 @@ def init_explorer_form():
def query_form(): def query_form():
"""Sets up a form in Streamlit to initialize Explorer with dataset and model selection."""
with st.form('query_form'): with st.form('query_form'):
col1, col2 = st.columns([0.8, 0.2]) col1, col2 = st.columns([0.8, 0.2])
with col1: with col1:
@ -58,6 +62,7 @@ def query_form():
def ai_query_form(): def ai_query_form():
"""Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection."""
with st.form('ai_query_form'): with st.form('ai_query_form'):
col1, col2 = st.columns([0.8, 0.2]) col1, col2 = st.columns([0.8, 0.2])
with col1: with col1:
@ -67,6 +72,7 @@ def ai_query_form():
def find_similar_imgs(imgs): def find_similar_imgs(imgs):
"""Initializes a Streamlit form for AI-based image querying with custom input."""
exp = st.session_state['explorer'] exp = st.session_state['explorer']
similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow') similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow')
paths = similar.to_pydict()['im_file'] paths = similar.to_pydict()['im_file']
@ -74,6 +80,7 @@ def find_similar_imgs(imgs):
def similarity_form(selected_imgs): def similarity_form(selected_imgs):
"""Initializes a form for AI-based image querying with custom input in Streamlit."""
st.write('Similarity Search') st.write('Similarity Search')
with st.form('similarity_form'): with st.form('similarity_form'):
subcol1, subcol2 = st.columns([1, 1]) subcol1, subcol2 = st.columns([1, 1])
@ -109,6 +116,7 @@ def similarity_form(selected_imgs):
def run_sql_query(): def run_sql_query():
"""Executes an SQL query and returns the results."""
st.session_state['error'] = None st.session_state['error'] = None
query = st.session_state.get('query') query = st.session_state.get('query')
if query.rstrip().lstrip(): if query.rstrip().lstrip():
@ -118,6 +126,7 @@ def run_sql_query():
def run_ai_query(): def run_ai_query():
"""Execute SQL query and update session state with query results."""
if not SETTINGS['openai_api_key']: if not SETTINGS['openai_api_key']:
st.session_state[ st.session_state[
'error'] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."' 'error'] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
@ -134,12 +143,14 @@ def run_ai_query():
def reset_explorer(): def reset_explorer():
"""Resets the explorer to its initial state by clearing session variables."""
st.session_state['explorer'] = None st.session_state['explorer'] = None
st.session_state['imgs'] = None st.session_state['imgs'] = None
st.session_state['error'] = None st.session_state['error'] = None
def utralytics_explorer_docs_callback(): def utralytics_explorer_docs_callback():
"""Resets the explorer to its initial state by clearing session variables."""
with st.container(border=True): with st.container(border=True):
st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg', st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg',
width=100) width=100)
@ -151,6 +162,7 @@ def utralytics_explorer_docs_callback():
def layout(): def layout():
"""Resets explorer session variables and provides documentation with a link to API docs."""
st.set_page_config(layout='wide', initial_sidebar_state='collapsed') st.set_page_config(layout='wide', initial_sidebar_state='collapsed')
st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True) st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True)

@ -1,3 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import getpass import getpass
from typing import List from typing import List
@ -14,6 +16,7 @@ from ultralytics.utils.plotting import plot_images
def get_table_schema(vector_size): def get_table_schema(vector_size):
"""Extracts and returns the schema of a database table."""
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector
class Schema(LanceModel): class Schema(LanceModel):
@ -29,6 +32,7 @@ def get_table_schema(vector_size):
def get_sim_index_schema(): def get_sim_index_schema():
"""Returns a LanceModel schema for a database table with specified vector size."""
from lancedb.pydantic import LanceModel from lancedb.pydantic import LanceModel
class Schema(LanceModel): class Schema(LanceModel):
@ -41,6 +45,7 @@ def get_sim_index_schema():
def sanitize_batch(batch, dataset_info): def sanitize_batch(batch, dataset_info):
"""Sanitizes input batch for inference, ensuring correct format and dimensions."""
batch['cls'] = batch['cls'].flatten().int().tolist() batch['cls'] = batch['cls'].flatten().int().tolist()
box_cls_pair = sorted(zip(batch['bboxes'].tolist(), batch['cls']), key=lambda x: x[1]) box_cls_pair = sorted(zip(batch['bboxes'].tolist(), batch['cls']), key=lambda x: x[1])
batch['bboxes'] = [box for box, _ in box_cls_pair] batch['bboxes'] = [box for box, _ in box_cls_pair]
@ -111,6 +116,7 @@ def plot_query_result(similar_set, plot_labels=True):
def prompt_sql_query(query): def prompt_sql_query(query):
"""Plots images with optional labels from a similar data set."""
check_requirements('openai>=1.6.1') check_requirements('openai>=1.6.1')
from openai import OpenAI from openai import OpenAI

@ -1,3 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import itertools import itertools
import os import os
from glob import glob from glob import glob
@ -53,10 +55,13 @@ def bbox_iof(polygon1, bbox2, eps=1e-6):
def load_yolo_dota(data_root, split='train'): def load_yolo_dota(data_root, split='train'):
"""Load DOTA dataset. """
Load DOTA dataset.
Args: Args:
data_root (str): Data root. data_root (str): Data root.
split (str): The split data set, could be train or val. split (str): The split data set, could be train or val.
Notes: Notes:
The directory structure assumed for the DOTA dataset: The directory structure assumed for the DOTA dataset:
- data_root - data_root
@ -133,7 +138,7 @@ def get_window_obj(anno, windows, iof_thr=0.7):
label[:, 1::2] *= w label[:, 1::2] *= w
label[:, 2::2] *= h label[:, 2::2] *= h
iofs = bbox_iof(label[:, 1:], windows) iofs = bbox_iof(label[:, 1:], windows)
# unnormalized and misaligned coordinates # Unnormalized and misaligned coordinates
window_anns = [(label[iofs[:, i] >= iof_thr]) for i in range(len(windows))] window_anns = [(label[iofs[:, i] >= iof_thr]) for i in range(len(windows))]
else: else:
window_anns = [np.zeros((0, 9), dtype=np.float32) for _ in range(len(windows))] window_anns = [np.zeros((0, 9), dtype=np.float32) for _ in range(len(windows))]
@ -141,13 +146,16 @@ def get_window_obj(anno, windows, iof_thr=0.7):
def crop_and_save(anno, windows, window_objs, im_dir, lb_dir): def crop_and_save(anno, windows, window_objs, im_dir, lb_dir):
"""Crop images and save new labels. """
Crop images and save new labels.
Args: Args:
anno (dict): Annotation dict, including `filepath`, `label`, `ori_size` as its keys. anno (dict): Annotation dict, including `filepath`, `label`, `ori_size` as its keys.
windows (list): A list of windows coordinates. windows (list): A list of windows coordinates.
window_objs (list): A list of labels inside each window. window_objs (list): A list of labels inside each window.
im_dir (str): The output directory path of images. im_dir (str): The output directory path of images.
lb_dir (str): The output directory path of labels. lb_dir (str): The output directory path of labels.
Notes: Notes:
The directory structure assumed for the DOTA dataset: The directory structure assumed for the DOTA dataset:
- data_root - data_root
@ -185,7 +193,7 @@ def split_images_and_labels(data_root, save_dir, split='train', crop_sizes=[1024
""" """
Split both images and labels. Split both images and labels.
NOTES: Notes:
The directory structure assumed for the DOTA dataset: The directory structure assumed for the DOTA dataset:
- data_root - data_root
- images - images
@ -215,7 +223,7 @@ def split_trainval(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
""" """
Split train and val set of DOTA. Split train and val set of DOTA.
NOTES: Notes:
The directory structure assumed for the DOTA dataset: The directory structure assumed for the DOTA dataset:
- data_root - data_root
- images - images
@ -245,7 +253,7 @@ def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
""" """
Split test set of DOTA, labels are not included within this set. Split test set of DOTA, labels are not included within this set.
NOTES: Notes:
The directory structure assumed for the DOTA dataset: The directory structure assumed for the DOTA dataset:
- data_root - data_root
- images - images

@ -107,6 +107,7 @@ class RTDETRValidator(DetectionValidator):
return outputs return outputs
def _prepare_batch(self, si, batch): def _prepare_batch(self, si, batch):
"""Prepares a batch for training or inference by applying transformations."""
idx = batch['batch_idx'] == si idx = batch['batch_idx'] == si
cls = batch['cls'][idx].squeeze(-1) cls = batch['cls'][idx].squeeze(-1)
bbox = batch['bboxes'][idx] bbox = batch['bboxes'][idx]
@ -121,6 +122,7 @@ class RTDETRValidator(DetectionValidator):
return prepared_batch return prepared_batch
def _prepare_pred(self, pred, pbatch): def _prepare_pred(self, pred, pbatch):
"""Prepares and returns a batch with transformed bounding boxes and class labels."""
predn = pred.clone() predn = pred.clone()
predn[..., [0, 2]] *= pbatch['ori_shape'][1] / self.args.imgsz # native-space pred predn[..., [0, 2]] *= pbatch['ori_shape'][1] / self.args.imgsz # native-space pred
predn[..., [1, 3]] *= pbatch['ori_shape'][0] / self.args.imgsz # native-space pred predn[..., [1, 3]] *= pbatch['ori_shape'][0] / self.args.imgsz # native-space pred

@ -87,6 +87,7 @@ class DetectionValidator(BaseValidator):
max_det=self.args.max_det) max_det=self.args.max_det)
def _prepare_batch(self, si, batch): def _prepare_batch(self, si, batch):
"""Prepares a batch of images and annotations for validation."""
idx = batch['batch_idx'] == si idx = batch['batch_idx'] == si
cls = batch['cls'][idx].squeeze(-1) cls = batch['cls'][idx].squeeze(-1)
bbox = batch['bboxes'][idx] bbox = batch['bboxes'][idx]
@ -100,6 +101,7 @@ class DetectionValidator(BaseValidator):
return prepared_batch return prepared_batch
def _prepare_pred(self, pred, pbatch): def _prepare_pred(self, pred, pbatch):
"""Prepares a batch of images and annotations for validation."""
predn = pred.clone() predn = pred.clone()
ops.scale_boxes(pbatch['imgsz'], predn[:, :4], pbatch['ori_shape'], ops.scale_boxes(pbatch['imgsz'], predn[:, :4], pbatch['ori_shape'],
ratio_pad=pbatch['ratio_pad']) # native-space pred ratio_pad=pbatch['ratio_pad']) # native-space pred

@ -23,6 +23,7 @@ class OBBPredictor(DetectionPredictor):
""" """
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initializes OBBPredictor with optional model and data configuration overrides."""
super().__init__(cfg, overrides, _callbacks) super().__init__(cfg, overrides, _callbacks)
self.args.task = 'obb' self.args.task = 'obb'

@ -65,6 +65,7 @@ class OBBValidator(DetectionValidator):
return self.match_predictions(detections[:, 5], gt_cls, iou) return self.match_predictions(detections[:, 5], gt_cls, iou)
def _prepare_batch(self, si, batch): def _prepare_batch(self, si, batch):
"""Prepares and returns a batch for OBB validation."""
idx = batch['batch_idx'] == si idx = batch['batch_idx'] == si
cls = batch['cls'][idx].squeeze(-1) cls = batch['cls'][idx].squeeze(-1)
bbox = batch['bboxes'][idx] bbox = batch['bboxes'][idx]
@ -78,6 +79,7 @@ class OBBValidator(DetectionValidator):
return prepared_batch return prepared_batch
def _prepare_pred(self, pred, pbatch): def _prepare_pred(self, pred, pbatch):
"""Prepares and returns a batch for OBB validation with scaled and padded bounding boxes."""
predn = pred.clone() predn = pred.clone()
ops.scale_boxes(pbatch['imgsz'], predn[:, :4], pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'], ops.scale_boxes(pbatch['imgsz'], predn[:, :4], pbatch['ori_shape'], ratio_pad=pbatch['ratio_pad'],
xywh=True) # native-space pred xywh=True) # native-space pred

@ -69,6 +69,7 @@ class PoseValidator(DetectionValidator):
self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[]) self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[])
def _prepare_batch(self, si, batch): def _prepare_batch(self, si, batch):
"""Prepares a batch for processing by converting keypoints to float and moving to device."""
pbatch = super()._prepare_batch(si, batch) pbatch = super()._prepare_batch(si, batch)
kpts = batch['keypoints'][batch['batch_idx'] == si] kpts = batch['keypoints'][batch['batch_idx'] == si]
h, w = pbatch['imgsz'] h, w = pbatch['imgsz']
@ -80,6 +81,7 @@ class PoseValidator(DetectionValidator):
return pbatch return pbatch
def _prepare_pred(self, pred, pbatch): def _prepare_pred(self, pred, pbatch):
"""Prepares and scales keypoints in a batch for pose processing."""
predn = super()._prepare_pred(pred, pbatch) predn = super()._prepare_pred(pred, pbatch)
nk = pbatch['kpts'].shape[1] nk = pbatch['kpts'].shape[1]
pred_kpts = predn[:, 6:].view(len(predn), nk, -1) pred_kpts = predn[:, 6:].view(len(predn), nk, -1)

@ -72,12 +72,14 @@ class SegmentationValidator(DetectionValidator):
return p, proto return p, proto
def _prepare_batch(self, si, batch): def _prepare_batch(self, si, batch):
"""Prepares a batch for training or inference by processing images and targets."""
prepared_batch = super()._prepare_batch(si, batch) prepared_batch = super()._prepare_batch(si, batch)
midx = [si] if self.args.overlap_mask else batch['batch_idx'] == si midx = [si] if self.args.overlap_mask else batch['batch_idx'] == si
prepared_batch['masks'] = batch['masks'][midx] prepared_batch['masks'] = batch['masks'][midx]
return prepared_batch return prepared_batch
def _prepare_pred(self, pred, pbatch, proto): def _prepare_pred(self, pred, pbatch, proto):
"""Prepares a batch for training or inference by processing images and targets."""
predn = super()._prepare_pred(pred, pbatch) predn = super()._prepare_pred(pred, pbatch)
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch['imgsz']) pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch['imgsz'])
return predn, pred_masks return predn, pred_masks

@ -116,6 +116,7 @@ class OBB(Detect):
"""YOLOv8 OBB detection head for detection with rotation models.""" """YOLOv8 OBB detection head for detection with rotation models."""
def __init__(self, nc=80, ne=1, ch=()): def __init__(self, nc=80, ne=1, ch=()):
"""Initialize OBB with number of classes `nc` and layer channels `ch`."""
super().__init__(nc, ch) super().__init__(nc, ch)
self.ne = ne # number of extra parameters self.ne = ne # number of extra parameters
self.detect = Detect.forward self.detect = Detect.forward
@ -124,6 +125,7 @@ class OBB(Detect):
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch) self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)
def forward(self, x): def forward(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
bs = x[0].shape[0] # batch size bs = x[0].shape[0] # batch size
angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits
# NOTE: set `angle` as an attribute so that `decode_bboxes` could use it. # NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.

@ -306,6 +306,7 @@ class OBBModel(DetectionModel):
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def init_criterion(self): def init_criterion(self):
"""Initialize the loss criterion for the model."""
return v8OBBLoss(self) return v8OBBLoss(self)

@ -153,6 +153,7 @@ class ObjectCounter:
self.selected_point = None self.selected_point = None
def extract_and_process_tracks(self, tracks): def extract_and_process_tracks(self, tracks):
"""Extracts and processes tracks for object counting in a video stream."""
boxes = tracks[0].boxes.xyxy.cpu() boxes = tracks[0].boxes.xyxy.cpu()
clss = tracks[0].boxes.cls.cpu().tolist() clss = tracks[0].boxes.cls.cpu().tolist()
track_ids = tracks[0].boxes.id.int().cpu().tolist() track_ids = tracks[0].boxes.id.int().cpu().tolist()

@ -55,6 +55,7 @@ class BaseTrack:
_count = 0 _count = 0
def __init__(self): def __init__(self):
"""Initializes a new track with unique ID and foundational tracking attributes."""
self.track_id = 0 self.track_id = 0
self.is_activated = False self.is_activated = False
self.state = TrackState.New self.state = TrackState.New

@ -245,6 +245,7 @@ def set_logging(name=LOGGING_NAME, verbose=True):
class CustomFormatter(logging.Formatter): class CustomFormatter(logging.Formatter):
def format(self, record): def format(self, record):
"""Sets up logging with UTF-8 encoding and configurable verbosity."""
return emojis(super().format(record)) return emojis(super().format(record))
formatter = CustomFormatter('%(message)s') # Use CustomFormatter to eliminate UTF-8 output as last recourse formatter = CustomFormatter('%(message)s') # Use CustomFormatter to eliminate UTF-8 output as last recourse

@ -206,7 +206,7 @@ def check_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=1.5, h
# Check file size # Check file size
gib = 1 << 30 # bytes per GiB gib = 1 << 30 # bytes per GiB
data = int(r.headers.get('Content-Length', 0)) / gib # file size (GB) data = int(r.headers.get('Content-Length', 0)) / gib # file size (GB)
total, used, free = (x / gib for x in shutil.disk_usage('/')) # bytes total, used, free = (x / gib for x in shutil.disk_usage(Path.cwd())) # bytes
if data * sf < free: if data * sf < free:
return True # sufficient space return True # sufficient space

Loading…
Cancel
Save