Merge branch 'main' into cli-info

cli-info
Burhan 2 months ago committed by GitHub
commit 0656a3652a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      docker/Dockerfile
  2. 154
      docs/en/guides/heatmaps.md
  3. 4
      docs/en/reference/utils/torch_utils.md
  4. 10
      docs/overrides/javascript/extra.js
  5. 8
      tests/test_solutions.py
  6. 2
      ultralytics/__init__.py
  7. 8
      ultralytics/cfg/solutions/default.yaml
  8. 8
      ultralytics/engine/trainer.py
  9. 280
      ultralytics/solutions/heatmap.py
  10. 8
      ultralytics/solutions/object_counter.py
  11. 3
      ultralytics/solutions/solutions.py
  12. 56
      ultralytics/utils/checks.py
  13. 9
      ultralytics/utils/torch_utils.py

@ -3,7 +3,7 @@
# Image is CUDA-optimized for YOLO11 single/multi-GPU training and inference # Image is CUDA-optimized for YOLO11 single/multi-GPU training and inference
# Start FROM PyTorch image https://hub.docker.com/r/pytorch/pytorch or nvcr.io/nvidia/pytorch:23.03-py3 # Start FROM PyTorch image https://hub.docker.com/r/pytorch/pytorch or nvcr.io/nvidia/pytorch:23.03-py3
FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime FROM pytorch/pytorch:2.4.1-cuda12.1-cudnn9-runtime
# Set environment variables # Set environment variables
# Avoid DDP error "MKL_THREADING_LAYER=INTEL is incompatible with libgomp.so.1 library" https://github.com/pytorch/pytorch/issues/37377 # Avoid DDP error "MKL_THREADING_LAYER=INTEL is incompatible with libgomp.so.1 library" https://github.com/pytorch/pytorch/issues/37377

