diff --git a/docker/Dockerfile b/docker/Dockerfile
index 3283c65076..da2d9ed37c 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -3,7 +3,7 @@
# Image is CUDA-optimized for YOLO11 single/multi-GPU training and inference
# Start FROM PyTorch image https://hub.docker.com/r/pytorch/pytorch or nvcr.io/nvidia/pytorch:23.03-py3
-FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime
+FROM pytorch/pytorch:2.4.1-cuda12.1-cudnn9-runtime
# Set environment variables
# Avoid DDP error "MKL_THREADING_LAYER=INTEL is incompatible with libgomp.so.1 library" https://github.com/pytorch/pytorch/issues/37377
diff --git a/docs/en/guides/heatmaps.md b/docs/en/guides/heatmaps.md
index 4e0a665fa0..f33993134f 100644
--- a/docs/en/guides/heatmaps.md
+++ b/docs/en/guides/heatmaps.md
@@ -41,10 +41,9 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
```python
import cv2
- from ultralytics import YOLO, solutions
+ from ultralytics import solutions
- model = YOLO("yolo11n.pt")
- cap = cv2.VideoCapture("path/to/video/file.mp4")
+ cap = cv2.VideoCapture("Path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
@@ -52,11 +51,10 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
video_writer = cv2.VideoWriter("heatmap_output.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
# Init heatmap
- heatmap_obj = solutions.Heatmap(
+ heatmap = solutions.Heatmap(
+ show=True,
+ model="yolo11n.pt",
colormap=cv2.COLORMAP_PARULA,
- view_img=True,
- shape="circle",
- names=model.names,
)
while cap.isOpened():
@@ -64,9 +62,7 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
if not success:
print("Video frame is empty or video processing has been successfully completed.")
break
- tracks = model.track(im0, persist=True, show=False)
-
- im0 = heatmap_obj.generate_heatmap(im0, tracks)
+ im0 = heatmap.generate_heatmap(im0)
video_writer.write(im0)
cap.release()
@@ -79,25 +75,24 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
```python
import cv2
- from ultralytics import YOLO, solutions
+ from ultralytics import solutions
- model = YOLO("yolo11n.pt")
- cap = cv2.VideoCapture("path/to/video/file.mp4")
+ cap = cv2.VideoCapture("Path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
# Video writer
video_writer = cv2.VideoWriter("heatmap_output.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
- line_points = [(20, 400), (1080, 404)] # line for object counting
+ # line for object counting
+ line_points = [(20, 400), (1080, 404)]
# Init heatmap
- heatmap_obj = solutions.Heatmap(
+ heatmap = solutions.Heatmap(
+ show=True,
+ model="yolo11n.pt",
colormap=cv2.COLORMAP_PARULA,
- view_img=True,
- shape="circle",
- count_reg_pts=line_points,
- names=model.names,
+ region=line_points,
)
while cap.isOpened():
@@ -105,9 +100,7 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
if not success:
print("Video frame is empty or video processing has been successfully completed.")
break
-
- tracks = model.track(im0, persist=True, show=False)
- im0 = heatmap_obj.generate_heatmap(im0, tracks)
+ im0 = heatmap.generate_heatmap(im0)
video_writer.write(im0)
cap.release()
@@ -120,10 +113,9 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
```python
import cv2
- from ultralytics import YOLO, solutions
+ from ultralytics import solutions
- model = YOLO("yolo11n.pt")
- cap = cv2.VideoCapture("path/to/video/file.mp4")
+ cap = cv2.VideoCapture("Path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
@@ -134,12 +126,11 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360), (20, 400)]
# Init heatmap
- heatmap_obj = solutions.Heatmap(
+ heatmap = solutions.Heatmap(
+ show=True,
+ model="yolo11n.pt",
colormap=cv2.COLORMAP_PARULA,
- view_img=True,
- shape="circle",
- count_reg_pts=region_points,
- names=model.names,
+ region=region_points,
)
while cap.isOpened():
@@ -147,9 +138,7 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
if not success:
print("Video frame is empty or video processing has been successfully completed.")
break
-
- tracks = model.track(im0, persist=True, show=False)
- im0 = heatmap_obj.generate_heatmap(im0, tracks)
+ im0 = heatmap.generate_heatmap(im0)
video_writer.write(im0)
cap.release()
@@ -162,10 +151,9 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
```python
import cv2
- from ultralytics import YOLO, solutions
+ from ultralytics import solutions
- model = YOLO("yolo11n.pt")
- cap = cv2.VideoCapture("path/to/video/file.mp4")
+ cap = cv2.VideoCapture("Path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
@@ -176,12 +164,11 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
# Init heatmap
- heatmap_obj = solutions.Heatmap(
+ heatmap = solutions.Heatmap(
+ show=True,
+ model="yolo11n.pt",
colormap=cv2.COLORMAP_PARULA,
- view_img=True,
- shape="circle",
- count_reg_pts=region_points,
- names=model.names,
+ region=region_points,
)
while cap.isOpened():
@@ -189,9 +176,7 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
if not success:
print("Video frame is empty or video processing has been successfully completed.")
break
-
- tracks = model.track(im0, persist=True, show=False)
- im0 = heatmap_obj.generate_heatmap(im0, tracks)
+ im0 = heatmap.generate_heatmap(im0)
video_writer.write(im0)
cap.release()
@@ -199,54 +184,25 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
cv2.destroyAllWindows()
```
- === "Im0"
-
- ```python
- import cv2
-
- from ultralytics import YOLO, solutions
-
- model = YOLO("yolo11n.pt") # YOLO11 custom/pretrained model
-
- im0 = cv2.imread("path/to/image.png") # path to image file
- h, w = im0.shape[:2] # image height and width
-
- # Heatmap Init
- heatmap_obj = solutions.Heatmap(
- colormap=cv2.COLORMAP_PARULA,
- view_img=True,
- shape="circle",
- names=model.names,
- )
-
- results = model.track(im0, persist=True)
- im0 = heatmap_obj.generate_heatmap(im0, tracks=results)
- cv2.imwrite("ultralytics_output.png", im0)
- ```
-
=== "Specific Classes"
```python
import cv2
- from ultralytics import YOLO, solutions
+ from ultralytics import solutions
- model = YOLO("yolo11n.pt")
- cap = cv2.VideoCapture("path/to/video/file.mp4")
+ cap = cv2.VideoCapture("Path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
# Video writer
video_writer = cv2.VideoWriter("heatmap_output.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
- classes_for_heatmap = [0, 2] # classes for heatmap
-
# Init heatmap
- heatmap_obj = solutions.Heatmap(
- colormap=cv2.COLORMAP_PARULA,
- view_img=True,
- shape="circle",
- names=model.names,
+ heatmap = solutions.Heatmap(
+ show=True,
+ model="yolo11n.pt",
+ classes=[0, 2],
)
while cap.isOpened():
@@ -254,9 +210,7 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
if not success:
print("Video frame is empty or video processing has been successfully completed.")
break
- tracks = model.track(im0, persist=True, show=False, classes=classes_for_heatmap)
-
- im0 = heatmap_obj.generate_heatmap(im0, tracks)
+ im0 = heatmap.generate_heatmap(im0)
video_writer.write(im0)
cap.release()
@@ -266,21 +220,14 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult
### Arguments `Heatmap()`
-| Name | Type | Default | Description |
-| ------------------ | ---------------- | ------------------ | ----------------------------------------------------------------- |
-| `names` | `list` | `None` | Dictionary of class names. |
-| `colormap` | `int` | `cv2.COLORMAP_JET` | Colormap to use for the heatmap. |
-| `view_img` | `bool` | `False` | Whether to display the image with the heatmap overlay. |
-| `view_in_counts` | `bool` | `True` | Whether to display the count of objects entering the region. |
-| `view_out_counts` | `bool` | `True` | Whether to display the count of objects exiting the region. |
-| `count_reg_pts` | `list` or `None` | `None` | Points defining the counting region (either a line or a polygon). |
-| `count_txt_color` | `tuple` | `(0, 0, 0)` | Text color for displaying counts. |
-| `count_bg_color` | `tuple` | `(255, 255, 255)` | Background color for displaying counts. |
-| `count_reg_color` | `tuple` | `(255, 0, 255)` | Color for the counting region. |
-| `region_thickness` | `int` | `5` | Thickness of the region line. |
-| `line_dist_thresh` | `int` | `15` | Distance threshold for line-based counting. |
-| `line_thickness` | `int` | `2` | Thickness of the lines used in drawing. |
-| `shape` | `str` | `"circle"` | Shape of the heatmap blobs ('circle' or 'rect'). |
+| Name | Type | Default | Description |
+| ------------ | ------ | ------------------ | ----------------------------------------------------------------- |
+| `colormap` | `int` | `cv2.COLORMAP_JET` | Colormap to use for the heatmap. |
+| `show` | `bool` | `False` | Whether to display the image with the heatmap overlay. |
+| `show_in` | `bool` | `True` | Whether to display the count of objects entering the region. |
+| `show_out` | `bool` | `True` | Whether to display the count of objects exiting the region. |
+| `region` | `list` | `None` | Points defining the counting region (either a line or a polygon). |
+| `line_width` | `int` | `2` | Thickness of the lines used in drawing. |
### Arguments `model.track`
@@ -328,18 +275,16 @@ Yes, Ultralytics YOLO11 supports object tracking and heatmap generation concurre
```python
import cv2
-from ultralytics import YOLO, solutions
+from ultralytics import solutions
-model = YOLO("yolo11n.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4")
-heatmap_obj = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, view_img=True, shape="circle", names=model.names)
+heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, show=True, model="yolo11n.pt")
while cap.isOpened():
success, im0 = cap.read()
if not success:
break
- tracks = model.track(im0, persist=True, show=False)
- im0 = heatmap_obj.generate_heatmap(im0, tracks)
+ im0 = heatmap.generate_heatmap(im0)
cv2.imshow("Heatmap", im0)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
@@ -361,19 +306,16 @@ You can visualize specific object classes by specifying the desired classes in t
```python
import cv2
-from ultralytics import YOLO, solutions
+from ultralytics import solutions
-model = YOLO("yolo11n.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4")
-heatmap_obj = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, view_img=True, shape="circle", names=model.names)
+heatmap = solutions.Heatmap(show=True, model="yolo11n.pt", classes=[0, 2])
-classes_for_heatmap = [0, 2] # Classes to visualize
while cap.isOpened():
success, im0 = cap.read()
if not success:
break
- tracks = model.track(im0, persist=True, show=False, classes=classes_for_heatmap)
- im0 = heatmap_obj.generate_heatmap(im0, tracks)
+ im0 = heatmap.generate_heatmap(im0)
cv2.imshow("Heatmap", im0)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
diff --git a/docs/en/reference/utils/torch_utils.md b/docs/en/reference/utils/torch_utils.md
index 4f8f3d1b9c..ac31ec2c33 100644
--- a/docs/en/reference/utils/torch_utils.md
+++ b/docs/en/reference/utils/torch_utils.md
@@ -35,6 +35,10 @@ keywords: Ultralytics, torch utils, model optimization, device selection, infere
+## ::: ultralytics.utils.torch_utils.get_gpu_info
+
+
+
## ::: ultralytics.utils.torch_utils.select_device
diff --git a/docs/overrides/javascript/extra.js b/docs/overrides/javascript/extra.js
index d6ea4daf99..b106acdfe0 100644
--- a/docs/overrides/javascript/extra.js
+++ b/docs/overrides/javascript/extra.js
@@ -94,13 +94,13 @@ document.addEventListener("DOMContentLoaded", () => {
fixedPositionYOffset: "3rem",
chatButtonBgColor: "#E1FF25",
baseSettings: {
- apiKey: "13dfec2e75982bc9bae3199a08e13b86b5fbacd64e9b2f89", // required
- integrationId: "cm1shscmm00y26sj83lgxzvkw", // required
- organizationId: "org_e3869az6hQZ0mXdF", // required
- primaryBrandColor: "#E1FF25", // Ultralytics brand color
+ apiKey: "13dfec2e75982bc9bae3199a08e13b86b5fbacd64e9b2f89",
+ integrationId: "cm1shscmm00y26sj83lgxzvkw",
+ organizationId: "org_e3869az6hQZ0mXdF",
+ primaryBrandColor: "#E1FF25",
organizationDisplayName: "Ultralytics",
theme: {
- stylesheetUrls: ["../stylesheets/style.css"],
+ stylesheetUrls: ["/stylesheets/style.css"],
},
// ...optional settings
},
diff --git a/tests/test_solutions.py b/tests/test_solutions.py
index 485c795ee4..8ec68d260f 100644
--- a/tests/test_solutions.py
+++ b/tests/test_solutions.py
@@ -19,8 +19,8 @@ def test_major_solutions():
cap = cv2.VideoCapture("solutions_ci_demo.mp4")
assert cap.isOpened(), "Error reading video file"
region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
- # counter = solutions.ObjectCounter(reg_pts=region_points, names=names, view_img=False)
- heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, names=names, view_img=False)
+ counter = solutions.ObjectCounter(region=region_points, model="yolo11n.pt", show=False)
+ heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False)
speed = solutions.SpeedEstimator(reg_pts=region_points, names=names, view_img=False)
queue = solutions.QueueManager(names=names, reg_pts=region_points, view_img=False)
while cap.isOpened():
@@ -29,8 +29,8 @@ def test_major_solutions():
break
original_im0 = im0.copy()
tracks = model.track(im0, persist=True, show=False)
- # _ = counter.start_counting(original_im0.copy(), tracks)
- _ = heatmap.generate_heatmap(original_im0.copy(), tracks)
+ _ = counter.count(original_im0.copy())
+ _ = heatmap.generate_heatmap(original_im0.copy())
_ = speed.estimate_speed(original_im0.copy(), tracks)
_ = queue.process_queue(original_im0.copy(), tracks)
cap.release()
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index b5e68098b5..ce089ca6b5 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-__version__ = "8.3.5"
+__version__ = "8.3.6"
import os
diff --git a/ultralytics/cfg/solutions/default.yaml b/ultralytics/cfg/solutions/default.yaml
index f22dce2c91..fd375223cd 100644
--- a/ultralytics/cfg/solutions/default.yaml
+++ b/ultralytics/cfg/solutions/default.yaml
@@ -10,7 +10,7 @@ show: True # Flag to control whether to display output image or not
show_in: True # Flag to display objects moving *into* the defined region
show_out: True # Flag to display objects moving *out of* the defined region
classes: # To count specific classes
-
-up_angle: 145.0 # workouts up_angle for counts, 145.0 is default value
-down_angle: 90 # workouts down_angle for counts, 90 is default value
-kpts: [6, 8, 10] # keypoints for workouts monitoring
+up_angle: 145.0 # Workouts up_angle for counts, 145.0 is default value
+down_angle: 90 # Workouts down_angle for counts, 90 is default value
+kpts: [6, 8, 10] # Keypoints for workouts monitoring
+colormap: # Colormap for heatmap
diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py
index 9fcc697040..aadf63b023 100644
--- a/ultralytics/engine/trainer.py
+++ b/ultralytics/engine/trainer.py
@@ -469,11 +469,11 @@ class BaseTrainer:
if RANK in {-1, 0}:
# Do final val with best.pt
- LOGGER.info(
- f"\n{epoch - self.start_epoch + 1} epochs completed in "
- f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
- )
+ epochs = epoch - self.start_epoch + 1 # total training epochs
+ seconds = time.time() - self.train_time_start # total training seconds
+ LOGGER.info(f"\n{epochs} epochs completed in {seconds / 3600:.3f} hours.")
self.final_eval()
+ self.validator.metrics.training = {"epochs": epochs, "seconds": seconds} # add training speed
if self.args.plots:
self.plot_metrics()
self.run_callbacks("on_train_end")
diff --git a/ultralytics/solutions/heatmap.py b/ultralytics/solutions/heatmap.py
index 79e37b9c2a..30d1817d76 100644
--- a/ultralytics/solutions/heatmap.py
+++ b/ultralytics/solutions/heatmap.py
@@ -1,249 +1,93 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-from collections import defaultdict
-
import cv2
import numpy as np
-from ultralytics.utils.checks import check_imshow, check_requirements
+from ultralytics.solutions.object_counter import ObjectCounter # Import object counter class
from ultralytics.utils.plotting import Annotator
-check_requirements("shapely>=2.0.0")
-
-from shapely.geometry import LineString, Point, Polygon
-
-class Heatmap:
+class Heatmap(ObjectCounter):
"""A class to draw heatmaps in real-time video stream based on their tracks."""
- def __init__(
- self,
- names,
- colormap=cv2.COLORMAP_JET,
- view_img=False,
- view_in_counts=True,
- view_out_counts=True,
- count_reg_pts=None,
- count_txt_color=(0, 0, 0),
- count_bg_color=(255, 255, 255),
- count_reg_color=(255, 0, 255),
- region_thickness=5,
- line_dist_thresh=15,
- line_thickness=2,
- shape="circle",
- ):
- """Initializes the heatmap class with default values for Visual, Image, track, count and heatmap parameters."""
- # Visual information
- self.annotator = None
- self.view_img = view_img
- self.shape = shape
-
- self.initialized = False
- self.names = names # Classes names
-
- # Image information
- self.im0 = None
- self.tf = line_thickness
- self.view_in_counts = view_in_counts
- self.view_out_counts = view_out_counts
-
- # Heatmap colormap and heatmap np array
- self.colormap = colormap
- self.heatmap = None
-
- # Predict/track information
- self.boxes = []
- self.track_ids = []
- self.clss = []
- self.track_history = defaultdict(list)
-
- # Region & Line Information
- self.counting_region = None
- self.line_dist_thresh = line_dist_thresh
- self.region_thickness = region_thickness
- self.region_color = count_reg_color
-
- # Object Counting Information
- self.in_counts = 0
- self.out_counts = 0
- self.count_ids = []
- self.class_wise_count = {}
- self.count_txt_color = count_txt_color
- self.count_bg_color = count_bg_color
- self.cls_txtdisplay_gap = 50
-
- # Check if environment supports imshow
- self.env_check = check_imshow(warn=True)
-
- # Region and line selection
- self.count_reg_pts = count_reg_pts
- print(self.count_reg_pts)
- if self.count_reg_pts is not None:
- if len(self.count_reg_pts) == 2:
- print("Line Counter Initiated.")
- self.counting_region = LineString(self.count_reg_pts)
- elif len(self.count_reg_pts) >= 3:
- print("Polygon Counter Initiated.")
- self.counting_region = Polygon(self.count_reg_pts)
- else:
- print("Invalid Region points provided, region_points must be 2 for lines or >= 3 for polygons.")
- print("Using Line Counter Now")
- self.counting_region = LineString(self.count_reg_pts)
-
- # Shape of heatmap, if not selected
- if self.shape not in {"circle", "rect"}:
- print("Unknown shape value provided, 'circle' & 'rect' supported")
- print("Using Circular shape now")
- self.shape = "circle"
-
- def extract_results(self, tracks):
- """
- Extracts results from the provided data.
+ def __init__(self, **kwargs):
+ """Initializes function for heatmap class with default values."""
+ super().__init__(**kwargs)
- Args:
- tracks (list): List of tracks obtained from the object tracking process.
- """
- if tracks[0].boxes.id is not None:
- self.boxes = tracks[0].boxes.xyxy.cpu()
- self.clss = tracks[0].boxes.cls.tolist()
- self.track_ids = tracks[0].boxes.id.int().tolist()
+ self.initialized = False # bool variable for heatmap initialization
+ if self.region is not None: # check if user provided the region coordinates
+ self.initialize_region()
+
+ # store colormap
+ self.colormap = cv2.COLORMAP_PARULA if self.CFG["colormap"] is None else self.CFG["colormap"]
- def generate_heatmap(self, im0, tracks):
+ def heatmap_effect(self, box):
"""
- Generate heatmap based on tracking data.
+ Efficient calculation of heatmap area and effect location for applying colormap.
Args:
- im0 (nd array): Image
- tracks (list): List of tracks obtained from the object tracking process.
+ box (list): Bounding Box coordinates data [x0, y0, x1, y1]
"""
- self.im0 = im0
+ x0, y0, x1, y1 = map(int, box)
+ radius_squared = (min(x1 - x0, y1 - y0) // 2) ** 2
- # Initialize heatmap only once
- if not self.initialized:
- self.heatmap = np.zeros((int(self.im0.shape[0]), int(self.im0.shape[1])), dtype=np.float32)
- self.initialized = True
+ # Create a meshgrid with region of interest (ROI) for vectorized distance calculations
+ xv, yv = np.meshgrid(np.arange(x0, x1), np.arange(y0, y1))
- self.heatmap *= 0.99 # decay factor
+ # Calculate squared distances from the center
+ dist_squared = (xv - ((x0 + x1) // 2)) ** 2 + (yv - ((y0 + y1) // 2)) ** 2
- self.extract_results(tracks)
- self.annotator = Annotator(self.im0, self.tf, None)
+ # Create a mask of points within the radius
+ within_radius = dist_squared <= radius_squared
- if self.track_ids:
- # Draw counting region
- if self.count_reg_pts is not None:
- self.annotator.draw_region(
- reg_pts=self.count_reg_pts, color=self.region_color, thickness=self.region_thickness
- )
+ # Update only the values within the bounding box in a single vectorized operation
+ self.heatmap[y0:y1, x0:x1][within_radius] += 2
- for box, cls, track_id in zip(self.boxes, self.clss, self.track_ids):
- # Store class info
- if self.names[cls] not in self.class_wise_count:
- self.class_wise_count[self.names[cls]] = {"IN": 0, "OUT": 0}
-
- if self.shape == "circle":
- center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2))
- radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2
+ def generate_heatmap(self, im0):
+ """
+ Generate heatmap for each frame using Ultralytics.
- y, x = np.ogrid[0 : self.heatmap.shape[0], 0 : self.heatmap.shape[1]]
- mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius**2
+ Args:
+ im0 (ndarray): Input image array for processing
+ Returns:
+ im0 (ndarray): Processed image for further usage
+ """
+ self.heatmap = np.zeros_like(im0, dtype=np.float32) * 0.99 if not self.initialized else self.heatmap
+ self.initialized = True # Initialize heatmap only once
- self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += (
- 2 * mask[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])]
- )
+ self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator
+ self.extract_tracks(im0) # Extract tracks
- else:
- self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += 2
+ # Iterate over bounding boxes, track ids and classes index
+ for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss):
+ # Draw bounding box and counting region
+ self.heatmap_effect(box)
- # Store tracking hist
- track_line = self.track_history[track_id]
- track_line.append((float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2)))
- if len(track_line) > 30:
- track_line.pop(0)
+ if self.region is not None:
+ self.annotator.draw_region(reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2)
+ self.store_tracking_history(track_id, box) # Store track history
+ self.store_classwise_counts(cls) # store classwise counts in dict
+ # Store tracking previous position and perform object counting
prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
+ self.count_objects(self.track_line, box, track_id, prev_position, cls) # Perform object counting
- if self.count_reg_pts is not None:
- # Count objects in any polygon
- if len(self.count_reg_pts) >= 3:
- is_inside = self.counting_region.contains(Point(track_line[-1]))
-
- if prev_position is not None and is_inside and track_id not in self.count_ids:
- self.count_ids.append(track_id)
-
- if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
- self.in_counts += 1
- self.class_wise_count[self.names[cls]]["IN"] += 1
- else:
- self.out_counts += 1
- self.class_wise_count[self.names[cls]]["OUT"] += 1
-
- # Count objects using line
- elif len(self.count_reg_pts) == 2:
- if prev_position is not None and track_id not in self.count_ids:
- distance = Point(track_line[-1]).distance(self.counting_region)
- if distance < self.line_dist_thresh and track_id not in self.count_ids:
- self.count_ids.append(track_id)
-
- if (box[0] - prev_position[0]) * (
- self.counting_region.centroid.x - prev_position[0]
- ) > 0:
- self.in_counts += 1
- self.class_wise_count[self.names[cls]]["IN"] += 1
- else:
- self.out_counts += 1
- self.class_wise_count[self.names[cls]]["OUT"] += 1
-
- else:
- for box, cls in zip(self.boxes, self.clss):
- if self.shape == "circle":
- center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2))
- radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2
-
- y, x = np.ogrid[0 : self.heatmap.shape[0], 0 : self.heatmap.shape[1]]
- mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius**2
-
- self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += (
- 2 * mask[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])]
- )
-
- else:
- self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += 2
-
- if self.count_reg_pts is not None:
- labels_dict = {}
-
- for key, value in self.class_wise_count.items():
- if value["IN"] != 0 or value["OUT"] != 0:
- if not self.view_in_counts and not self.view_out_counts:
- continue
- elif not self.view_in_counts:
- labels_dict[str.capitalize(key)] = f"OUT {value['OUT']}"
- elif not self.view_out_counts:
- labels_dict[str.capitalize(key)] = f"IN {value['IN']}"
- else:
- labels_dict[str.capitalize(key)] = f"IN {value['IN']} OUT {value['OUT']}"
-
- if labels_dict is not None:
- self.annotator.display_analytics(self.im0, labels_dict, self.count_txt_color, self.count_bg_color, 10)
+ self.display_counts(im0) if self.region is not None else None # Display the counts on the frame
# Normalize, apply colormap to heatmap and combine with original image
- heatmap_normalized = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX)
- heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), self.colormap)
- self.im0 = cv2.addWeighted(self.im0, 0.5, heatmap_colored, 0.5, 0)
-
- if self.env_check and self.view_img:
- self.display_frames()
-
- return self.im0
-
- def display_frames(self):
- """Display frame."""
- cv2.imshow("Ultralytics Heatmap", self.im0)
-
- if cv2.waitKey(1) & 0xFF == ord("q"):
- return
-
-
-if __name__ == "__main__":
- classes_names = {0: "person", 1: "car"} # example class names
- heatmap = Heatmap(classes_names)
+ im0 = (
+ im0
+ if self.track_data.id is None
+ else cv2.addWeighted(
+ im0,
+ 0.5,
+ cv2.applyColorMap(
+ cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8), self.colormap
+ ),
+ 0.5,
+ 0,
+ )
+ )
+
+ self.display_output(im0) # display output with base class function
+ return im0 # return output image for more usage
diff --git a/ultralytics/solutions/object_counter.py b/ultralytics/solutions/object_counter.py
index 073599a4c7..5fdaef258a 100644
--- a/ultralytics/solutions/object_counter.py
+++ b/ultralytics/solutions/object_counter.py
@@ -19,8 +19,7 @@ class ObjectCounter(BaseSolution):
self.out_count = 0 # Counter for objects moving outward
self.counted_ids = [] # List of IDs of objects that have been counted
self.classwise_counts = {} # Dictionary for counts, categorized by object class
-
- self.initialize_region() # Setup region and counting areas
+ self.region_initialized = False # Bool variable for region initialization
self.show_in = self.CFG["show_in"]
self.show_out = self.CFG["show_out"]
@@ -99,6 +98,10 @@ class ObjectCounter(BaseSolution):
Returns
im0 (ndarray): The processed image for more usage
"""
+ if not self.region_initialized:
+ self.initialize_region()
+ self.region_initialized = True
+
self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator
self.extract_tracks(im0) # Extract tracks
@@ -107,21 +110,20 @@ class ObjectCounter(BaseSolution):
) # Draw region
# Iterate over bounding boxes, track ids and classes index
- if self.track_data is not None and self.track_data.id is not None:
- for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss):
- # Draw bounding box and counting region
- self.annotator.box_label(box, label=self.names[cls], color=colors(track_id, True))
- self.store_tracking_history(track_id, box) # Store track history
- self.store_classwise_counts(cls) # store classwise counts in dict
-
- # Draw centroid of objects
- self.annotator.draw_centroid_and_tracks(
- self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width
- )
-
- # store previous position of track for object counting
- prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
- self.count_objects(self.track_line, box, track_id, prev_position, cls) # Perform object counting
+ for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss):
+ # Draw bounding box and counting region
+ self.annotator.box_label(box, label=self.names[cls], color=colors(track_id, True))
+ self.store_tracking_history(track_id, box) # Store track history
+ self.store_classwise_counts(cls) # store classwise counts in dict
+
+ # Draw centroid of objects
+ self.annotator.draw_centroid_and_tracks(
+ self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width
+ )
+
+ # store previous position of track for object counting
+ prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
+ self.count_objects(self.track_line, box, track_id, prev_position, cls) # Perform object counting
self.display_counts(im0) # Display the counts on the frame
self.display_output(im0) # display output with base class function
diff --git a/ultralytics/solutions/solutions.py b/ultralytics/solutions/solutions.py
index ed53de654b..14b3f80a52 100644
--- a/ultralytics/solutions/solutions.py
+++ b/ultralytics/solutions/solutions.py
@@ -57,7 +57,8 @@ class BaseSolution:
self.clss = self.track_data.cls.cpu().tolist()
self.track_ids = self.track_data.id.int().cpu().tolist()
else:
- LOGGER.warning("WARNING ⚠️ tracks none, no keypoints will be considered.")
+ LOGGER.warning("WARNING ⚠️ no tracks found!")
+ self.boxes, self.clss, self.track_ids = [], [], []
def store_tracking_history(self, track_id, box):
"""
diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py
index 85eccf67e3..2a461b037d 100644
--- a/ultralytics/utils/checks.py
+++ b/ultralytics/utils/checks.py
@@ -593,20 +593,29 @@ def collect_system_info():
import psutil
from ultralytics.utils import ENVIRONMENT # scope to avoid circular import
- from ultralytics.utils.torch_utils import get_cpu_info
+ from ultralytics.utils.torch_utils import get_cpu_info, get_gpu_info
- ram_info = psutil.virtual_memory().total / (1024**3) # Convert bytes to GB
+ gib = 1 << 30 # bytes per GiB
+ cuda = torch and torch.cuda.is_available()
check_yolo()
- LOGGER.info(
- f"\n{'OS':<20}{platform.platform()}\n"
- f"{'Environment':<20}{ENVIRONMENT}\n"
- f"{'Python':<20}{PYTHON_VERSION}\n"
- f"{'Install':<20}{'git' if IS_GIT_DIR else 'pip' if IS_PIP_PACKAGE else 'other'}\n"
- f"{'RAM':<20}{ram_info:.2f} GB\n"
- f"{'CPU':<20}{get_cpu_info()}\n"
- f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n"
- )
+ total, used, free = shutil.disk_usage("/")
+
+ info_dict = {
+ "OS": platform.platform(),
+ "Environment": ENVIRONMENT,
+ "Python": PYTHON_VERSION,
+ "Install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other",
+ "RAM": f"{psutil.virtual_memory().total / gib:.2f} GB",
+ "Disk": f"{(total - free) / gib:.1f}/{total / gib:.1f} GB",
+ "CPU": get_cpu_info(),
+ "CPU count": os.cpu_count(),
+ "GPU": get_gpu_info(index=0) if cuda else None,
+ "GPU count": torch.cuda.device_count() if cuda else None,
+ "CUDA": torch.version.cuda if cuda else None,
+ }
+ LOGGER.info("\n" + "\n".join(f"{k:<20}{v}" for k, v in info_dict.items()) + "\n")
+ package_info = {}
for r in parse_requirements(package="ultralytics"):
try:
current = metadata.version(r.name)
@@ -614,17 +623,24 @@ def collect_system_info():
except metadata.PackageNotFoundError:
current = "(not installed)"
is_met = "❌ "
- LOGGER.info(f"{r.name:<20}{is_met}{current}{r.specifier}")
+ package_info[r.name] = f"{is_met}{current}{r.specifier}"
+ LOGGER.info(f"{r.name:<20}{package_info[r.name]}")
+
+ info_dict["Package Info"] = package_info
if is_github_action_running():
- LOGGER.info(
- f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n"
- f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n"
- f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n"
- f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n"
- f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n"
- f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n"
- )
+ github_info = {
+ "RUNNER_OS": os.getenv("RUNNER_OS"),
+ "GITHUB_EVENT_NAME": os.getenv("GITHUB_EVENT_NAME"),
+ "GITHUB_WORKFLOW": os.getenv("GITHUB_WORKFLOW"),
+ "GITHUB_ACTOR": os.getenv("GITHUB_ACTOR"),
+ "GITHUB_REPOSITORY": os.getenv("GITHUB_REPOSITORY"),
+ "GITHUB_REPOSITORY_OWNER": os.getenv("GITHUB_REPOSITORY_OWNER"),
+ }
+ LOGGER.info("\n" + "\n".join(f"{k}: {v}" for k, v in github_info.items()))
+ info_dict["GitHub Info"] = github_info
+
+ return info_dict
def check_amp(model):
diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py
index 00176d3033..db84ed6945 100644
--- a/ultralytics/utils/torch_utils.py
+++ b/ultralytics/utils/torch_utils.py
@@ -123,6 +123,12 @@ def get_cpu_info():
return PERSISTENT_CACHE.get("cpu_info", "unknown")
+def get_gpu_info(index):
+ """Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'."""
+ properties = torch.cuda.get_device_properties(index)
+ return f"{properties.name}, {properties.total_memory / (1 << 20):.0f}MiB"
+
+
def select_device(device="", batch=0, newline=False, verbose=True):
"""
Selects the appropriate PyTorch device based on the provided arguments.
@@ -208,8 +214,7 @@ def select_device(device="", batch=0, newline=False, verbose=True):
)
space = " " * (len(s) + 1)
for i, d in enumerate(devices):
- p = torch.cuda.get_device_properties(i)
- s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
+ s += f"{'' if i == 0 else space}CUDA:{d} ({get_gpu_info(i)})\n" # bytes to MB
arg = "cuda:0"
elif mps and TORCH_2_0 and torch.backends.mps.is_available():
# Prefer MPS if available