Add HeatMap `guide` in real-world-projects + Code in Solutions Directory (#6796)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/6816/head^2
Muhammad Rizwan Munawar 12 months ago committed by GitHub
parent 1e1247ddee
commit 742cbc1b4e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 142
      docs/en/guides/heatmaps.md
  2. 1
      docs/en/guides/index.md
  3. 127
      docs/en/guides/object-counting.md
  4. 106
      docs/en/guides/workouts-monitoring.md
  5. 12
      docs/en/yolov5/tutorials/clearml_logging_integration.md
  6. 18
      docs/en/yolov5/tutorials/comet_logging_integration.md
  7. 1
      docs/mkdocs.yml
  8. 2
      ultralytics/solutions/ai_gym.py
  9. 102
      ultralytics/solutions/heatmap.py
  10. 1
      ultralytics/solutions/object_counter.py

@ -0,0 +1,142 @@
---
comments: true
description: Advanced Data Visualization with Ultralytics YOLOv8 Heatmaps
keywords: Ultralytics, YOLOv8, Advanced Data Visualization, Heatmap Technology, Object Detection and Tracking, Jupyter Notebook, Python SDK, Command Line Interface
---
# Advanced Data Visualization: Heatmaps using Ultralytics YOLOv8 🚀
## Introduction to Heatmaps
A heatmap generated with [Ultralytics YOLOv8](https://github.com/ultralytics/ultralytics/) transforms complex data into a vibrant, color-coded matrix. This visual tool employs a spectrum of colors to represent varying data values, where warmer hues indicate higher intensities and cooler tones signify lower values. Heatmaps excel in visualizing intricate data patterns, correlations, and anomalies, offering an accessible and engaging approach to data interpretation across diverse domains.
## Why Choose Heatmaps for Data Analysis?
- **Intuitive Data Distribution Visualization:** Heatmaps simplify the comprehension of data concentration and distribution, converting complex datasets into easy-to-understand visual formats.
- **Efficient Pattern Detection:** By visualizing data in heatmap format, it becomes easier to spot trends, clusters, and outliers, facilitating quicker analysis and insights.
- **Enhanced Spatial Analysis and Decision Making:** Heatmaps are instrumental in illustrating spatial relationships, aiding in decision-making processes in sectors such as business intelligence, environmental studies, and urban planning.
## Real World Applications
| Transportation | Retail |
|:-----------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------:|
| ![Ultralytics YOLOv8 Transportation Heatmap](https://github.com/RizwanMunawar/ultralytics/assets/62513924/50d197b8-c7f6-4ecf-a664-3d4363b073de) | ![Ultralytics YOLOv8 Retail Heatmap](https://github.com/RizwanMunawar/ultralytics/assets/62513924/ffd0649f-5ff5-48d2-876d-6bdffeff5c54) |
| Ultralytics YOLOv8 Transportation Heatmap | Ultralytics YOLOv8 Retail Heatmap |
???+ tip "heatmap_alpha"
heatmap_alpha value should be in range (0.0 - 1.0)
!!! Example "Heatmap Example"
=== "Heatmap"
```python
from ultralytics import YOLO
from ultralytics.solutions import heatmap
import cv2
model = YOLO("yolov8s.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4")
if not cap.isOpened():
print("Error reading video file")
exit(0)
heatmap_obj = heatmap.Heatmap()
heatmap_obj.set_args(colormap=cv2.COLORMAP_CIVIDIS,
imw=cap.get(4), # should same as im0 width
imh=cap.get(3), # should same as im0 height
view_img=True)
while cap.isOpened():
success, im0 = cap.read()
if not success:
exit(0)
results = model.track(im0, persist=True)
frame = heatmap_obj.generate_heatmap(im0, tracks=results)
```
=== "Heatmap with Specific Classes"
```python
from ultralytics import YOLO
from ultralytics.solutions import heatmap
import cv2
model = YOLO("yolov8s.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4")
if not cap.isOpened():
print("Error reading video file")
exit(0)
classes_for_heatmap = [0, 2]
heatmap_obj = heatmap.Heatmap()
heatmap_obj.set_args(colormap=cv2.COLORMAP_CIVIDIS,
imw=cap.get(4), # should same as im0 width
imh=cap.get(3), # should same as im0 height
view_img=True)
while cap.isOpened():
success, im0 = cap.read()
if not success:
exit(0)
results = model.track(im0, persist=True,
classes=classes_for_heatmap)
frame = heatmap_obj.generate_heatmap(im0, tracks=results)
```
=== "Heatmap with Save Output"
```python
from ultralytics import YOLO
import heatmap
import cv2
model = YOLO("yolov8n.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4")
if not cap.isOpened():
print("Error reading video file")
exit(0)
video_writer = cv2.VideoWriter("heatmap_output.avi",
cv2.VideoWriter_fourcc(*'mp4v'),
int(cap.get(5)),
(int(cap.get(3)), int(cap.get(4))))
heatmap_obj = heatmap.Heatmap()
heatmap_obj.set_args(colormap=cv2.COLORMAP_CIVIDIS,
imw=cap.get(4), # should same as im0 width
imh=cap.get(3), # should same as im0 height
view_img=True)
while cap.isOpened():
success, im0 = cap.read()
if not success:
exit(0)
results = model.track(im0, persist=True)
frame = heatmap_obj.generate_heatmap(im0, tracks=results)
video_writer.write(im0)
video_writer.release()
```
### Arguments `set_args`
| Name | Type | Default | Description |
|---------------|----------------|---------|--------------------------------|
| view_img | `bool` | `False` | Display the frame with heatmap |
| colormap | `cv2.COLORMAP` | `None` | cv2.COLORMAP for heatmap |
| imw | `int` | `None` | Width of Heatmap |
| imh | `int` | `None` | Height of Heatmap |
| heatmap_alpha | `float` | `0.5` | Heatmap alpha value |
### Arguments `model.track`
| Name | Type | Default | Description |
|-----------|---------|----------------|-------------------------------------------------------------|
| `source` | `im0` | `None` | source directory for images or videos |
| `persist` | `bool` | `False` | persisting tracks between frames |
| `tracker` | `str` | `botsort.yaml` | Tracking method 'bytetrack' or 'botsort' |
| `conf` | `float` | `0.3` | Confidence Threshold |
| `iou` | `float` | `0.5` | IOU Threshold |
| `classes` | `list` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |

@ -34,6 +34,7 @@ Here's a compilation of in-depth guides to help you master different aspects of
* [Workouts Monitoring](workouts-monitoring.md) 🚀 NEW: Discover the comprehensive approach to monitoring workouts with Ultralytics YOLOv8. Acquire the skills and insights necessary to effectively use YOLOv8 for tracking and analyzing various aspects of fitness routines in real time.
* [Objects Counting in Regions](region-counting.md) 🚀 NEW: Explore counting objects in specific regions with Ultralytics YOLOv8 for precise and efficient object detection in varied areas.
* [Security Alarm System](security-alarm-system.md) 🚀 NEW: Discover the process of creating a security alarm system with Ultralytics YOLOv8. This system triggers alerts upon detecting new objects in the frame. Subsequently, you can customize the code to align with your specific use case.
* [Heatmaps](heatmaps.md) 🚀 NEW: Elevate your understanding of data with our Detection Heatmaps! These intuitive visual tools use vibrant color gradients to vividly illustrate the intensity of data values across a matrix. Essential in computer vision, heatmaps are skillfully designed to highlight areas of interest, providing an immediate, impactful way to interpret spatial information.
## Contribute to Our Guides

@ -23,28 +23,100 @@ Object counting with [Ultralytics YOLOv8](https://github.com/ultralytics/ultraly
| ![Conveyor Belt Packets Counting Using Ultralytics YOLOv8](https://github.com/RizwanMunawar/ultralytics/assets/62513924/70e2d106-510c-4c6c-a57a-d34a765aa757) | ![Fish Counting in Sea using Ultralytics YOLOv8](https://github.com/RizwanMunawar/ultralytics/assets/62513924/c60d047b-3837-435f-8d29-bb9fc95d2191) |
| Conveyor Belt Packets Counting Using Ultralytics YOLOv8 | Fish Counting in Sea using Ultralytics YOLOv8 |
## Example
```python
from ultralytics import YOLO
from ultralytics.solutions import object_counter
import cv2
model = YOLO("yolov8n.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4")
counter = object_counter.ObjectCounter() # Init Object Counter
region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
counter.set_args(view_img=True, reg_pts=region_points,
classes_names=model.names, draw_tracks=True)
while cap.isOpened():
success, frame = cap.read()
if not success:
exit(0)
tracks = model.track(frame, persist=True, show=False)
counter.start_counting(frame, tracks)
```
!!! Example "Object Counting Example"
=== "Object Counting"
```python
from ultralytics import YOLO
from ultralytics.solutions import object_counter
import cv2
model = YOLO("yolov8n.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4")
if not cap.isOpened():
print("Error reading video file")
exit(0)
counter = object_counter.ObjectCounter() # Init Object Counter
region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
counter.set_args(view_img=True,
reg_pts=region_points,
classes_names=model.names,
draw_tracks=True)
while cap.isOpened():
success, im0 = cap.read()
if not success:
exit(0)
tracks = model.track(im0, persist=True, show=False)
im0 = counter.start_counting(im0, tracks)
```
=== "Object Counting with Specific Classes"
```python
from ultralytics import YOLO
from ultralytics.solutions import object_counter
import cv2
model = YOLO("yolov8n.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4")
if not cap.isOpened():
print("Error reading video file")
exit(0)
classes_to_count = [0, 2]
counter = object_counter.ObjectCounter() # Init Object Counter
region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
counter.set_args(view_img=True,
reg_pts=region_points,
classes_names=model.names,
draw_tracks=True)
while cap.isOpened():
success, im0 = cap.read()
if not success:
exit(0)
tracks = model.track(im0, persist=True,
show=False,
classes=classes_to_count)
im0 = counter.start_counting(im0, tracks)
```
=== "Object Counting with Save Output"
```python
from ultralytics import YOLO
from ultralytics.solutions import object_counter
import cv2
model = YOLO("yolov8n.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4")
if not cap.isOpened():
print("Error reading video file")
exit(0)
video_writer = cv2.VideoWriter("object_counting.avi",
cv2.VideoWriter_fourcc(*'mp4v'),
int(cap.get(5)),
(int(cap.get(3)), int(cap.get(4))))
counter = object_counter.ObjectCounter() # Init Object Counter
region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
counter.set_args(view_img=True,
reg_pts=region_points,
classes_names=model.names,
draw_tracks=True)
while cap.isOpened():
success, im0 = cap.read()
if not success:
exit(0)
tracks = model.track(im0, persist=True, show=False)
im0 = counter.start_counting(im0, tracks)
video_writer.write(im0)
video_writer.release()
```
???+ tip "Region is Movable"
@ -60,3 +132,16 @@ while cap.isOpened():
| classes_names | `dict` | `model.model.names` | Classes Names Dict |
| region_color | `tuple` | `(0, 255, 0)` | Region Area Color |
| track_thickness | `int` | `2` | Tracking line thickness |
| draw_tracks | `bool` | `False` | Draw Tracks lines |
### Arguments `model.track`
| Name | Type | Default | Description |
|-----------|---------|----------------|-------------------------------------------------------------|
| `source` | `im0` | `None` | source directory for images or videos |
| `persist` | `bool` | `False` | persisting tracks between frames |
| `tracker` | `str` | `botsort.yaml` | Tracking method 'bytetrack' or 'botsort' |
| `conf` | `float` | `0.3` | Confidence Threshold |
| `iou` | `float` | `0.5` | IOU Threshold |
| `classes` | `list` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |

@ -23,27 +23,72 @@ Monitoring workouts through pose estimation with [Ultralytics YOLOv8](https://gi
| ![PushUps Counting](https://github.com/RizwanMunawar/ultralytics/assets/62513924/cf016a41-589f-420f-8a8c-2cc8174a16de) | ![PullUps Counting](https://github.com/RizwanMunawar/ultralytics/assets/62513924/cb20f316-fac2-4330-8445-dcf5ffebe329) |
| PushUps Counting | PullUps Counting |
## Example
```python
from ultralytics import YOLO
from ultralytics.solutions import ai_gym
import cv2
model = YOLO("yolov8n-pose.pt")
cap = cv2.VideoCapture("path/to/video.mp4")
gym_object = ai_gym.AIGym() # init AI GYM module
gym_object.set_args(line_thickness=2, view_img=True, pose_type="pushup", kpts_to_check=[6, 8, 10])
frame_count = 0
while cap.isOpened():
success, frame = cap.read()
if not success: exit(0)
frame_count += 1
results = model.predict(frame, verbose=False)
gym_object.start_counting(frame, results, frame_count)
```
!!! Example "Workouts Monitoring Example"
=== "Workouts Monitoring"
```python
from ultralytics import YOLO
from ultralytics.solutions import ai_gym
import cv2
model = YOLO("yolov8n-pose.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4")
if not cap.isOpened():
print("Error reading video file")
exit(0)
gym_object = ai_gym.AIGym() # init AI GYM module
gym_object.set_args(line_thickness=2,
view_img=True,
pose_type="pushup",
kpts_to_check=[6, 8, 10])
frame_count = 0
while cap.isOpened():
success, im0 = cap.read()
if not success:
exit(0)
frame_count += 1
results = model.predict(im0, verbose=False)
im0 = gym_object.start_counting(im0, results, frame_count)
```
=== "Workouts Monitoring with Save Output"
```python
from ultralytics import YOLO
from ultralytics.solutions import ai_gym
import cv2
model = YOLO("yolov8n-pose.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4")
if not cap.isOpened():
print("Error reading video file")
exit(0)
video_writer = cv2.VideoWriter("workouts.avi",
cv2.VideoWriter_fourcc(*'mp4v'),
int(cap.get(5)),
(int(cap.get(3)), int(cap.get(4))))
gym_object = ai_gym.AIGym() # init AI GYM module
gym_object.set_args(line_thickness=2,
view_img=True,
pose_type="pushup",
kpts_to_check=[6, 8, 10])
frame_count = 0
while cap.isOpened():
success, im0 = cap.read()
if not success:
exit(0)
frame_count += 1
results = model.predict(im0, verbose=False)
im0 = gym_object.start_counting(im0, results, frame_count)
video_writer.write(im0)
video_writer.release()
```
???+ tip "Support"
@ -51,7 +96,7 @@ while cap.isOpened():
### KeyPoints Map
![keyPoints Order Ultralytics YOLOv8 Pose](https://github.com/RizwanMunawar/ultralytics/assets/62513924/520059af-f961-433b-b2fb-7fe8c4336ee5)
![keyPoints Order Ultralytics YOLOv8 Pose](https://github.com/ultralytics/ultralytics/assets/62513924/f45d8315-b59f-47b7-b9c8-c61af1ce865b)
### Arguments `set_args`
@ -63,3 +108,22 @@ while cap.isOpened():
| pose_type | `str` | `pushup` | Pose that need to be monitored, "pullup" and "abworkout" also supported |
| pose_up_angle | `int` | `145` | Pose Up Angle value |
| pose_down_angle | `int` | `90` | Pose Down Angle value |
### Arguments `model.predict`
| Name | Type | Default | Description |
|-----------------|----------------|------------------------|----------------------------------------------------------------------------|
| `source` | `str` | `'ultralytics/assets'` | source directory for images or videos |
| `conf` | `float` | `0.25` | object confidence threshold for detection |
| `iou` | `float` | `0.7` | intersection over union (IoU) threshold for NMS |
| `imgsz` | `int or tuple` | `640` | image size as scalar or (h, w) list, i.e. (640, 480) |
| `half` | `bool` | `False` | use half precision (FP16) |
| `device` | `None or str` | `None` | device to run on, i.e. cuda device=0/1/2/3 or device=cpu |
| `max_det` | `int` | `300` | maximum number of detections per image |
| `vid_stride` | `bool` | `False` | video frame-rate stride |
| `stream_buffer` | `bool` | `False` | buffer all streaming frames (True) or return the most recent frame (False) |
| `visualize` | `bool` | `False` | visualize model features |
| `augment` | `bool` | `False` | apply image augmentation to prediction sources |
| `agnostic_nms` | `bool` | `False` | class-agnostic NMS |
| `retina_masks` | `bool` | `False` | use high-resolution segmentation masks |
| `classes` | `None or list` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |

@ -38,13 +38,13 @@ To keep track of your experiments and/or data, ClearML needs to communicate to a
Either sign up for free to the [ClearML Hosted Service](https://cutt.ly/yolov5-tutorial-clearml) or you can set up your own server, see [here](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server). Even the server is open-source, so even if you're dealing with sensitive data, you should be good to go!
1. Install the `clearml` python package:
- Install the `clearml` python package:
```bash
pip install clearml
```
2. Connect the ClearML SDK to the server by [creating credentials](https://app.clear.ml/settings/workspace-configuration) (go right top to Settings -> Workspace -> Create new credentials), then execute the command below and follow the instructions:
- Connect the ClearML SDK to the server by [creating credentials](https://app.clear.ml/settings/workspace-configuration) (go right top to Settings -> Workspace -> Create new credentials), then execute the command below and follow the instructions:
```bash
clearml-init
@ -89,15 +89,13 @@ This will capture:
- Images with bounding boxes per epoch
- Mosaic per epoch
- Validation images per epoch
- ...
That's a lot right? 🤯 Now, we can visualize all of this information in the ClearML UI to get an overview of our training progress. Add custom columns to the table view (such as e.g. mAP_0.5) so you can easily sort on the best performing model. Or select multiple experiments and directly compare them!
There even more we can do with all of this information, like hyperparameter optimization and remote execution, so keep reading if you want to see how that works!
<br>
## 🔗 Dataset Version Management
### 🔗 Dataset Version Management
Versioning your data separately from your code is generally a good idea and makes it easy to acquire the latest version too. This repository supports supplying a dataset version ID, and it will make sure to get the data if it's not there yet. Next to that, this workflow also saves the used dataset ID as part of the task parameters, so you will always know for sure which data was used in which experiment!
@ -165,7 +163,7 @@ python train.py --img 640 --batch 16 --epochs 3 --data clearml://<your_dataset_i
<br>
## 👀 Hyperparameter Optimization
### 👀 Hyperparameter Optimization
Now that we have our experiments and data versioned, it's time to take a look at what we can build on top!

@ -10,22 +10,22 @@ keywords: YOLOv5, Comet, Machine Learning, Ultralytics, Real time metrics tracki
This guide will cover how to use YOLOv5 with [Comet](https://bit.ly/yolov5-readme-comet2)
# About Comet
## About Comet
Comet builds tools that help data scientists, engineers, and team leaders accelerate and optimize machine learning and deep learning models.
Track and visualize model metrics in real time, save your hyperparameters, datasets, and model checkpoints, and visualize your model predictions with [Comet Custom Panels](https://www.comet.com/docs/v2/guides/comet-dashboard/code-panels/about-panels/?utm_source=yolov5&utm_medium=partner&utm_campaign=partner_yolov5_2022&utm_content=github)!
Comet makes sure you never lose track of your work and makes it easy to share results and collaborate across teams of all sizes!
# Getting Started
## Getting Started
## Install Comet
### Install Comet
```shell
pip install comet_ml
```
## Configure Comet Credentials
### Configure Comet Credentials
There are two ways to configure Comet with YOLOv5.
@ -48,7 +48,7 @@ api_key=<Your Comet API Key>
project_name=<Your Comet Project Name> # This will default to 'yolov5'
```
## Run the Training Script
### Run the Training Script
```shell
# Train YOLOv5s on COCO128 for 5 epochs
@ -59,7 +59,7 @@ That's it! Comet will automatically log your hyperparameters, command line argum
<img width="1920" alt="yolo-ui" src="https://user-images.githubusercontent.com/26833433/202851203-164e94e1-2238-46dd-91f8-de020e9d6b41.png">
# Try out an Example!
## Try out an Example!
Check out an example of a [completed run here](https://www.comet.com/examples/comet-example-yolov5/a0e29e0e9b984e4a822db2a62d0cb357?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&xAxis=step&utm_source=yolov5&utm_medium=partner&utm_campaign=partner_yolov5_2022&utm_content=github)
@ -67,7 +67,7 @@ Or better yet, try it out yourself in this Colab Notebook
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1RG0WOQyxlDlo5Km8GogJpIEJlg_5lyYO?usp=sharing)
# Log automatically
## Log automatically
By default, Comet will log the following items
@ -88,7 +88,7 @@ By default, Comet will log the following items
- Plots for the PR and F1 curves across all classes
- Correlogram of the Class Labels
# Configure Comet Logging
## Configure Comet Logging
Comet can be configured to log additional data either through command line flags passed to the training script or through environment variables.
@ -254,7 +254,7 @@ comet optimizer -j <set number of workers> utils/loggers/comet/hpo.py \
utils/loggers/comet/optimizer_config.json"
```
### Visualizing Results
## Visualizing Results
Comet provides a number of ways to visualize the results of your sweep. Take a look at a [project with a completed sweep here](https://www.comet.com/examples/comet-example-yolov5/view/PrlArHGuuhDTKC1UuBmTtOSXD/panels?utm_source=yolov5&utm_medium=partner&utm_campaign=partner_yolov5_2022&utm_content=github)

@ -270,6 +270,7 @@ nav:
- Workouts Monitoring: guides/workouts-monitoring.md
- Objects Counting in Regions: guides/region-counting.md
- Security Alarm System: guides/security-alarm-system.md
- Heatmaps: guides/heatmaps.md
- Integrations:
- integrations/index.md
- Comet ML: integrations/comet.md

@ -125,6 +125,8 @@ class AIGym:
if cv2.waitKey(1) & 0xFF == ord('q'):
return
return self.im0
if __name__ == '__main__':
AIGym()

@ -0,0 +1,102 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import cv2
import numpy as np
class Heatmap:
"""A class to draw heatmaps in real-time video stream based on their tracks."""
def __init__(self):
"""Initializes the heatmap class with default values for Visual, Image, track and heatmap parameters."""
# Visual Information
self.annotator = None
self.view_img = False
# Image Information
self.imw = None
self.imh = None
self.im0 = None
# Heatmap Colormap and heatmap np array
self.colormap = None
self.heatmap = None
self.heatmap_alpha = 0.5
# Predict/track information
self.boxes = None
self.track_ids = None
self.clss = None
def set_args(self, imw, imh, colormap=cv2.COLORMAP_JET, heatmap_alpha=0.5, view_img=False):
"""
Configures the heatmap colormap, width, height and display parameters.
Args:
colormap (cv2.COLORMAP): The colormap to be set.
imw (int): The width of the frame.
imh (int): The height of the frame.
heatmap_alpha (float): alpha value for heatmap display
view_img (bool): Flag indicating frame display
"""
self.imw = imw
self.imh = imh
self.colormap = colormap
self.heatmap_alpha = heatmap_alpha
self.view_img = view_img
# Heatmap new frame
self.heatmap = np.zeros((int(self.imw), int(self.imh)), dtype=np.float32)
def extract_results(self, tracks):
"""
Extracts results from the provided data.
Args:
tracks (list): List of tracks obtained from the object tracking process.
"""
if tracks[0].boxes.id is None:
return
self.boxes = tracks[0].boxes.xyxy.cpu()
self.clss = tracks[0].boxes.cls.cpu().tolist()
self.track_ids = tracks[0].boxes.id.int().cpu().tolist()
def generate_heatmap(self, im0, tracks):
"""
Generate heatmap based on tracking data.
Args:
im0 (nd array): Image
tracks (list): List of tracks obtained from the object tracking process.
"""
self.extract_results(tracks)
self.im0 = im0
for box, cls in zip(self.boxes, self.clss):
self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += 1
# Normalize, apply colormap to heatmap and combine with original image
heatmap_normalized = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX)
heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), self.colormap)
im0_with_heatmap = cv2.addWeighted(self.im0, 1 - self.heatmap_alpha, heatmap_colored, self.heatmap_alpha, 0)
if self.view_img:
self.display_frames(im0_with_heatmap)
return im0_with_heatmap
def display_frames(self, im0_with_heatmap):
"""
Display heatmap.
Args:
im0_with_heatmap (nd array): Original Image with heatmap
"""
cv2.imshow('Ultralytics Heatmap', im0_with_heatmap)
if cv2.waitKey(1) & 0xFF == ord('q'):
return
if __name__ == '__main__':
Heatmap()

@ -159,6 +159,7 @@ class ObjectCounter:
if tracks[0].boxes.id is None:
return
self.extract_and_process_tracks(tracks)
return self.im0
if __name__ == '__main__':

Loading…
Cancel
Save