@ -41,10 +41,9 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
```python ```python
import cv2 import cv2
from ultralytics import YOLO, solutions from ultralytics import solutions
model = YOLO("yolo11n.pt") cap = cv2.VideoCapture("Path/to/video/file.mp4")
cap = cv2.VideoCapture("path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file" assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
@ -52,11 +51,10 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
video_writer = cv2.VideoWriter("heatmap_output.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) video_writer = cv2.VideoWriter("heatmap_output.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
# Init heatmap # Init heatmap
heatmap_obj = solutions.Heatmap( heatmap = solutions.Heatmap(
show=True,
model="yolo11n.pt",
colormap=cv2.COLORMAP_PARULA, colormap=cv2.COLORMAP_PARULA,
view_img=True,
shape="circle",
names=model.names,
) )
while cap.isOpened(): while cap.isOpened():
@ -64,9 +62,7 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
if not success: if not success:
print("Video frame is empty or video processing has been successfully completed.") print("Video frame is empty or video processing has been successfully completed.")
break break
tracks = model.track(im0, persist=True, show=False) im0 = heatmap.generate_heatmap(im0)
im0 = heatmap_obj.generate_heatmap(im0, tracks)
video_writer.write(im0) video_writer.write(im0)
cap.release() cap.release()
@ -79,25 +75,24 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
```python ```python
import cv2 import cv2
from ultralytics import YOLO, solutions from ultralytics import solutions
model = YOLO("yolo11n.pt") cap = cv2.VideoCapture("Path/to/video/file.mp4")
cap = cv2.VideoCapture("path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file" assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
# Video writer # Video writer
video_writer = cv2.VideoWriter("heatmap_output.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) video_writer = cv2.VideoWriter("heatmap_output.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
line_points = [(20, 400), (1080, 404)] # line for object counting # line for object counting
line_points = [(20, 400), (1080, 404)]
# Init heatmap # Init heatmap
heatmap_obj = solutions.Heatmap( heatmap = solutions.Heatmap(
show=True,
model="yolo11n.pt",
colormap=cv2.COLORMAP_PARULA, colormap=cv2.COLORMAP_PARULA,
view_img=True, region=line_points,
shape="circle",
count_reg_pts=line_points,
names=model.names,
) )
while cap.isOpened(): while cap.isOpened():
@ -105,9 +100,7 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
if not success: if not success:
print("Video frame is empty or video processing has been successfully completed.") print("Video frame is empty or video processing has been successfully completed.")
break break
im0 = heatmap.generate_heatmap(im0)
tracks = model.track(im0, persist=True, show=False)
im0 = heatmap_obj.generate_heatmap(im0, tracks)
video_writer.write(im0) video_writer.write(im0)
cap.release() cap.release()
@ -120,10 +113,9 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
```python ```python
import cv2 import cv2
from ultralytics import YOLO, solutions from ultralytics import solutions
model = YOLO("yolo11n.pt") cap = cv2.VideoCapture("Path/to/video/file.mp4")
cap = cv2.VideoCapture("path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file" assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
@ -134,12 +126,11 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360), (20, 400)] region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360), (20, 400)]
# Init heatmap # Init heatmap
heatmap_obj = solutions.Heatmap( heatmap = solutions.Heatmap(
show=True,
model="yolo11n.pt",
colormap=cv2.COLORMAP_PARULA, colormap=cv2.COLORMAP_PARULA,
view_img=True, region=region_points,
shape="circle",
count_reg_pts=region_points,
names=model.names,
) )
while cap.isOpened(): while cap.isOpened():
@ -147,9 +138,7 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
if not success: if not success:
print("Video frame is empty or video processing has been successfully completed.") print("Video frame is empty or video processing has been successfully completed.")
break break
im0 = heatmap.generate_heatmap(im0)
tracks = model.track(im0, persist=True, show=False)
im0 = heatmap_obj.generate_heatmap(im0, tracks)
video_writer.write(im0) video_writer.write(im0)
cap.release() cap.release()
@ -162,10 +151,9 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
```python ```python
import cv2 import cv2
from ultralytics import YOLO, solutions from ultralytics import solutions
model = YOLO("yolo11n.pt") cap = cv2.VideoCapture("Path/to/video/file.mp4")
cap = cv2.VideoCapture("path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file" assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
@ -176,12 +164,11 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)] region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
# Init heatmap # Init heatmap
heatmap_obj = solutions.Heatmap( heatmap = solutions.Heatmap(
show=True,
model="yolo11n.pt",
colormap=cv2.COLORMAP_PARULA, colormap=cv2.COLORMAP_PARULA,
view_img=True, region=region_points,
shape="circle",
count_reg_pts=region_points,
names=model.names,
) )
while cap.isOpened(): while cap.isOpened():
@ -189,9 +176,7 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
if not success: if not success:
print("Video frame is empty or video processing has been successfully completed.") print("Video frame is empty or video processing has been successfully completed.")
break break
im0 = heatmap.generate_heatmap(im0)
tracks = model.track(im0, persist=True, show=False)
im0 = heatmap_obj.generate_heatmap(im0, tracks)
video_writer.write(im0) video_writer.write(im0)
cap.release() cap.release()
@ -199,54 +184,25 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
cv2.destroyAllWindows() cv2.destroyAllWindows()
``` ```
=== "Im0"
```python
import cv2
from ultralytics import YOLO, solutions
model = YOLO("yolo11n.pt") # YOLO11 custom/pretrained model
im0 = cv2.imread("path/to/image.png") # path to image file
h, w = im0.shape[:2] # image height and width
# Heatmap Init
heatmap_obj = solutions.Heatmap(
colormap=cv2.COLORMAP_PARULA,
view_img=True,
shape="circle",
names=model.names,
)
results = model.track(im0, persist=True)
im0 = heatmap_obj.generate_heatmap(im0, tracks=results)
cv2.imwrite("ultralytics_output.png", im0)
```
=== "Specific Classes" === "Specific Classes"
```python ```python
import cv2 import cv2
from ultralytics import YOLO, solutions from ultralytics import solutions
model = YOLO("yolo11n.pt") cap = cv2.VideoCapture("Path/to/video/file.mp4")
cap = cv2.VideoCapture("path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file" assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
# Video writer # Video writer
video_writer = cv2.VideoWriter("heatmap_output.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) video_writer = cv2.VideoWriter("heatmap_output.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
classes_for_heatmap = [0, 2] # classes for heatmap
# Init heatmap # Init heatmap
heatmap_obj = solutions.Heatmap( heatmap = solutions.Heatmap(
colormap=cv2.COLORMAP_PARULA, show=True,
view_img=True, model="yolo11n.pt",
shape="circle", classes=[0, 2],
names=model.names,
) )
while cap.isOpened(): while cap.isOpened():
@ -254,9 +210,7 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
if not success: if not success:
print("Video frame is empty or video processing has been successfully completed.") print("Video frame is empty or video processing has been successfully completed.")
break break
tracks = model.track(im0, persist=True, show=False, classes=classes_for_heatmap) im0 = heatmap.generate_heatmap(im0)
im0 = heatmap_obj.generate_heatmap(im0, tracks)
video_writer.write(im0) video_writer.write(im0)
cap.release() cap.release()
@ -267,20 +221,13 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
### Arguments `Heatmap()` ### Arguments `Heatmap()`
| Name | Type | Default | Description | | Name | Type | Default | Description |
| ------------------ | ---------------- | ------------------ | ----------------------------------------------------------------- | | ------------ | ------ | ------------------ | ----------------------------------------------------------------- |
| `names` | `list` | `None` | Dictionary of class names. |
| `colormap` | `int` | `cv2.COLORMAP_JET` | Colormap to use for the heatmap. | | `colormap` | `int` | `cv2.COLORMAP_JET` | Colormap to use for the heatmap. |
| `view_img` | `bool` | `False` | Whether to display the image with the heatmap overlay. | | `show` | `bool` | `False` | Whether to display the image with the heatmap overlay. |
| `view_in_counts` | `bool` | `True` | Whether to display the count of objects entering the region. | | `show_in` | `bool` | `True` | Whether to display the count of objects entering the region. |
| `view_out_counts` | `bool` | `True` | Whether to display the count of objects exiting the region. | | `show_out` | `bool` | `True` | Whether to display the count of objects exiting the region. |
| `count_reg_pts` | `list` or `None` | `None` | Points defining the counting region (either a line or a polygon). | | `region` | `list` | `None` | Points defining the counting region (either a line or a polygon). |
| `count_txt_color` | `tuple` | `(0, 0, 0)` | Text color for displaying counts. | | `line_width` | `int` | `2` | Thickness of the lines used in drawing. |
| `count_bg_color` | `tuple` | `(255, 255, 255)` | Background color for displaying counts. |
| `count_reg_color` | `tuple` | `(255, 0, 255)` | Color for the counting region. |
| `region_thickness` | `int` | `5` | Thickness of the region line. |
| `line_dist_thresh` | `int` | `15` | Distance threshold for line-based counting. |
| `line_thickness` | `int` | `2` | Thickness of the lines used in drawing. |
| `shape` | `str` | `"circle"` | Shape of the heatmap blobs ('circle' or 'rect'). |
### Arguments `model.track` ### Arguments `model.track`
@ -328,18 +275,16 @@ Yes, Ultralytics YOLO11 supports object tracking and heatmap generation concurre
```python ```python
import cv2 import cv2
from ultralytics import YOLO, solutions from ultralytics import solutions
model = YOLO("yolo11n.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4") cap = cv2.VideoCapture("path/to/video/file.mp4")
heatmap_obj = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, view_img=True, shape="circle", names=model.names) heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, show=True, model="yolo11n.pt")
while cap.isOpened(): while cap.isOpened():
success, im0 = cap.read() success, im0 = cap.read()
if not success: if not success:
break break
tracks = model.track(im0, persist=True, show=False) im0 = heatmap.generate_heatmap(im0)
im0 = heatmap_obj.generate_heatmap(im0, tracks)
cv2.imshow("Heatmap", im0) cv2.imshow("Heatmap", im0)
if cv2.waitKey(1) & 0xFF == ord("q"): if cv2.waitKey(1) & 0xFF == ord("q"):
break break
@ -361,19 +306,16 @@ You can visualize specific object classes by specifying the desired classes in t
```python ```python
import cv2 import cv2
from ultralytics import YOLO, solutions from ultralytics import solutions
model = YOLO("yolo11n.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4") cap = cv2.VideoCapture("path/to/video/file.mp4")
heatmap_obj = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, view_img=True, shape="circle", names=model.names) heatmap = solutions.Heatmap(show=True, model="yolo11n.pt", classes=[0, 2])
classes_for_heatmap = [0, 2] # Classes to visualize
while cap.isOpened(): while cap.isOpened():
success, im0 = cap.read() success, im0 = cap.read()
if not success: if not success:
break break
tracks = model.track(im0, persist=True, show=False, classes=classes_for_heatmap) im0 = heatmap.generate_heatmap(im0)
im0 = heatmap_obj.generate_heatmap(im0, tracks)
cv2.imshow("Heatmap", im0) cv2.imshow("Heatmap", im0)
if cv2.waitKey(1) & 0xFF == ord("q"): if cv2.waitKey(1) & 0xFF == ord("q"):
break break

