Explorer Cleanup (#7364)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Muhammad Rizwan Munawar <chr043416@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/7367/head
Glenn Jocher 10 months ago committed by GitHub
parent aca8eb1fd4
commit ed73c0fedc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 16
      README.md
  2. 18
      README.zh-CN.md
  3. 2
      docker/Dockerfile
  4. 2
      docker/Dockerfile-cpu
  5. 2
      docker/Dockerfile-python
  6. 16
      docs/en/guides/instance-segmentation-and-tracking.md
  7. 12
      docs/en/guides/vision-eye.md
  8. 144
      examples/heatmaps.ipynb
  9. 146
      examples/object_counting.ipynb
  10. 204
      examples/object_tracking.ipynb
  11. 5
      tests/test_explorer.py
  12. 19
      ultralytics/data/explorer/explorer.py
  13. 17
      ultralytics/data/explorer/utils.py
  14. 15
      ultralytics/utils/__init__.py
  15. 24
      ultralytics/utils/plotting.py

@ -66,7 +66,7 @@ For alternative installation methods including [Conda](https://anaconda.org/cond
<details open>
<summary>Usage</summary>
#### CLI
### CLI
YOLOv8 may be used directly in the Command Line Interface (CLI) with a `yolo` command:
@ -76,7 +76,7 @@ yolo predict model=yolov8n.pt source='https://ultralytics.com/images/bus.jpg'
`yolo` can be used for a variety of tasks and modes and accepts additional arguments, i.e. `imgsz=640`. See the YOLOv8 [CLI Docs](https://docs.ultralytics.com/usage/cli) for examples.
#### Python
### Python
YOLOv8 may also be used directly in a Python environment, and accepts the same [arguments](https://docs.ultralytics.com/usage/cfg/) as in the CLI example above:
@ -98,6 +98,18 @@ See YOLOv8 [Python Docs](https://docs.ultralytics.com/usage/python) for more exa
</details>
### Notebooks
Ultralytics provides interactive notebooks for YOLOv8, covering training, validation, tracking, and more. Each notebook is paired with a [YouTube](https://youtube.com/ultralytics) tutorial, making it easy to learn and implement advanced YOLOv8 features.
| Docs | Notebook | YouTube |
| --------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| <a href="https://docs.ultralytics.com/modes/">YOLOv8 Train, Val, Predict and Export Modes</a> | <a href="https://colab.research.google.com/github/ultralytics/ultralytics/blob/main/examples/tutorial.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a> | <a href="https://youtu.be/j8uQc0qB91s"><center><img width=30% src="https://raw.githubusercontent.com/ultralytics/assets/main/social/logo-social-youtube-rect.png" alt="Ultralytics Youtube Video"></center></a> |
| <a href="https://docs.ultralytics.com/hub/quickstart/">Ultralytics HUB QuickStart</a> | <a href="https://colab.research.google.com/github/ultralytics/ultralytics/blob/main/examples/hub.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a> | <a href="https://youtu.be/lveF9iCMIzc"><center><img width=30% src="https://raw.githubusercontent.com/ultralytics/assets/main/social/logo-social-youtube-rect.png" alt="Ultralytics Youtube Video"></center></a> |
| <a href="https://docs.ultralytics.com/modes/track/">YOLOv8 Multi-Object Tracking in Videos</a> | <a href="https://colab.research.google.com/github/ultralytics/ultralytics/blob/main/examples/object_tracking.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a> | <a href="https://youtu.be/hHyHmOtmEgs"><center><img width=30% src="https://raw.githubusercontent.com/ultralytics/assets/main/social/logo-social-youtube-rect.png" alt="Ultralytics Youtube Video"></center></a> |
| <a href="https://docs.ultralytics.com/guides/object-counting/">YOLOv8 Object Counting in Videos</a> | <a href="https://colab.research.google.com/github/ultralytics/ultralytics/blob/main/examples/object_counting.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a> | <a href="https://youtu.be/Ag2e-5_NpS0"><center><img width=30% src="https://raw.githubusercontent.com/ultralytics/assets/main/social/logo-social-youtube-rect.png" alt="Ultralytics Youtube Video"></center></a> |
| <a href="https://docs.ultralytics.com/guides/heatmaps/">YOLOv8 Heatmaps in Videos</a> | <a href="https://colab.research.google.com/github/ultralytics/ultralytics/blob/main/examples/heatmaps.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a> | <a href="https://youtu.be/4ezde5-nZZw"><center><img width=30% src="https://raw.githubusercontent.com/ultralytics/assets/main/social/logo-social-youtube-rect.png" alt="Ultralytics Youtube Video"></center></a> |
## <div align="center">Models</div>
YOLOv8 [Detect](https://docs.ultralytics.com/tasks/detect), [Segment](https://docs.ultralytics.com/tasks/segment) and [Pose](https://docs.ultralytics.com/tasks/pose) models pretrained on the [COCO](https://docs.ultralytics.com/datasets/detect/coco) dataset are available here, as well as YOLOv8 [Classify](https://docs.ultralytics.com/tasks/classify) models pretrained on the [ImageNet](https://docs.ultralytics.com/datasets/classify/imagenet) dataset. [Track](https://docs.ultralytics.com/modes/track) mode is available for all Detect, Segment and Pose models.

@ -44,6 +44,8 @@
</div>
</div>
以下是提供的内容的中文翻译:
## <div align="center">文档</div>
请参阅下面的快速安装和使用示例,以及 [YOLOv8 文档](https://docs.ultralytics.com) 上有关训练、验证、预测和部署的完整文档。
@ -66,7 +68,7 @@ pip install ultralytics
<details open>
<summary>Usage</summary>
#### CLI
### CLI
YOLOv8 可以在命令行界面(CLI)中直接使用,只需输入 `yolo` 命令:
@ -76,7 +78,7 @@ yolo predict model=yolov8n.pt source='https://ultralytics.com/images/bus.jpg'
`yolo` 可用于各种任务和模式,并接受其他参数,例如 `imgsz=640`。查看 YOLOv8 [CLI 文档](https://docs.ultralytics.com/usage/cli)以获取示例。
#### Python
### Python
YOLOv8 也可以在 Python 环境中直接使用,并接受与上述 CLI 示例中相同的[参数](https://docs.ultralytics.com/usage/cfg/):
@ -98,6 +100,18 @@ success = model.export(format="onnx") # 将模型导出为 ONNX 格式
</details>
### 笔记本
Ultralytics 提供了 YOLOv8 的交互式笔记本,涵盖训练、验证、跟踪等内容。每个笔记本都配有 [YouTube](https://youtube.com/ultralytics) 教程,使学习和实现高级 YOLOv8 功能变得简单。
| 文档 | 笔记本 | YouTube |
| ---------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| <a href="https://docs.ultralytics.com/modes/">YOLOv8 训练、验证、预测和导出模式</a> | <a href="https://colab.research.google.com/github/ultralytics/ultralytics/blob/main/examples/tutorial.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="在 Colab 中打开"></a> | <a href="https://youtu.be/j8uQc0qB91s"><center><img width=30% src="https://raw.githubusercontent.com/ultralytics/assets/main/social/logo-social-youtube-rect.png" alt="Ultralytics Youtube 视频"></center></a> |
| <a href="https://docs.ultralytics.com/hub/quickstart/">Ultralytics HUB 快速开始</a> | <a href="https://colab.research.google.com/github/ultralytics/ultralytics/blob/main/examples/hub.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="在 Colab 中打开"></a> | <a href="https://youtu.be/lveF9iCMIzc"><center><img width=30% src="https://raw.githubusercontent.com/ultralytics/assets/main/social/logo-social-youtube-rect.png" alt="Ultralytics Youtube 视频"></center></a> |
| <a href="https://docs.ultralytics.com/modes/track/">YOLOv8 视频中的多对象跟踪</a> | <a href="https://colab.research.google.com/github/ultralytics/ultralytics/blob/main/examples/object_tracking.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="在 Colab 中打开"></a> | <a href="https://youtu.be/hHyHmOtmEgs"><center><img width=30% src="https://raw.githubusercontent.com/ultralytics/assets/main/social/logo-social-youtube-rect.png" alt="Ultralytics Youtube 视频"></center></a> |
| <a href="https://docs.ultralytics.com/guides/object-counting/">YOLOv8 视频中的对象计数</a> | <a href="https://colab.research.google.com/github/ultralytics/ultralytics/blob/main/examples/object_counting.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="在 Colab 中打开"></a> | <a href="https://youtu.be/Ag2e-5_NpS0"><center><img width=30% src="https://raw.githubusercontent.com/ultralytics/assets/main/social/logo-social-youtube-rect.png" alt="Ultralytics Youtube 视频"></center></a> |
| <a href="https://docs.ultralytics.com/guides/heatmaps/">YOLOv8 视频中的热图</a> | <a href="https://colab.research.google.com/github/ultralytics/ultralytics/blob/main/examples/heatmaps.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="在 Colab 中打开"></a> | <a href="https://youtu.be/4ezde5-nZZw"><center><img width=30% src="https://raw.githubusercontent.com/ultralytics/assets/main/social/logo-social-youtube-rect.png" alt="Ultralytics Youtube 视频"></center></a> |
## <div align="center">模型</div>
在[COCO](https://docs.ultralytics.com/datasets/detect/coco)数据集上预训练的YOLOv8 [检测](https://docs.ultralytics.com/tasks/detect),[分割](https://docs.ultralytics.com/tasks/segment)和[姿态](https://docs.ultralytics.com/tasks/pose)模型可以在这里找到,以及在[ImageNet](https://docs.ultralytics.com/datasets/classify/imagenet)数据集上预训练的YOLOv8 [分类](https://docs.ultralytics.com/tasks/classify)模型。所有的检测,分割和姿态模型都支持[追踪](https://docs.ultralytics.com/modes/track)模式。

@ -28,7 +28,7 @@ ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /u
# Install pip packages
RUN python3 -m pip install --upgrade pip wheel
RUN pip install --no-cache -e ".[export]" albumentations comet pycocotools pytest-cov
RUN pip install --no-cache -e ".[export]" albumentations comet pycocotools lancedb pytest-cov
# Run exports to AutoInstall packages
RUN yolo export model=tmp/yolov8n.pt format=edgetpu imgsz=32

@ -26,7 +26,7 @@ RUN rm -rf /usr/lib/python3.11/EXTERNALLY-MANAGED
# Install pip packages
RUN python3 -m pip install --upgrade pip wheel
RUN pip install --no-cache -e ".[export]" --extra-index-url https://download.pytorch.org/whl/cpu
RUN pip install --no-cache -e ".[export]" lancedb --extra-index-url https://download.pytorch.org/whl/cpu
# Run exports to AutoInstall packages
RUN yolo export model=tmp/yolov8n.pt format=edgetpu imgsz=32

@ -26,7 +26,7 @@ ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /u
# Install pip packages
RUN python3 -m pip install --upgrade pip wheel
RUN pip install --no-cache -e ".[export]" --extra-index-url https://download.pytorch.org/whl/cpu
RUN pip install --no-cache -e ".[export]" lancedb --extra-index-url https://download.pytorch.org/whl/cpu
# Run exports to AutoInstall packages
RUN yolo export model=tmp/yolov8n.pt format=edgetpu imgsz=32

@ -91,16 +91,18 @@ There are two types of instance segmentation tracking available in the Ultralyti
print("Video frame is empty or video processing has been successfully completed.")
break
annotator = Annotator(im0, line_width=2)
results = model.track(im0, persist=True)
masks = results[0].masks.xy
track_ids = results[0].boxes.id.int().cpu().tolist()
annotator = Annotator(im0, line_width=2)
if results[0].boxes.id is not None:
masks = results[0].masks.xy
track_ids = results[0].boxes.id.int().cpu().tolist()
for mask, track_id in zip(masks, track_ids):
annotator.seg_bbox(mask=mask,
mask_color=colors(track_id, True),
track_label=str(track_id))
for mask, track_id in zip(masks, track_ids):
annotator.seg_bbox(mask=mask,
mask_color=colors(track_id, True),
track_label=str(track_id))
out.write(im0)
cv2.imshow("instance-segmentation-object-tracking", im0)

@ -81,15 +81,17 @@ keywords: Ultralytics, YOLOv8, Object Detection, Object Tracking, IDetection, Vi
print("Video frame is empty or video processing has been successfully completed.")
break
annotator = Annotator(im0, line_width=2)
results = model.track(im0, persist=True)
boxes = results[0].boxes.xyxy.cpu()
track_ids = results[0].boxes.id.int().cpu().tolist()
annotator = Annotator(im0, line_width=2)
if results[0].boxes.id is not None:
track_ids = results[0].boxes.id.int().cpu().tolist()
for box, track_id in zip(boxes, track_ids):
annotator.box_label(box, label=str(track_id), color=colors(int(track_id)))
annotator.visioneye(box, center_point)
for box, track_id in zip(boxes, track_ids):
annotator.box_label(box, label=str(track_id), color=colors(int(track_id)))
annotator.visioneye(box, center_point)
out.write(im0)
cv2.imshow("visioneye-pinpoint", im0)

@ -0,0 +1,144 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"<div align=\"center\">\n",
"\n",
" <a href=\"https://ultralytics.com/yolov8\" target=\"_blank\">\n",
" <img width=\"1024\", src=\"https://raw.githubusercontent.com/ultralytics/assets/main/yolov8/banner-yolov8.png\"></a>\n",
"\n",
" [中文](https://docs.ultralytics.com/zh/) | [한국어](https://docs.ultralytics.com/ko/) | [日本語](https://docs.ultralytics.com/ja/) | [Русский](https://docs.ultralytics.com/ru/) | [Deutsch](https://docs.ultralytics.com/de/) | [Français](https://docs.ultralytics.com/fr/) | [Español](https://docs.ultralytics.com/es/) | [Português](https://docs.ultralytics.com/pt/) | [हि](https://docs.ultralytics.com/hi/) | [العربية](https://docs.ultralytics.com/ar/)\n",
"\n",
" <a href=\"https://colab.research.google.com/github/ultralytics/ultralytics/blob/main/examples/heatmaps.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"></a>\n",
"\n",
"Welcome to the Ultralytics YOLOv8 🚀 notebook! <a href=\"https://github.com/ultralytics/ultralytics\">YOLOv8</a> is the latest version of the YOLO (You Only Look Once) AI models developed by <a href=\"https://ultralytics.com\">Ultralytics</a>. This notebook serves as the starting point for exploring the <a href=\"https://docs.ultralytics.com/guides/heatmaps/\">heatmaps</a> and understand its features and capabilities.\n",
"\n",
"YOLOv8 models are fast, accurate, and easy to use, making them ideal for various object detection and image segmentation tasks. They can be trained on large datasets and run on diverse hardware platforms, from CPUs to GPUs.\n",
"\n",
"We hope that the resources in this notebook will help you get the most out of <a href=\"https://docs.ultralytics.com/guides/heatmaps/\">Ultralytics Heatmaps</a>. Please browse the YOLOv8 <a href=\"https://docs.ultralytics.com/\">Docs</a> for details, raise an issue on <a href=\"https://github.com/ultralytics/ultralytics\">GitHub</a> for support, and join our <a href=\"https://ultralytics.com/discord\">Discord</a> community for questions and discussions!\n",
"\n",
"</div>"
],
"metadata": {
"id": "PN1cAxdvd61e"
}
},
{
"cell_type": "markdown",
"source": [
"# Setup\n",
"\n",
"Pip install `ultralytics` and [dependencies](https://github.com/ultralytics/ultralytics/blob/main/pyproject.toml) and check software and hardware."
],
"metadata": {
"id": "o68Sg1oOeZm2"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9dSwz_uOReMI"
},
"outputs": [],
"source": [
"!pip install ultralytics"
]
},
{
"cell_type": "markdown",
"source": [
"# Ultralytics Heatmaps\n",
"\n",
"Heatmap is color-coded matrix, generated by Ultralytics YOLOv8, simplifies intricate data by using vibrant colors. This visual representation employs warmer hues for higher intensities and cooler tones for lower values. Heatmaps are effective in illustrating complex data patterns, correlations, and anomalies, providing a user-friendly and engaging way to interpret data across various domains."
],
"metadata": {
"id": "m7VkxQ2aeg7k"
}
},
{
"cell_type": "code",
"source": [
"from ultralytics import YOLO\n",
"from ultralytics.solutions import heatmap\n",
"import cv2\n",
"\n",
"model = YOLO(\"yolov8n.pt\")\n",
"cap = cv2.VideoCapture(\"path/to/video/file.mp4\")\n",
"assert cap.isOpened(), \"Error reading video file\"\n",
"\n",
"# Video writer\n",
"video_writer = cv2.VideoWriter(\"heatmap_output.avi\",\n",
" cv2.VideoWriter_fourcc(*'mp4v'),\n",
" int(cap.get(5)),\n",
" (int(cap.get(3)), int(cap.get(4))))\n",
"\n",
"# Init heatmap\n",
"heatmap_obj = heatmap.Heatmap()\n",
"heatmap_obj.set_args(colormap=cv2.COLORMAP_PARULA ,\n",
" imw=cap.get(4), # should same as cap height\n",
" imh=cap.get(3), # should same as cap width\n",
" view_img=True,\n",
" shape=\"circle\")\n",
"\n",
"while cap.isOpened():\n",
" success, im0 = cap.read()\n",
" if not success:\n",
" print(\"Video frame is empty or video processing has been successfully completed.\")\n",
" break\n",
" tracks = model.track(im0, persist=True, show=False)\n",
"\n",
" im0 = heatmap_obj.generate_heatmap(im0, tracks)\n",
" video_writer.write(im0)\n",
"\n",
"cap.release()\n",
"video_writer.release()\n",
"cv2.destroyAllWindows()"
],
"metadata": {
"id": "Cx-u59HQdu2o"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#Community Support\n",
"\n",
"For more information, you can explore <a href=\"https://docs.ultralytics.com/guides/heatmaps/#heatmap-colormaps\">Ultralytics Heatmaps Docs</a>\n",
"\n",
"Ultralytics ⚡ resources\n",
"- About Us – https://ultralytics.com/about\n",
"- Join Our Team – https://ultralytics.com/work\n",
"- Contact Us – https://ultralytics.com/contact\n",
"- Discord – https://discord.gg/2wNGbc6g9X\n",
"- Ultralytics License – https://ultralytics.com/license\n",
"\n",
"YOLOv8 🚀 resources\n",
"- GitHub – https://github.com/ultralytics/ultralytics\n",
"- Docs – https://docs.ultralytics.com/"
],
"metadata": {
"id": "QrlKg-y3fEyD"
}
}
]
}

@ -0,0 +1,146 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"<div align=\"center\">\n",
"\n",
" <a href=\"https://ultralytics.com/yolov8\" target=\"_blank\">\n",
" <img width=\"1024\", src=\"https://raw.githubusercontent.com/ultralytics/assets/main/yolov8/banner-yolov8.png\"></a>\n",
"\n",
" [中文](https://docs.ultralytics.com/zh/) | [한국어](https://docs.ultralytics.com/ko/) | [日本語](https://docs.ultralytics.com/ja/) | [Русский](https://docs.ultralytics.com/ru/) | [Deutsch](https://docs.ultralytics.com/de/) | [Français](https://docs.ultralytics.com/fr/) | [Español](https://docs.ultralytics.com/es/) | [Português](https://docs.ultralytics.com/pt/) | [हि](https://docs.ultralytics.com/hi/) | [العربية](https://docs.ultralytics.com/ar/)\n",
"\n",
" <a href=\"https://colab.research.google.com/github/ultralytics/ultralytics/blob/main/examples/object_counting.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"></a>\n",
"\n",
"Welcome to the Ultralytics YOLOv8 🚀 notebook! <a href=\"https://github.com/ultralytics/ultralytics\">YOLOv8</a> is the latest version of the YOLO (You Only Look Once) AI models developed by <a href=\"https://ultralytics.com\">Ultralytics</a>. This notebook serves as the starting point for exploring the <a href=\"https://docs.ultralytics.com/guides/object-counting/\">Object Counting</a> and understand its features and capabilities.\n",
"\n",
"YOLOv8 models are fast, accurate, and easy to use, making them ideal for various object detection and image segmentation tasks. They can be trained on large datasets and run on diverse hardware platforms, from CPUs to GPUs.\n",
"\n",
"We hope that the resources in this notebook will help you get the most out of <a href=\"https://docs.ultralytics.com/guides/object-counting/\">Ultralytics Object Counting</a>. Please browse the YOLOv8 <a href=\"https://docs.ultralytics.com/\">Docs</a> for details, raise an issue on <a href=\"https://github.com/ultralytics/ultralytics\">GitHub</a> for support, and join our <a href=\"https://ultralytics.com/discord\">Discord</a> community for questions and discussions!\n",
"\n",
"</div>"
],
"metadata": {
"id": "PN1cAxdvd61e"
}
},
{
"cell_type": "markdown",
"source": [
"# Setup\n",
"\n",
"Pip install `ultralytics` and [dependencies](https://github.com/ultralytics/ultralytics/blob/main/pyproject.toml) and check software and hardware."
],
"metadata": {
"id": "o68Sg1oOeZm2"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9dSwz_uOReMI"
},
"outputs": [],
"source": [
"!pip install ultralytics"
]
},
{
"cell_type": "markdown",
"source": [
"# Ultralytics Object Counting\n",
"\n",
"Counting objects using Ultralytics YOLOv8 entails the precise detection and enumeration of specific objects within videos and camera streams. YOLOv8 demonstrates exceptional performance in real-time applications, delivering efficient and accurate object counting across diverse scenarios such as crowd analysis and surveillance. This is attributed to its advanced algorithms and deep learning capabilities."
],
"metadata": {
"id": "m7VkxQ2aeg7k"
}
},
{
"cell_type": "code",
"source": [
"from ultralytics import YOLO\n",
"from ultralytics.solutions import object_counter\n",
"import cv2\n",
"\n",
"model = YOLO(\"yolov8n.pt\")\n",
"cap = cv2.VideoCapture(\"path/to/video/file.mp4\")\n",
"assert cap.isOpened(), \"Error reading video file\"\n",
"\n",
"# Define line points\n",
"line_points = [(20, 400), (1080, 400)]\n",
"\n",
"# Video writer\n",
"video_writer = cv2.VideoWriter(\"object_counting_output.avi\",\n",
" cv2.VideoWriter_fourcc(*'mp4v'),\n",
" int(cap.get(5)),\n",
" (int(cap.get(3)), int(cap.get(4))))\n",
"\n",
"# Init Object Counter\n",
"counter = object_counter.ObjectCounter()\n",
"counter.set_args(view_img=True,\n",
" reg_pts=line_points,\n",
" classes_names=model.names,\n",
" draw_tracks=True)\n",
"\n",
"while cap.isOpened():\n",
" success, im0 = cap.read()\n",
" if not success:\n",
" print(\"Video frame is empty or video processing has been successfully completed.\")\n",
" break\n",
" tracks = model.track(im0, persist=True, show=False)\n",
"\n",
" im0 = counter.start_counting(im0, tracks)\n",
" video_writer.write(im0)\n",
"\n",
"cap.release()\n",
"video_writer.release()\n",
"cv2.destroyAllWindows()"
],
"metadata": {
"id": "Cx-u59HQdu2o"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#Community Support\n",
"\n",
"For more information, you can explore <a href=\"https://docs.ultralytics.com/guides/object-counting/\">Ultralytics Object Counting Docs</a>\n",
"\n",
"Ultralytics ⚡ resources\n",
"- About Us – https://ultralytics.com/about\n",
"- Join Our Team – https://ultralytics.com/work\n",
"- Contact Us – https://ultralytics.com/contact\n",
"- Discord – https://discord.gg/2wNGbc6g9X\n",
"- Ultralytics License – https://ultralytics.com/license\n",
"\n",
"YOLOv8 🚀 resources\n",
"- GitHub – https://github.com/ultralytics/ultralytics\n",
"- Docs – https://docs.ultralytics.com/"
],
"metadata": {
"id": "QrlKg-y3fEyD"
}
}
]
}

@ -0,0 +1,204 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"<div align=\"center\">\n",
"\n",
" <a href=\"https://ultralytics.com/yolov8\" target=\"_blank\">\n",
" <img width=\"1024\", src=\"https://raw.githubusercontent.com/ultralytics/assets/main/yolov8/banner-yolov8.png\"></a>\n",
"\n",
" [中文](https://docs.ultralytics.com/zh/) | [한국어](https://docs.ultralytics.com/ko/) | [日本語](https://docs.ultralytics.com/ja/) | [Русский](https://docs.ultralytics.com/ru/) | [Deutsch](https://docs.ultralytics.com/de/) | [Français](https://docs.ultralytics.com/fr/) | [Español](https://docs.ultralytics.com/es/) | [Português](https://docs.ultralytics.com/pt/) | [हि](https://docs.ultralytics.com/hi/) | [العربية](https://docs.ultralytics.com/ar/)\n",
"\n",
" <a href=\"https://colab.research.google.com/github/ultralytics/ultralytics/blob/main/examples/object_tracking.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"></a>\n",
"\n",
"Welcome to the Ultralytics YOLOv8 🚀 notebook! <a href=\"https://github.com/ultralytics/ultralytics\">YOLOv8</a> is the latest version of the YOLO (You Only Look Once) AI models developed by <a href=\"https://ultralytics.com\">Ultralytics</a>. This notebook serves as the starting point for exploring the <a href=\"https://docs.ultralytics.com/modes/track/\">Object Tracking</a> and understand its features and capabilities.\n",
"\n",
"YOLOv8 models are fast, accurate, and easy to use, making them ideal for various object detection and image segmentation tasks. They can be trained on large datasets and run on diverse hardware platforms, from CPUs to GPUs.\n",
"\n",
"We hope that the resources in this notebook will help you get the most out of <a href=\"https://docs.ultralytics.com/modes/track/\">Ultralytics Object Tracking</a>. Please browse the YOLOv8 <a href=\"https://docs.ultralytics.com/\">Docs</a> for details, raise an issue on <a href=\"https://github.com/ultralytics/ultralytics\">GitHub</a> for support, and join our <a href=\"https://ultralytics.com/discord\">Discord</a> community for questions and discussions!\n",
"\n",
"</div>"
],
"metadata": {
"id": "PN1cAxdvd61e"
}
},
{
"cell_type": "markdown",
"source": [
"# Setup\n",
"\n",
"Pip install `ultralytics` and [dependencies](https://github.com/ultralytics/ultralytics/blob/main/pyproject.toml) and check software and hardware."
],
"metadata": {
"id": "o68Sg1oOeZm2"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9dSwz_uOReMI"
},
"outputs": [],
"source": [
"!pip install ultralytics"
]
},
{
"cell_type": "markdown",
"source": [
"# Ultralytics Object Tracking\n",
"\n",
"Within the domain of video analytics, object tracking stands out as a crucial undertaking. It goes beyond merely identifying the location and class of objects within the frame; it also involves assigning a unique ID to each detected object as the video unfolds. The applications of this technology are vast, spanning from surveillance and security to real-time sports analytics."
],
"metadata": {
"id": "m7VkxQ2aeg7k"
}
},
{
"cell_type": "markdown",
"source": [
"## CLI"
],
"metadata": {
"id": "-ZF9DM6e6gz0"
}
},
{
"cell_type": "code",
"source": [
"!yolo track source=\"/content/people walking gray.mp4\" save=True"
],
"metadata": {
"id": "-XJqhOwo6iqT"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Python\n",
"\n",
"- Draw Object tracking trails"
],
"metadata": {
"id": "XRcw0vIE6oNb"
}
},
{
"cell_type": "code",
"source": [
"import cv2\n",
"import numpy as np\n",
"from ultralytics import YOLO\n",
"\n",
"from ultralytics.utils.checks import check_imshow\n",
"from ultralytics.utils.plotting import Annotator, colors\n",
"\n",
"from collections import defaultdict\n",
"\n",
"track_history = defaultdict(lambda: [])\n",
"model = YOLO(\"yolov8n.pt\")\n",
"names = model.model.names\n",
"\n",
"video_path = \"/path/to/video/file.mp4\"\n",
"cap = cv2.VideoCapture(video_path)\n",
"assert cap.isOpened(), \"Error reading video file\"\n",
"\n",
"frame_width = int(cap.get(3))\n",
"frame_height = int(cap.get(4))\n",
"size = (frame_width, frame_height)\n",
"result = cv2.VideoWriter('object_tracking.avi',\n",
" cv2.VideoWriter_fourcc(*'MJPG'),\n",
" int(cap.get(5)), size)\n",
"\n",
"\n",
"while cap.isOpened():\n",
" success, frame = cap.read()\n",
" if success:\n",
" results = model.track(frame, persist=True, verbose=False)\n",
" boxes = results[0].boxes.xyxy.cpu()\n",
"\n",
" if results[0].boxes.id is not None:\n",
"\n",
" # Extract prediction results\n",
" clss = results[0].boxes.cls.cpu().tolist()\n",
" track_ids = results[0].boxes.id.int().cpu().tolist()\n",
" confs = results[0].boxes.conf.float().cpu().tolist()\n",
"\n",
" # Annotator Init\n",
" annotator = Annotator(frame, line_width=2)\n",
"\n",
" for box, cls, track_id in zip(boxes, clss, track_ids):\n",
" annotator.box_label(box, color=colors(int(cls), True), label=names[int(cls)])\n",
"\n",
" # Store tracking history\n",
" track = track_history[track_id]\n",
" track.append((int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)))\n",
" if len(track) > 30:\n",
" track.pop(0)\n",
"\n",
" # Plot tracks\n",
" points = np.array(track, dtype=np.int32).reshape((-1, 1, 2))\n",
" cv2.circle(frame, (track[-1]), 7, colors(int(cls), True), -1)\n",
" cv2.polylines(frame, [points], isClosed=False, color=colors(int(cls), True), thickness=2)\n",
"\n",
" result.write(frame)\n",
" if cv2.waitKey(1) & 0xFF == ord(\"q\"):\n",
" break\n",
" else:\n",
" break\n",
"\n",
"result.release()\n",
"cap.release()\n",
"cv2.destroyAllWindows()"
],
"metadata": {
"id": "Cx-u59HQdu2o"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#Community Support\n",
"\n",
"For more information, you can explore <a href=\"https://docs.ultralytics.com/modes/track/\">Ultralytics Object Tracking Docs</a>\n",
"\n",
"Ultralytics ⚡ resources\n",
"- About Us – https://ultralytics.com/about\n",
"- Join Our Team – https://ultralytics.com/work\n",
"- Contact Us – https://ultralytics.com/contact\n",
"- Discord – https://discord.gg/2wNGbc6g9X\n",
"- Ultralytics License – https://ultralytics.com/license\n",
"\n",
"YOLOv8 🚀 resources\n",
"- GitHub – https://github.com/ultralytics/ultralytics\n",
"- Docs – https://docs.ultralytics.com/"
],
"metadata": {
"id": "QrlKg-y3fEyD"
}
}
]
}

