From 48843119913960a44901eb664f6d0f30544758dd Mon Sep 17 00:00:00 2001 From: Muhammad Rizwan Munawar Date: Sun, 6 Oct 2024 22:20:58 +0500 Subject: [PATCH] Update `heatmaps` solution (#16720) Co-authored-by: UltralyticsAssistant --- docs/en/guides/heatmaps.md | 158 +++++-------- tests/test_solutions.py | 8 +- ultralytics/cfg/solutions/default.yaml | 8 +- ultralytics/solutions/heatmap.py | 282 ++++++------------------ ultralytics/solutions/object_counter.py | 36 +-- ultralytics/solutions/solutions.py | 3 +- 6 files changed, 142 insertions(+), 353 deletions(-) diff --git a/docs/en/guides/heatmaps.md b/docs/en/guides/heatmaps.md index 4e0a665fa0..f33993134f 100644 --- a/docs/en/guides/heatmaps.md +++ b/docs/en/guides/heatmaps.md @@ -41,10 +41,9 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult ```python import cv2 - from ultralytics import YOLO, solutions + from ultralytics import solutions - model = YOLO("yolo11n.pt") - cap = cv2.VideoCapture("path/to/video/file.mp4") + cap = cv2.VideoCapture("Path/to/video/file.mp4") assert cap.isOpened(), "Error reading video file" w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) @@ -52,11 +51,10 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult video_writer = cv2.VideoWriter("heatmap_output.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) # Init heatmap - heatmap_obj = solutions.Heatmap( + heatmap = solutions.Heatmap( + show=True, + model="yolo11n.pt", colormap=cv2.COLORMAP_PARULA, - view_img=True, - shape="circle", - names=model.names, ) while cap.isOpened(): @@ -64,9 +62,7 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult if not success: print("Video frame is empty or video processing has been successfully completed.") break - tracks = model.track(im0, persist=True, show=False) - - im0 = heatmap_obj.generate_heatmap(im0, tracks) + im0 = heatmap.generate_heatmap(im0) video_writer.write(im0) cap.release() @@ -79,25 +75,24 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult ```python import cv2 - from ultralytics import YOLO, solutions + from ultralytics import solutions - model = YOLO("yolo11n.pt") - cap = cv2.VideoCapture("path/to/video/file.mp4") + cap = cv2.VideoCapture("Path/to/video/file.mp4") assert cap.isOpened(), "Error reading video file" w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) # Video writer video_writer = cv2.VideoWriter("heatmap_output.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) - line_points = [(20, 400), (1080, 404)] # line for object counting + # line for object counting + line_points = [(20, 400), (1080, 404)] # Init heatmap - heatmap_obj = solutions.Heatmap( + heatmap = solutions.Heatmap( + show=True, + model="yolo11n.pt", colormap=cv2.COLORMAP_PARULA, - view_img=True, - shape="circle", - count_reg_pts=line_points, - names=model.names, + region=line_points, ) while cap.isOpened(): @@ -105,9 +100,7 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult if not success: print("Video frame is empty or video processing has been successfully completed.") break - - tracks = model.track(im0, persist=True, show=False) - im0 = heatmap_obj.generate_heatmap(im0, tracks) + im0 = heatmap.generate_heatmap(im0) video_writer.write(im0) cap.release() @@ -120,10 +113,9 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult ```python import cv2 - from ultralytics import YOLO, solutions + from ultralytics import solutions - model = YOLO("yolo11n.pt") - cap = cv2.VideoCapture("path/to/video/file.mp4") + cap = cv2.VideoCapture("Path/to/video/file.mp4") assert cap.isOpened(), "Error reading video file" w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) @@ -134,12 +126,11 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360), (20, 400)] # Init heatmap - heatmap_obj = solutions.Heatmap( + heatmap = solutions.Heatmap( + show=True, + model="yolo11n.pt", colormap=cv2.COLORMAP_PARULA, - view_img=True, - shape="circle", - count_reg_pts=region_points, - names=model.names, + region=region_points, ) while cap.isOpened(): @@ -147,9 +138,7 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult if not success: print("Video frame is empty or video processing has been successfully completed.") break - - tracks = model.track(im0, persist=True, show=False) - im0 = heatmap_obj.generate_heatmap(im0, tracks) + im0 = heatmap.generate_heatmap(im0) video_writer.write(im0) cap.release() @@ -162,10 +151,9 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult ```python import cv2 - from ultralytics import YOLO, solutions + from ultralytics import solutions - model = YOLO("yolo11n.pt") - cap = cv2.VideoCapture("path/to/video/file.mp4") + cap = cv2.VideoCapture("Path/to/video/file.mp4") assert cap.isOpened(), "Error reading video file" w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) @@ -176,12 +164,11 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)] # Init heatmap - heatmap_obj = solutions.Heatmap( + heatmap = solutions.Heatmap( + show=True, + model="yolo11n.pt", colormap=cv2.COLORMAP_PARULA, - view_img=True, - shape="circle", - count_reg_pts=region_points, - names=model.names, + region=region_points, ) while cap.isOpened(): @@ -189,9 +176,7 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult if not success: print("Video frame is empty or video processing has been successfully completed.") break - - tracks = model.track(im0, persist=True, show=False) - im0 = heatmap_obj.generate_heatmap(im0, tracks) + im0 = heatmap.generate_heatmap(im0) video_writer.write(im0) cap.release() @@ -199,54 +184,25 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult cv2.destroyAllWindows() ``` - === "Im0" - - ```python - import cv2 - - from ultralytics import YOLO, solutions - - model = YOLO("yolo11n.pt") # YOLO11 custom/pretrained model - - im0 = cv2.imread("path/to/image.png") # path to image file - h, w = im0.shape[:2] # image height and width - - # Heatmap Init - heatmap_obj = solutions.Heatmap( - colormap=cv2.COLORMAP_PARULA, - view_img=True, - shape="circle", - names=model.names, - ) - - results = model.track(im0, persist=True) - im0 = heatmap_obj.generate_heatmap(im0, tracks=results) - cv2.imwrite("ultralytics_output.png", im0) - ``` - === "Specific Classes" ```python import cv2 - from ultralytics import YOLO, solutions + from ultralytics import solutions - model = YOLO("yolo11n.pt") - cap = cv2.VideoCapture("path/to/video/file.mp4") + cap = cv2.VideoCapture("Path/to/video/file.mp4") assert cap.isOpened(), "Error reading video file" w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) # Video writer video_writer = cv2.VideoWriter("heatmap_output.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) - classes_for_heatmap = [0, 2] # classes for heatmap - # Init heatmap - heatmap_obj = solutions.Heatmap( - colormap=cv2.COLORMAP_PARULA, - view_img=True, - shape="circle", - names=model.names, + heatmap = solutions.Heatmap( + show=True, + model="yolo11n.pt", + classes=[0, 2], ) while cap.isOpened(): @@ -254,9 +210,7 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult if not success: print("Video frame is empty or video processing has been successfully completed.") break - tracks = model.track(im0, persist=True, show=False, classes=classes_for_heatmap) - - im0 = heatmap_obj.generate_heatmap(im0, tracks) + im0 = heatmap.generate_heatmap(im0) video_writer.write(im0) cap.release() @@ -266,21 +220,14 @@ A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ult ### Arguments `Heatmap()` -| Name | Type | Default | Description | -| ------------------ | ---------------- | ------------------ | ----------------------------------------------------------------- | -| `names` | `list` | `None` | Dictionary of class names. | -| `colormap` | `int` | `cv2.COLORMAP_JET` | Colormap to use for the heatmap. | -| `view_img` | `bool` | `False` | Whether to display the image with the heatmap overlay. | -| `view_in_counts` | `bool` | `True` | Whether to display the count of objects entering the region. | -| `view_out_counts` | `bool` | `True` | Whether to display the count of objects exiting the region. | -| `count_reg_pts` | `list` or `None` | `None` | Points defining the counting region (either a line or a polygon). | -| `count_txt_color` | `tuple` | `(0, 0, 0)` | Text color for displaying counts. | -| `count_bg_color` | `tuple` | `(255, 255, 255)` | Background color for displaying counts. | -| `count_reg_color` | `tuple` | `(255, 0, 255)` | Color for the counting region. | -| `region_thickness` | `int` | `5` | Thickness of the region line. | -| `line_dist_thresh` | `int` | `15` | Distance threshold for line-based counting. | -| `line_thickness` | `int` | `2` | Thickness of the lines used in drawing. | -| `shape` | `str` | `"circle"` | Shape of the heatmap blobs ('circle' or 'rect'). | +| Name | Type | Default | Description | +| ------------ | ------ | ------------------ | ----------------------------------------------------------------- | +| `colormap` | `int` | `cv2.COLORMAP_JET` | Colormap to use for the heatmap. | +| `show` | `bool` | `False` | Whether to display the image with the heatmap overlay. | +| `show_in` | `bool` | `True` | Whether to display the count of objects entering the region. | +| `show_out` | `bool` | `True` | Whether to display the count of objects exiting the region. | +| `region` | `list` | `None` | Points defining the counting region (either a line or a polygon). | +| `line_width` | `int` | `2` | Thickness of the lines used in drawing. | ### Arguments `model.track` @@ -328,18 +275,16 @@ Yes, Ultralytics YOLO11 supports object tracking and heatmap generation concurre ```python import cv2 -from ultralytics import YOLO, solutions +from ultralytics import solutions -model = YOLO("yolo11n.pt") cap = cv2.VideoCapture("path/to/video/file.mp4") -heatmap_obj = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, view_img=True, shape="circle", names=model.names) +heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, show=True, model="yolo11n.pt") while cap.isOpened(): success, im0 = cap.read() if not success: break - tracks = model.track(im0, persist=True, show=False) - im0 = heatmap_obj.generate_heatmap(im0, tracks) + im0 = heatmap.generate_heatmap(im0) cv2.imshow("Heatmap", im0) if cv2.waitKey(1) & 0xFF == ord("q"): break @@ -361,19 +306,16 @@ You can visualize specific object classes by specifying the desired classes in t ```python import cv2 -from ultralytics import YOLO, solutions +from ultralytics import solutions -model = YOLO("yolo11n.pt") cap = cv2.VideoCapture("path/to/video/file.mp4") -heatmap_obj = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, view_img=True, shape="circle", names=model.names) +heatmap = solutions.Heatmap(show=True, model="yolo11n.pt", classes=[0, 2]) -classes_for_heatmap = [0, 2] # Classes to visualize while cap.isOpened(): success, im0 = cap.read() if not success: break - tracks = model.track(im0, persist=True, show=False, classes=classes_for_heatmap) - im0 = heatmap_obj.generate_heatmap(im0, tracks) + im0 = heatmap.generate_heatmap(im0) cv2.imshow("Heatmap", im0) if cv2.waitKey(1) & 0xFF == ord("q"): break diff --git a/tests/test_solutions.py b/tests/test_solutions.py index 485c795ee4..8ec68d260f 100644 --- a/tests/test_solutions.py +++ b/tests/test_solutions.py @@ -19,8 +19,8 @@ def test_major_solutions(): cap = cv2.VideoCapture("solutions_ci_demo.mp4") assert cap.isOpened(), "Error reading video file" region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)] - # counter = solutions.ObjectCounter(reg_pts=region_points, names=names, view_img=False) - heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, names=names, view_img=False) + counter = solutions.ObjectCounter(region=region_points, model="yolo11n.pt", show=False) + heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False) speed = solutions.SpeedEstimator(reg_pts=region_points, names=names, view_img=False) queue = solutions.QueueManager(names=names, reg_pts=region_points, view_img=False) while cap.isOpened(): @@ -29,8 +29,8 @@ def test_major_solutions(): break original_im0 = im0.copy() tracks = model.track(im0, persist=True, show=False) - # _ = counter.start_counting(original_im0.copy(), tracks) - _ = heatmap.generate_heatmap(original_im0.copy(), tracks) + _ = counter.count(original_im0.copy()) + _ = heatmap.generate_heatmap(original_im0.copy()) _ = speed.estimate_speed(original_im0.copy(), tracks) _ = queue.process_queue(original_im0.copy(), tracks) cap.release() diff --git a/ultralytics/cfg/solutions/default.yaml b/ultralytics/cfg/solutions/default.yaml index f22dce2c91..fd375223cd 100644 --- a/ultralytics/cfg/solutions/default.yaml +++ b/ultralytics/cfg/solutions/default.yaml @@ -10,7 +10,7 @@ show: True # Flag to control whether to display output image or not show_in: True # Flag to display objects moving *into* the defined region show_out: True # Flag to display objects moving *out of* the defined region classes: # To count specific classes - -up_angle: 145.0 # workouts up_angle for counts, 145.0 is default value -down_angle: 90 # workouts down_angle for counts, 90 is default value -kpts: [6, 8, 10] # keypoints for workouts monitoring +up_angle: 145.0 # Workouts up_angle for counts, 145.0 is default value +down_angle: 90 # Workouts down_angle for counts, 90 is default value +kpts: [6, 8, 10] # Keypoints for workouts monitoring +colormap: # Colormap for heatmap diff --git a/ultralytics/solutions/heatmap.py b/ultralytics/solutions/heatmap.py index 79e37b9c2a..30d1817d76 100644 --- a/ultralytics/solutions/heatmap.py +++ b/ultralytics/solutions/heatmap.py @@ -1,249 +1,93 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from collections import defaultdict - import cv2 import numpy as np -from ultralytics.utils.checks import check_imshow, check_requirements +from ultralytics.solutions.object_counter import ObjectCounter # Import object counter class from ultralytics.utils.plotting import Annotator -check_requirements("shapely>=2.0.0") - -from shapely.geometry import LineString, Point, Polygon - -class Heatmap: +class Heatmap(ObjectCounter): """A class to draw heatmaps in real-time video stream based on their tracks.""" - def __init__( - self, - names, - colormap=cv2.COLORMAP_JET, - view_img=False, - view_in_counts=True, - view_out_counts=True, - count_reg_pts=None, - count_txt_color=(0, 0, 0), - count_bg_color=(255, 255, 255), - count_reg_color=(255, 0, 255), - region_thickness=5, - line_dist_thresh=15, - line_thickness=2, - shape="circle", - ): - """Initializes the heatmap class with default values for Visual, Image, track, count and heatmap parameters.""" - # Visual information - self.annotator = None - self.view_img = view_img - self.shape = shape - - self.initialized = False - self.names = names # Classes names - - # Image information - self.im0 = None - self.tf = line_thickness - self.view_in_counts = view_in_counts - self.view_out_counts = view_out_counts - - # Heatmap colormap and heatmap np array - self.colormap = colormap - self.heatmap = None - - # Predict/track information - self.boxes = [] - self.track_ids = [] - self.clss = [] - self.track_history = defaultdict(list) - - # Region & Line Information - self.counting_region = None - self.line_dist_thresh = line_dist_thresh - self.region_thickness = region_thickness - self.region_color = count_reg_color - - # Object Counting Information - self.in_counts = 0 - self.out_counts = 0 - self.count_ids = [] - self.class_wise_count = {} - self.count_txt_color = count_txt_color - self.count_bg_color = count_bg_color - self.cls_txtdisplay_gap = 50 - - # Check if environment supports imshow - self.env_check = check_imshow(warn=True) - - # Region and line selection - self.count_reg_pts = count_reg_pts - print(self.count_reg_pts) - if self.count_reg_pts is not None: - if len(self.count_reg_pts) == 2: - print("Line Counter Initiated.") - self.counting_region = LineString(self.count_reg_pts) - elif len(self.count_reg_pts) >= 3: - print("Polygon Counter Initiated.") - self.counting_region = Polygon(self.count_reg_pts) - else: - print("Invalid Region points provided, region_points must be 2 for lines or >= 3 for polygons.") - print("Using Line Counter Now") - self.counting_region = LineString(self.count_reg_pts) - - # Shape of heatmap, if not selected - if self.shape not in {"circle", "rect"}: - print("Unknown shape value provided, 'circle' & 'rect' supported") - print("Using Circular shape now") - self.shape = "circle" - - def extract_results(self, tracks): - """ - Extracts results from the provided data. + def __init__(self, **kwargs): + """Initializes function for heatmap class with default values.""" + super().__init__(**kwargs) - Args: - tracks (list): List of tracks obtained from the object tracking process. - """ - if tracks[0].boxes.id is not None: - self.boxes = tracks[0].boxes.xyxy.cpu() - self.clss = tracks[0].boxes.cls.tolist() - self.track_ids = tracks[0].boxes.id.int().tolist() + self.initialized = False # bool variable for heatmap initialization + if self.region is not None: # check if user provided the region coordinates + self.initialize_region() + + # store colormap + self.colormap = cv2.COLORMAP_PARULA if self.CFG["colormap"] is None else self.CFG["colormap"] - def generate_heatmap(self, im0, tracks): + def heatmap_effect(self, box): """ - Generate heatmap based on tracking data. + Efficient calculation of heatmap area and effect location for applying colormap. Args: - im0 (nd array): Image - tracks (list): List of tracks obtained from the object tracking process. + box (list): Bounding Box coordinates data [x0, y0, x1, y1] """ - self.im0 = im0 + x0, y0, x1, y1 = map(int, box) + radius_squared = (min(x1 - x0, y1 - y0) // 2) ** 2 - # Initialize heatmap only once - if not self.initialized: - self.heatmap = np.zeros((int(self.im0.shape[0]), int(self.im0.shape[1])), dtype=np.float32) - self.initialized = True + # Create a meshgrid with region of interest (ROI) for vectorized distance calculations + xv, yv = np.meshgrid(np.arange(x0, x1), np.arange(y0, y1)) - self.heatmap *= 0.99 # decay factor + # Calculate squared distances from the center + dist_squared = (xv - ((x0 + x1) // 2)) ** 2 + (yv - ((y0 + y1) // 2)) ** 2 - self.extract_results(tracks) - self.annotator = Annotator(self.im0, self.tf, None) + # Create a mask of points within the radius + within_radius = dist_squared <= radius_squared - if self.track_ids: - # Draw counting region - if self.count_reg_pts is not None: - self.annotator.draw_region( - reg_pts=self.count_reg_pts, color=self.region_color, thickness=self.region_thickness - ) + # Update only the values within the bounding box in a single vectorized operation + self.heatmap[y0:y1, x0:x1][within_radius] += 2 - for box, cls, track_id in zip(self.boxes, self.clss, self.track_ids): - # Store class info - if self.names[cls] not in self.class_wise_count: - self.class_wise_count[self.names[cls]] = {"IN": 0, "OUT": 0} - - if self.shape == "circle": - center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2)) - radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2 + def generate_heatmap(self, im0): + """ + Generate heatmap for each frame using Ultralytics. - y, x = np.ogrid[0 : self.heatmap.shape[0], 0 : self.heatmap.shape[1]] - mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius**2 + Args: + im0 (ndarray): Input image array for processing + Returns: + im0 (ndarray): Processed image for further usage + """ + self.heatmap = np.zeros_like(im0, dtype=np.float32) * 0.99 if not self.initialized else self.heatmap + self.initialized = True # Initialize heatmap only once - self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += ( - 2 * mask[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] - ) + self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator + self.extract_tracks(im0) # Extract tracks - else: - self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += 2 + # Iterate over bounding boxes, track ids and classes index + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + # Draw bounding box and counting region + self.heatmap_effect(box) - # Store tracking hist - track_line = self.track_history[track_id] - track_line.append((float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2))) - if len(track_line) > 30: - track_line.pop(0) + if self.region is not None: + self.annotator.draw_region(reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2) + self.store_tracking_history(track_id, box) # Store track history + self.store_classwise_counts(cls) # store classwise counts in dict + # Store tracking previous position and perform object counting prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None + self.count_objects(self.track_line, box, track_id, prev_position, cls) # Perform object counting - if self.count_reg_pts is not None: - # Count objects in any polygon - if len(self.count_reg_pts) >= 3: - is_inside = self.counting_region.contains(Point(track_line[-1])) - - if prev_position is not None and is_inside and track_id not in self.count_ids: - self.count_ids.append(track_id) - - if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0: - self.in_counts += 1 - self.class_wise_count[self.names[cls]]["IN"] += 1 - else: - self.out_counts += 1 - self.class_wise_count[self.names[cls]]["OUT"] += 1 - - # Count objects using line - elif len(self.count_reg_pts) == 2: - if prev_position is not None and track_id not in self.count_ids: - distance = Point(track_line[-1]).distance(self.counting_region) - if distance < self.line_dist_thresh and track_id not in self.count_ids: - self.count_ids.append(track_id) - - if (box[0] - prev_position[0]) * ( - self.counting_region.centroid.x - prev_position[0] - ) > 0: - self.in_counts += 1 - self.class_wise_count[self.names[cls]]["IN"] += 1 - else: - self.out_counts += 1 - self.class_wise_count[self.names[cls]]["OUT"] += 1 - - else: - for box, cls in zip(self.boxes, self.clss): - if self.shape == "circle": - center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2)) - radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2 - - y, x = np.ogrid[0 : self.heatmap.shape[0], 0 : self.heatmap.shape[1]] - mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius**2 - - self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += ( - 2 * mask[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] - ) - - else: - self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += 2 - - if self.count_reg_pts is not None: - labels_dict = {} - - for key, value in self.class_wise_count.items(): - if value["IN"] != 0 or value["OUT"] != 0: - if not self.view_in_counts and not self.view_out_counts: - continue - elif not self.view_in_counts: - labels_dict[str.capitalize(key)] = f"OUT {value['OUT']}" - elif not self.view_out_counts: - labels_dict[str.capitalize(key)] = f"IN {value['IN']}" - else: - labels_dict[str.capitalize(key)] = f"IN {value['IN']} OUT {value['OUT']}" - - if labels_dict is not None: - self.annotator.display_analytics(self.im0, labels_dict, self.count_txt_color, self.count_bg_color, 10) + self.display_counts(im0) if self.region is not None else None # Display the counts on the frame # Normalize, apply colormap to heatmap and combine with original image - heatmap_normalized = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX) - heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), self.colormap) - self.im0 = cv2.addWeighted(self.im0, 0.5, heatmap_colored, 0.5, 0) - - if self.env_check and self.view_img: - self.display_frames() - - return self.im0 - - def display_frames(self): - """Display frame.""" - cv2.imshow("Ultralytics Heatmap", self.im0) - - if cv2.waitKey(1) & 0xFF == ord("q"): - return - - -if __name__ == "__main__": - classes_names = {0: "person", 1: "car"} # example class names - heatmap = Heatmap(classes_names) + im0 = ( + im0 + if self.track_data.id is None + else cv2.addWeighted( + im0, + 0.5, + cv2.applyColorMap( + cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8), self.colormap + ), + 0.5, + 0, + ) + ) + + self.display_output(im0) # display output with base class function + return im0 # return output image for more usage diff --git a/ultralytics/solutions/object_counter.py b/ultralytics/solutions/object_counter.py index 073599a4c7..5fdaef258a 100644 --- a/ultralytics/solutions/object_counter.py +++ b/ultralytics/solutions/object_counter.py @@ -19,8 +19,7 @@ class ObjectCounter(BaseSolution): self.out_count = 0 # Counter for objects moving outward self.counted_ids = [] # List of IDs of objects that have been counted self.classwise_counts = {} # Dictionary for counts, categorized by object class - - self.initialize_region() # Setup region and counting areas + self.region_initialized = False # Bool variable for region initialization self.show_in = self.CFG["show_in"] self.show_out = self.CFG["show_out"] @@ -99,6 +98,10 @@ class ObjectCounter(BaseSolution): Returns im0 (ndarray): The processed image for more usage """ + if not self.region_initialized: + self.initialize_region() + self.region_initialized = True + self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator self.extract_tracks(im0) # Extract tracks @@ -107,21 +110,20 @@ class ObjectCounter(BaseSolution): ) # Draw region # Iterate over bounding boxes, track ids and classes index - if self.track_data is not None and self.track_data.id is not None: - for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): - # Draw bounding box and counting region - self.annotator.box_label(box, label=self.names[cls], color=colors(track_id, True)) - self.store_tracking_history(track_id, box) # Store track history - self.store_classwise_counts(cls) # store classwise counts in dict - - # Draw centroid of objects - self.annotator.draw_centroid_and_tracks( - self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width - ) - - # store previous position of track for object counting - prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None - self.count_objects(self.track_line, box, track_id, prev_position, cls) # Perform object counting + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + # Draw bounding box and counting region + self.annotator.box_label(box, label=self.names[cls], color=colors(track_id, True)) + self.store_tracking_history(track_id, box) # Store track history + self.store_classwise_counts(cls) # store classwise counts in dict + + # Draw centroid of objects + self.annotator.draw_centroid_and_tracks( + self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width + ) + + # store previous position of track for object counting + prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None + self.count_objects(self.track_line, box, track_id, prev_position, cls) # Perform object counting self.display_counts(im0) # Display the counts on the frame self.display_output(im0) # display output with base class function diff --git a/ultralytics/solutions/solutions.py b/ultralytics/solutions/solutions.py index ed53de654b..14b3f80a52 100644 --- a/ultralytics/solutions/solutions.py +++ b/ultralytics/solutions/solutions.py @@ -57,7 +57,8 @@ class BaseSolution: self.clss = self.track_data.cls.cpu().tolist() self.track_ids = self.track_data.id.int().cpu().tolist() else: - LOGGER.warning("WARNING ⚠️ tracks none, no keypoints will be considered.") + LOGGER.warning("WARNING ⚠️ no tracks found!") + self.boxes, self.clss, self.track_ids = [], [], [] def store_tracking_history(self, track_id, box): """