@ -35,6 +35,10 @@ keywords: Ultralytics, torch utils, model optimization, device selection, infere
<br><br><hr><br> <br><br><hr><br>
## ::: ultralytics.utils.torch_utils.get_gpu_info
<br><br><hr><br>
## ::: ultralytics.utils.torch_utils.select_device ## ::: ultralytics.utils.torch_utils.select_device
<br><br><hr><br> <br><br><hr><br>

@ -94,13 +94,13 @@ document.addEventListener("DOMContentLoaded", () => {
fixedPositionYOffset: "3rem", fixedPositionYOffset: "3rem",
chatButtonBgColor: "#E1FF25", chatButtonBgColor: "#E1FF25",
baseSettings: { baseSettings: {
apiKey: "13dfec2e75982bc9bae3199a08e13b86b5fbacd64e9b2f89", // required apiKey: "13dfec2e75982bc9bae3199a08e13b86b5fbacd64e9b2f89",
integrationId: "cm1shscmm00y26sj83lgxzvkw", // required integrationId: "cm1shscmm00y26sj83lgxzvkw",
organizationId: "org_e3869az6hQZ0mXdF", // required organizationId: "org_e3869az6hQZ0mXdF",
primaryBrandColor: "#E1FF25", // Ultralytics brand color primaryBrandColor: "#E1FF25",
organizationDisplayName: "Ultralytics", organizationDisplayName: "Ultralytics",
theme: { theme: {
stylesheetUrls: ["../stylesheets/style.css"], stylesheetUrls: ["/stylesheets/style.css"],
}, },
// ...optional settings // ...optional settings
}, },

@ -19,8 +19,8 @@ def test_major_solutions():
cap = cv2.VideoCapture("solutions_ci_demo.mp4") cap = cv2.VideoCapture("solutions_ci_demo.mp4")
assert cap.isOpened(), "Error reading video file" assert cap.isOpened(), "Error reading video file"
region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)] region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
# counter = solutions.ObjectCounter(reg_pts=region_points, names=names, view_img=False) counter = solutions.ObjectCounter(region=region_points, model="yolo11n.pt", show=False)
heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, names=names, view_img=False) heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False)
speed = solutions.SpeedEstimator(reg_pts=region_points, names=names, view_img=False) speed = solutions.SpeedEstimator(reg_pts=region_points, names=names, view_img=False)
queue = solutions.QueueManager(names=names, reg_pts=region_points, view_img=False) queue = solutions.QueueManager(names=names, reg_pts=region_points, view_img=False)
while cap.isOpened(): while cap.isOpened():
@ -29,8 +29,8 @@ def test_major_solutions():
break break
original_im0 = im0.copy() original_im0 = im0.copy()
tracks = model.track(im0, persist=True, show=False) tracks = model.track(im0, persist=True, show=False)
# _ = counter.start_counting(original_im0.copy(), tracks) _ = counter.count(original_im0.copy())
_ = heatmap.generate_heatmap(original_im0.copy(), tracks) _ = heatmap.generate_heatmap(original_im0.copy())
_ = speed.estimate_speed(original_im0.copy(), tracks) _ = speed.estimate_speed(original_im0.copy(), tracks)
_ = queue.process_queue(original_im0.copy(), tracks) _ = queue.process_queue(original_im0.copy(), tracks)
cap.release() cap.release()

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.5" __version__ = "8.3.6"
import os import os