@ -1,4 +1,5 @@
from ultralytics import Explorer
from ultralytics.utils import ASSETS
def test_similarity():
@ -6,14 +7,14 @@ def test_similarity():
exp.create_embeddings_table()
similar = exp.get_similar(idx=1)
assert len(similar) == 25
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
similar = exp.get_similar(img=ASSETS / 'zidane.jpg')
assert len(similar) == 25
similar = exp.get_similar(idx=[1, 2], limit=10)
assert len(similar) == 10
sim_idx = exp.similarity_index()
assert len(sim_idx) > 0
sql = exp.sql_query("WHERE labels LIKE '%person%'")
len(sql) > 0
assert len(sql) > 0
def test_det():

@ -40,7 +40,7 @@ class ExplorerDataset(YOLODataset):
return self.ims[i], self.im_hw0[i], self.im_hw[i]
def build_transforms(self, hyp=None):
transforms = Format(
return Format(
bbox_format='xyxy',
normalize=False,
return_mask=self.use_segments,
@ -49,7 +49,6 @@ class ExplorerDataset(YOLODataset):
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask,
)
return transforms
class Explorer:
@ -161,8 +160,7 @@ class Explorer:
embeds = self.model.embed(imgs)
# Get avg if multiple images are passed (len > 1)
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
query = self.table.search(embeds).limit(limit).to_arrow()
return query
return self.table.search(embeds).limit(limit).to_arrow()
def sql_query(self, query, return_type='pandas'):
"""
@ -223,8 +221,7 @@ class Explorer:
"""
result = self.sql_query(query, return_type='arrow')
img = plot_similar_images(result, plot_labels=labels)
img = Image.fromarray(img)
return img
return Image.fromarray(img)
def get_similar(self, img=None, idx=None, limit=25, return_type='pandas'):
"""
@ -276,8 +273,7 @@ class Explorer:
"""
similar = self.get_similar(img, idx, limit, return_type='arrow')
img = plot_similar_images(similar, plot_labels=labels)
img = Image.fromarray(img)
return img
return Image.fromarray(img)
def similarity_index(self, max_dist=0.2, top_k=None, force=False):
"""
@ -331,7 +327,6 @@ class Explorer:
sim_table.add(_yield_sim_idx())
self.sim_index = sim_table
return sim_table.to_pandas()
def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False):
@ -373,8 +368,7 @@ class Explorer:
buffer.seek(0)
# Use Pillow to open the image from the buffer
image = Image.open(buffer)
return image
return Image.open(buffer)
def _check_imgs_or_idxs(self, img, idx):
if img is None and idx is None:
@ -385,8 +379,7 @@ class Explorer:
idx = idx if isinstance(idx, list) else [idx]
img = self.table.to_lance().take(idx, columns=['im_file']).to_pydict()['im_file']
img = img if isinstance(img, list) else [img]
return img
return img if isinstance(img, list) else [img]
def visualize(self, result):
"""

@ -1,4 +1,3 @@
from pathlib import Path
from typing import List
import cv2
@ -94,10 +93,12 @@ def plot_similar_images(similar_set, plot_labels=True):
batch_idx = np.concatenate(batch_idx, axis=0)
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
fname = 'temp_exp_grid.jpg'
plot_images(imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, fname=fname,
max_subplots=len(images)).join()
img = cv2.imread(fname, cv2.IMREAD_COLOR)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
Path(fname).unlink()
return img_rgb
return plot_images(imgs,
batch_idx,
cls,
bboxes=boxes,
masks=masks,
kpts=kpts,
max_subplots=len(images),
save=False,
threaded=False)

@ -736,16 +736,19 @@ class TryExcept(contextlib.ContextDecorator):
def threaded(func):
"""
Multi-threads a target function and returns thread.
Multi-threads a target function by default and returns the thread or function result.
Use as @threaded decorator.
Use as @threaded decorator. The function runs in a separate thread unless 'threaded=False' is passed.
"""
def wrapper(*args, **kwargs):
"""Multi-threads a given function and returns the thread."""
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
"""Multi-threads a given function based on 'threaded' kwarg and returns the thread or function result."""
if kwargs.pop('threaded', True): # run in thread
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
else:
return func(*args, **kwargs)
return wrapper

@ -125,7 +125,7 @@ class Annotator:
if rotated:
p1 = [int(b) for b in box[0]]
# NOTE: cv2-version polylines needs np.asarray type.
cv2.polylines(self.im, [np.asarray(box, dtype=np.int)], True, color, self.lw)
cv2.polylines(self.im, [np.asarray(box, dtype=int)], True, color, self.lw)
else:
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
@ -580,7 +580,8 @@ def plot_images(images,
fname='images.jpg',
names=None,
on_plot=None,
max_subplots=16):
max_subplots=16,
save=True):
"""Plot image grid with labels."""
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
@ -596,7 +597,6 @@ def plot_images(images,
batch_idx = batch_idx.cpu().numpy()
max_size = 1920 # max image size
max_subplots = max_subplots # max image subplots, i.e. 4x4
bs, _, h, w = images.shape # batch size, _, height, width
bs = min(bs, max_subplots) # limit plot images
ns = np.ceil(bs ** 0.5) # number of subplots (square)
@ -605,12 +605,9 @@ def plot_images(images,
# Build Image
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
for i, im in enumerate(images):
if i == max_subplots: # if last batch has fewer images than we expect
break
for i in range(bs):
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
im = im.transpose(1, 2, 0)
mosaic[y:y + h, x:x + w, :] = im
mosaic[y:y + h, x:x + w, :] = images[i].transpose(1, 2, 0)
# Resize (optional)
scale = max_size / ns / max(h, w)
@ -622,7 +619,7 @@ def plot_images(images,
# Annotate
fs = int((h + w) * ns * 0.01) # font size
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
for i in range(i + 1):
for i in range(bs):
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
if paths:
@ -699,9 +696,12 @@ def plot_images(images,
with contextlib.suppress(Exception):
im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
annotator.fromarray(im)
annotator.im.save(fname) # save
if on_plot:
on_plot(fname)
if save:
annotator.im.save(fname) # save
if on_plot:
on_plot(fname)
else:
return np.asarray(annotator.im)
@plt_settings()

Loading…
Cancel
Save