@ -10,7 +10,7 @@ show: True # Flag to control whether to display output image or not
show_in: True # Flag to display objects moving *into* the defined region show_in: True # Flag to display objects moving *into* the defined region
show_out: True # Flag to display objects moving *out of* the defined region show_out: True # Flag to display objects moving *out of* the defined region
classes: # To count specific classes classes: # To count specific classes
up_angle: 145.0 # Workouts up_angle for counts, 145.0 is default value
up_angle: 145.0 # workouts up_angle for counts, 145.0 is default value down_angle: 90 # Workouts down_angle for counts, 90 is default value
down_angle: 90 # workouts down_angle for counts, 90 is default value kpts: [6, 8, 10] # Keypoints for workouts monitoring
kpts: [6, 8, 10] # keypoints for workouts monitoring colormap: # Colormap for heatmap

@ -469,11 +469,11 @@ class BaseTrainer:
if RANK in {-1, 0}: if RANK in {-1, 0}:
# Do final val with best.pt # Do final val with best.pt
LOGGER.info( epochs = epoch - self.start_epoch + 1 # total training epochs
f"\n{epoch - self.start_epoch + 1} epochs completed in " seconds = time.time() - self.train_time_start # total training seconds
f"{(time.time() - self.train_time_start) / 3600:.3f} hours." LOGGER.info(f"\n{epochs} epochs completed in {seconds / 3600:.3f} hours.")
)
self.final_eval() self.final_eval()
self.validator.metrics.training = {"epochs": epochs, "seconds": seconds} # add training speed
if self.args.plots: if self.args.plots:
self.plot_metrics() self.plot_metrics()
self.run_callbacks("on_train_end") self.run_callbacks("on_train_end")

@ -1,249 +1,93 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
from collections import defaultdict
import cv2 import cv2
import numpy as np import numpy as np
from ultralytics.utils.checks import check_imshow, check_requirements from ultralytics.solutions.object_counter import ObjectCounter # Import object counter class
from ultralytics.utils.plotting import Annotator from ultralytics.utils.plotting import Annotator
check_requirements("shapely>=2.0.0")
from shapely.geometry import LineString, Point, Polygon
class Heatmap(ObjectCounter):
class Heatmap:
"""A class to draw heatmaps in real-time video stream based on their tracks.""" """A class to draw heatmaps in real-time video stream based on their tracks."""
def __init__( def __init__(self, **kwargs):
self, """Initializes function for heatmap class with default values."""
names, super().__init__(**kwargs)
colormap=cv2.COLORMAP_JET,
view_img=False,
view_in_counts=True,
view_out_counts=True,
count_reg_pts=None,
count_txt_color=(0, 0, 0),
count_bg_color=(255, 255, 255),
count_reg_color=(255, 0, 255),
region_thickness=5,
line_dist_thresh=15,
line_thickness=2,
shape="circle",
):
"""Initializes the heatmap class with default values for Visual, Image, track, count and heatmap parameters."""
# Visual information
self.annotator = None
self.view_img = view_img
self.shape = shape
self.initialized = False
self.names = names # Classes names
# Image information
self.im0 = None
self.tf = line_thickness
self.view_in_counts = view_in_counts
self.view_out_counts = view_out_counts
# Heatmap colormap and heatmap np array
self.colormap = colormap
self.heatmap = None
# Predict/track information
self.boxes = []
self.track_ids = []
self.clss = []
self.track_history = defaultdict(list)
# Region & Line Information
self.counting_region = None
self.line_dist_thresh = line_dist_thresh
self.region_thickness = region_thickness
self.region_color = count_reg_color
# Object Counting Information
self.in_counts = 0
self.out_counts = 0
self.count_ids = []
self.class_wise_count = {}
self.count_txt_color = count_txt_color
self.count_bg_color = count_bg_color
self.cls_txtdisplay_gap = 50
# Check if environment supports imshow
self.env_check = check_imshow(warn=True)
# Region and line selection
self.count_reg_pts = count_reg_pts
print(self.count_reg_pts)
if self.count_reg_pts is not None:
if len(self.count_reg_pts) == 2:
print("Line Counter Initiated.")
self.counting_region = LineString(self.count_reg_pts)
elif len(self.count_reg_pts) >= 3:
print("Polygon Counter Initiated.")
self.counting_region = Polygon(self.count_reg_pts)
else:
print("Invalid Region points provided, region_points must be 2 for lines or >= 3 for polygons.")
print("Using Line Counter Now")
self.counting_region = LineString(self.count_reg_pts)
# Shape of heatmap, if not selected
if self.shape not in {"circle", "rect"}:
print("Unknown shape value provided, 'circle' & 'rect' supported")
print("Using Circular shape now")
self.shape = "circle"
def extract_results(self, tracks):
"""
Extracts results from the provided data.
Args: self.initialized = False # bool variable for heatmap initialization
tracks (list): List of tracks obtained from the object tracking process. if self.region is not None: # check if user provided the region coordinates
""" self.initialize_region()
if tracks[0].boxes.id is not None:
self.boxes = tracks[0].boxes.xyxy.cpu() # store colormap
self.clss = tracks[0].boxes.cls.tolist() self.colormap = cv2.COLORMAP_PARULA if self.CFG["colormap"] is None else self.CFG["colormap"]
self.track_ids = tracks[0].boxes.id.int().tolist()
def generate_heatmap(self, im0, tracks): def heatmap_effect(self, box):
""" """
Generate heatmap based on tracking data. Efficient calculation of heatmap area and effect location for applying colormap.
Args: Args:
im0 (nd array): Image box (list): Bounding Box coordinates data [x0, y0, x1, y1]
tracks (list): List of tracks obtained from the object tracking process.
""" """
self.im0 = im0 x0, y0, x1, y1 = map(int, box)
radius_squared = (min(x1 - x0, y1 - y0) // 2) ** 2
# Initialize heatmap only once # Create a meshgrid with region of interest (ROI) for vectorized distance calculations
if not self.initialized: xv, yv = np.meshgrid(np.arange(x0, x1), np.arange(y0, y1))
self.heatmap = np.zeros((int(self.im0.shape[0]), int(self.im0.shape[1])), dtype=np.float32)
self.initialized = True
self.heatmap *= 0.99 # decay factor # Calculate squared distances from the center
dist_squared = (xv - ((x0 + x1) // 2)) ** 2 + (yv - ((y0 + y1) // 2)) ** 2
self.extract_results(tracks) # Create a mask of points within the radius
self.annotator = Annotator(self.im0, self.tf, None) within_radius = dist_squared <= radius_squared
if self.track_ids: # Update only the values within the bounding box in a single vectorized operation
# Draw counting region self.heatmap[y0:y1, x0:x1][within_radius] += 2
if self.count_reg_pts is not None:
self.annotator.draw_region(
reg_pts=self.count_reg_pts, color=self.region_color, thickness=self.region_thickness
)
for box, cls, track_id in zip(self.boxes, self.clss, self.track_ids):
# Store class info
if self.names[cls] not in self.class_wise_count:
self.class_wise_count[self.names[cls]] = {"IN": 0, "OUT": 0}
if self.shape == "circle": def generate_heatmap(self, im0):
center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2)) """
radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2 Generate heatmap for each frame using Ultralytics.
y, x = np.ogrid[0 : self.heatmap.shape[0], 0 : self.heatmap.shape[1]] Args:
mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius**2 im0 (ndarray): Input image array for processing
Returns:
im0 (ndarray): Processed image for further usage
"""
self.heatmap = np.zeros_like(im0, dtype=np.float32) * 0.99 if not self.initialized else self.heatmap
self.initialized = True # Initialize heatmap only once
self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += ( self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator
2 * mask[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] self.extract_tracks(im0) # Extract tracks
)
else: # Iterate over bounding boxes, track ids and classes index
self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += 2 for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss):
# Draw bounding box and counting region
self.heatmap_effect(box)
# Store tracking hist if self.region is not None:
track_line = self.track_history[track_id] self.annotator.draw_region(reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2)
track_line.append((float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2))) self.store_tracking_history(track_id, box) # Store track history
if len(track_line) > 30: self.store_classwise_counts(cls) # store classwise counts in dict
track_line.pop(0)
# Store tracking previous position and perform object counting
prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
self.count_objects(self.track_line, box, track_id, prev_position, cls) # Perform object counting
if self.count_reg_pts is not None: self.display_counts(im0) if self.region is not None else None # Display the counts on the frame
# Count objects in any polygon
if len(self.count_reg_pts) >= 3:
is_inside = self.counting_region.contains(Point(track_line[-1]))
if prev_position is not None and is_inside and track_id not in self.count_ids:
self.count_ids.append(track_id)
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
self.in_counts += 1
self.class_wise_count[self.names[cls]]["IN"] += 1
else:
self.out_counts += 1
self.class_wise_count[self.names[cls]]["OUT"] += 1
# Count objects using line
elif len(self.count_reg_pts) == 2:
if prev_position is not None and track_id not in self.count_ids:
distance = Point(track_line[-1]).distance(self.counting_region)
if distance < self.line_dist_thresh and track_id not in self.count_ids:
self.count_ids.append(track_id)
if (box[0] - prev_position[0]) * (
self.counting_region.centroid.x - prev_position[0]
) > 0:
self.in_counts += 1
self.class_wise_count[self.names[cls]]["IN"] += 1
else:
self.out_counts += 1
self.class_wise_count[self.names[cls]]["OUT"] += 1
else:
for box, cls in zip(self.boxes, self.clss):
if self.shape == "circle":
center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2))
radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2
y, x = np.ogrid[0 : self.heatmap.shape[0], 0 : self.heatmap.shape[1]]
mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius**2
self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += (
2 * mask[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])]
)
else:
self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += 2
if self.count_reg_pts is not None:
labels_dict = {}
for key, value in self.class_wise_count.items():
if value["IN"] != 0 or value["OUT"] != 0:
if not self.view_in_counts and not self.view_out_counts:
continue
elif not self.view_in_counts:
labels_dict[str.capitalize(key)] = f"OUT {value['OUT']}"
elif not self.view_out_counts:
labels_dict[str.capitalize(key)] = f"IN {value['IN']}"
else:
labels_dict[str.capitalize(key)] = f"IN {value['IN']} OUT {value['OUT']}"
if labels_dict is not None:
self.annotator.display_analytics(self.im0, labels_dict, self.count_txt_color, self.count_bg_color, 10)
# Normalize, apply colormap to heatmap and combine with original image # Normalize, apply colormap to heatmap and combine with original image
heatmap_normalized = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX) im0 = (
heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), self.colormap) im0
self.im0 = cv2.addWeighted(self.im0, 0.5, heatmap_colored, 0.5, 0) if self.track_data.id is None
else cv2.addWeighted(
if self.env_check and self.view_img: im0,
self.display_frames() 0.5,
cv2.applyColorMap(
return self.im0 cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8), self.colormap
),
def display_frames(self): 0.5,
"""Display frame.""" 0,
cv2.imshow("Ultralytics Heatmap", self.im0) )
)
if cv2.waitKey(1) & 0xFF == ord("q"):
return
if __name__ == "__main__": self.display_output(im0) # display output with base class function
classes_names = {0: "person", 1: "car"} # example class names return im0 # return output image for more usage
heatmap = Heatmap(classes_names)

@ -19,8 +19,7 @@ class ObjectCounter(BaseSolution):
self.out_count = 0 # Counter for objects moving outward self.out_count = 0 # Counter for objects moving outward
self.counted_ids = [] # List of IDs of objects that have been counted self.counted_ids = [] # List of IDs of objects that have been counted
self.classwise_counts = {} # Dictionary for counts, categorized by object class self.classwise_counts = {} # Dictionary for counts, categorized by object class
self.region_initialized = False # Bool variable for region initialization
self.initialize_region() # Setup region and counting areas
self.show_in = self.CFG["show_in"] self.show_in = self.CFG["show_in"]
self.show_out = self.CFG["show_out"] self.show_out = self.CFG["show_out"]
@ -99,6 +98,10 @@ class ObjectCounter(BaseSolution):
Returns Returns
im0 (ndarray): The processed image for more usage im0 (ndarray): The processed image for more usage
""" """
if not self.region_initialized:
self.initialize_region()
self.region_initialized = True
self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator
self.extract_tracks(im0) # Extract tracks self.extract_tracks(im0) # Extract tracks
@ -107,7 +110,6 @@ class ObjectCounter(BaseSolution):
) # Draw region ) # Draw region
# Iterate over bounding boxes, track ids and classes index # Iterate over bounding boxes, track ids and classes index
if self.track_data is not None and self.track_data.id is not None:
for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss):
# Draw bounding box and counting region # Draw bounding box and counting region
self.annotator.box_label(box, label=self.names[cls], color=colors(track_id, True)) self.annotator.box_label(box, label=self.names[cls], color=colors(track_id, True))

@ -57,7 +57,8 @@ class BaseSolution:
self.clss = self.track_data.cls.cpu().tolist() self.clss = self.track_data.cls.cpu().tolist()
self.track_ids = self.track_data.id.int().cpu().tolist() self.track_ids = self.track_data.id.int().cpu().tolist()
else: else:
LOGGER.warning("WARNING ⚠ tracks none, no keypoints will be considered.") LOGGER.warning("WARNING ⚠ no tracks found!")
self.boxes, self.clss, self.track_ids = [], [], []
def store_tracking_history(self, track_id, box): def store_tracking_history(self, track_id, box):
""" """

@ -593,20 +593,29 @@ def collect_system_info():
import psutil import psutil
from ultralytics.utils import ENVIRONMENT # scope to avoid circular import from ultralytics.utils import ENVIRONMENT # scope to avoid circular import
from ultralytics.utils.torch_utils import get_cpu_info from ultralytics.utils.torch_utils import get_cpu_info, get_gpu_info
ram_info = psutil.virtual_memory().total / (1024**3) # Convert bytes to GB gib = 1 << 30 # bytes per GiB
cuda = torch and torch.cuda.is_available()
check_yolo() check_yolo()
LOGGER.info( total, used, free = shutil.disk_usage("/")
f"\n{'OS':<20}{platform.platform()}\n"
f"{'Environment':<20}{ENVIRONMENT}\n" info_dict = {
f"{'Python':<20}{PYTHON_VERSION}\n" "OS": platform.platform(),
f"{'Install':<20}{'git' if IS_GIT_DIR else 'pip' if IS_PIP_PACKAGE else 'other'}\n" "Environment": ENVIRONMENT,
f"{'RAM':<20}{ram_info:.2f} GB\n" "Python": PYTHON_VERSION,
f"{'CPU':<20}{get_cpu_info()}\n" "Install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other",
f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n" "RAM": f"{psutil.virtual_memory().total / gib:.2f} GB",
) "Disk": f"{(total - free) / gib:.1f}/{total / gib:.1f} GB",
"CPU": get_cpu_info(),
"CPU count": os.cpu_count(),
"GPU": get_gpu_info(index=0) if cuda else None,
"GPU count": torch.cuda.device_count() if cuda else None,
"CUDA": torch.version.cuda if cuda else None,
}
LOGGER.info("\n" + "\n".join(f"{k:<20}{v}" for k, v in info_dict.items()) + "\n")
package_info = {}
for r in parse_requirements(package="ultralytics"): for r in parse_requirements(package="ultralytics"):
try: try:
current = metadata.version(r.name) current = metadata.version(r.name)
@ -614,17 +623,24 @@ def collect_system_info():
except metadata.PackageNotFoundError: except metadata.PackageNotFoundError:
current = "(not installed)" current = "(not installed)"
is_met = "" is_met = ""
LOGGER.info(f"{r.name:<20}{is_met}{current}{r.specifier}") package_info[r.name] = f"{is_met}{current}{r.specifier}"
LOGGER.info(f"{r.name:<20}{package_info[r.name]}")
info_dict["Package Info"] = package_info
if is_github_action_running(): if is_github_action_running():
LOGGER.info( github_info = {
f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n" "RUNNER_OS": os.getenv("RUNNER_OS"),
f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n" "GITHUB_EVENT_NAME": os.getenv("GITHUB_EVENT_NAME"),
f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n" "GITHUB_WORKFLOW": os.getenv("GITHUB_WORKFLOW"),
f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n" "GITHUB_ACTOR": os.getenv("GITHUB_ACTOR"),
f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n" "GITHUB_REPOSITORY": os.getenv("GITHUB_REPOSITORY"),
f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n" "GITHUB_REPOSITORY_OWNER": os.getenv("GITHUB_REPOSITORY_OWNER"),
) }
LOGGER.info("\n" + "\n".join(f"{k}: {v}" for k, v in github_info.items()))
info_dict["GitHub Info"] = github_info
return info_dict
def check_amp(model): def check_amp(model):

@ -123,6 +123,12 @@ def get_cpu_info():
return PERSISTENT_CACHE.get("cpu_info", "unknown") return PERSISTENT_CACHE.get("cpu_info", "unknown")
def get_gpu_info(index):
"""Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'."""
properties = torch.cuda.get_device_properties(index)
return f"{properties.name}, {properties.total_memory / (1 << 20):.0f}MiB"
def select_device(device="", batch=0, newline=False, verbose=True): def select_device(device="", batch=0, newline=False, verbose=True):
""" """
Selects the appropriate PyTorch device based on the provided arguments. Selects the appropriate PyTorch device based on the provided arguments.
@ -208,8 +214,7 @@ def select_device(device="", batch=0, newline=False, verbose=True):
) )
space = " " * (len(s) + 1) space = " " * (len(s) + 1)
for i, d in enumerate(devices): for i, d in enumerate(devices):
p = torch.cuda.get_device_properties(i) s += f"{'' if i == 0 else space}CUDA:{d} ({get_gpu_info(i)})\n" # bytes to MB
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
arg = "cuda:0" arg = "cuda:0"
elif mps and TORCH_2_0 and torch.backends.mps.is_available(): elif mps and TORCH_2_0 and torch.backends.mps.is_available():
# Prefer MPS if available # Prefer MPS if available

Loading…
Cancel
Save