diff --git a/docs/en/guides/speed-estimation.md b/docs/en/guides/speed-estimation.md index 6a6c192de1..48a9aa09eb 100644 --- a/docs/en/guides/speed-estimation.md +++ b/docs/en/guides/speed-estimation.md @@ -45,40 +45,33 @@ keywords: Ultralytics YOLO11, speed estimation, object tracking, computer vision ```python import cv2 - from ultralytics import YOLO, solutions + from ultralytics import solutions - model = YOLO("yolo11n.pt") - names = model.model.names + cap = cv2.VideoCapture("Path/to/video/file.mp4") - cap = cv2.VideoCapture("path/to/video/file.mp4") assert cap.isOpened(), "Error reading video file" w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) - # Video writer - video_writer = cv2.VideoWriter("speed_estimation.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) + video_writer = cv2.VideoWriter("speed_management.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) - line_pts = [(0, 360), (1280, 360)] + speed_region = [(20, 400), (1080, 404), (1080, 360), (20, 360)] - # Init speed-estimation obj - speed_obj = solutions.SpeedEstimator( - reg_pts=line_pts, - names=names, - view_img=True, - ) + speed = solutions.SpeedEstimator(model="yolo11n.pt", region=speed_region, show=True) while cap.isOpened(): success, im0 = cap.read() - if not success: - print("Video frame is empty or video processing has been successfully completed.") - break - tracks = model.track(im0, persist=True) + if success: + out = speed.estimate_speed(im0) + video_writer.write(im0) + if cv2.waitKey(1) & 0xFF == ord("q"): + break + continue - im0 = speed_obj.estimate_speed(im0, tracks) - video_writer.write(im0) + print("Video frame is empty or video processing has been successfully completed.") + break cap.release() - video_writer.release() cv2.destroyAllWindows() ``` @@ -88,13 +81,12 @@ keywords: Ultralytics YOLO11, speed estimation, object tracking, computer vision ### Arguments `SpeedEstimator` -| Name | Type | Default | Description | -| ------------------ | ------ | -------------------------- | ---------------------------------------------------- | -| `names` | `dict` | `None` | Dictionary of class names. | -| `reg_pts` | `list` | `[(20, 400), (1260, 400)]` | List of region points for speed estimation. | -| `view_img` | `bool` | `False` | Whether to display the image with annotations. | -| `line_thickness` | `int` | `2` | Thickness of the lines for drawing boxes and tracks. | -| `spdl_dist_thresh` | `int` | `10` | Distance threshold for speed calculation. | +| Name | Type | Default | Description | +| ------------ | ------ | -------------------------- | ---------------------------------------------------- | +| `model` | `str` | `None` | Path to Ultralytics YOLO Model File | +| `region` | `list` | `[(20, 400), (1260, 400)]` | List of points defining the counting region. | +| `line_width` | `int` | `2` | Line thickness for bounding boxes. | +| `show` | `bool` | `False` | Flag to control whether to display the video stream. | ### Arguments `model.track` @@ -111,10 +103,7 @@ Estimating object speed with Ultralytics YOLO11 involves combining [object detec ```python import cv2 -from ultralytics import YOLO, solutions - -model = YOLO("yolo11n.pt") -names = model.model.names +from ultralytics import solutions cap = cv2.VideoCapture("path/to/video/file.mp4") w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) @@ -122,17 +111,16 @@ video_writer = cv2.VideoWriter("speed_estimation.avi", cv2.VideoWriter_fourcc(*" # Initialize SpeedEstimator speed_obj = solutions.SpeedEstimator( - reg_pts=[(0, 360), (1280, 360)], - names=names, - view_img=True, + region=[(0, 360), (1280, 360)], + model="yolo11n.pt", + show=True, ) while cap.isOpened(): success, im0 = cap.read() if not success: break - tracks = model.track(im0, persist=True, show=False) - im0 = speed_obj.estimate_speed(im0, tracks) + im0 = speed_obj.estimate_speed(im0) video_writer.write(im0) cap.release() diff --git a/tests/test_solutions.py b/tests/test_solutions.py index ef9ffe7d11..d3ba2d5fc2 100644 --- a/tests/test_solutions.py +++ b/tests/test_solutions.py @@ -14,24 +14,21 @@ WORKOUTS_SOLUTION_DEMO = "https://github.com/ultralytics/assets/releases/downloa def test_major_solutions(): """Test the object counting, heatmap, speed estimation and queue management solution.""" safe_download(url=MAJOR_SOLUTIONS_DEMO) - model = YOLO("yolo11n.pt") - names = model.names cap = cv2.VideoCapture("solutions_ci_demo.mp4") assert cap.isOpened(), "Error reading video file" region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)] counter = solutions.ObjectCounter(region=region_points, model="yolo11n.pt", show=False) heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False) - speed = solutions.SpeedEstimator(reg_pts=region_points, names=names, view_img=False) + speed = solutions.SpeedEstimator(region=region_points, model="yolo11n.pt", show=False) queue = solutions.QueueManager(region=region_points, model="yolo11n.pt", show=False) while cap.isOpened(): success, im0 = cap.read() if not success: break original_im0 = im0.copy() - tracks = model.track(im0, persist=True, show=False) _ = counter.count(original_im0.copy()) _ = heatmap.generate_heatmap(original_im0.copy()) - _ = speed.estimate_speed(original_im0.copy(), tracks) + _ = speed.estimate_speed(original_im0.copy()) _ = queue.process_queue(original_im0.copy()) cap.release() cv2.destroyAllWindows() diff --git a/ultralytics/cfg/solutions/default.yaml b/ultralytics/cfg/solutions/default.yaml index fd375223cd..a98ae52749 100644 --- a/ultralytics/cfg/solutions/default.yaml +++ b/ultralytics/cfg/solutions/default.yaml @@ -2,15 +2,15 @@ # Configuration for Ultralytics Solutions -model: "yolo11n.pt" # The Ultralytics YOLO11 model to be used (e.g., yolo11n.pt for YOLO11 nano version) +model: "yolo11n.pt" # The Ultralytics YOLO11 model to be used (e.g., yolo11n.pt for YOLO11 nano version and yolov8n.pt for YOLOv8 nano version) -region: # Object counting, queue or speed estimation region points -line_width: 2 # Thickness of the lines used to draw regions on the image/video frames -show: True # Flag to control whether to display output image or not +region: # Object counting, queue or speed estimation region points. Default region points are [(20, 400), (1080, 404), (1080, 360), (20, 360)] +line_width: 2 # Width of the annotator used to draw regions on the image/video frames + bounding boxes and tracks drawing. Default value is 2. +show: True # Flag to control whether to display output image or not, you can set this as False i.e. when deploying it on some embedded devices. show_in: True # Flag to display objects moving *into* the defined region show_out: True # Flag to display objects moving *out of* the defined region -classes: # To count specific classes -up_angle: 145.0 # Workouts up_angle for counts, 145.0 is default value -down_angle: 90 # Workouts down_angle for counts, 90 is default value -kpts: [6, 8, 10] # Keypoints for workouts monitoring -colormap: # Colormap for heatmap +classes: # To count specific classes. i.e, if you want to detect, track and count the person with COCO model, you can use classes=0, Default its None +up_angle: 145.0 # Workouts up_angle for counts, 145.0 is default value. You can adjust it for different workouts, based on position of keypoints. +down_angle: 90 # Workouts down_angle for counts, 90 is default value. You can change it for different workouts, based on position of keypoints. +kpts: [6, 8, 10] # Keypoints for workouts monitoring, i.e. If you want to consider keypoints for pushups that have mostly values of [6, 8, 10]. +colormap: # Colormap for heatmap, Only OPENCV supported colormaps can be used. By default COLORMAP_PARULA will be used for visualization. diff --git a/ultralytics/solutions/speed_estimation.py b/ultralytics/solutions/speed_estimation.py index 70964241fd..decd159b55 100644 --- a/ultralytics/solutions/speed_estimation.py +++ b/ultralytics/solutions/speed_estimation.py @@ -1,116 +1,76 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from collections import defaultdict from time import time -import cv2 import numpy as np -from ultralytics.utils.checks import check_imshow +from ultralytics.solutions.solutions import BaseSolution, LineString from ultralytics.utils.plotting import Annotator, colors -class SpeedEstimator: +class SpeedEstimator(BaseSolution): """A class to estimate the speed of objects in a real-time video stream based on their tracks.""" - def __init__(self, names, reg_pts=None, view_img=False, line_thickness=2, spdl_dist_thresh=10): - """ - Initializes the SpeedEstimator with the given parameters. - - Args: - names (dict): Dictionary of class names. - reg_pts (list, optional): List of region points for speed estimation. Defaults to [(20, 400), (1260, 400)]. - view_img (bool, optional): Whether to display the image with annotations. Defaults to False. - line_thickness (int, optional): Thickness of the lines for drawing boxes and tracks. Defaults to 2. - spdl_dist_thresh (int, optional): Distance threshold for speed calculation. Defaults to 10. - """ - # Region information - self.reg_pts = reg_pts if reg_pts is not None else [(20, 400), (1260, 400)] + def __init__(self, **kwargs): + """Initializes the SpeedEstimator with the given parameters.""" + super().__init__(**kwargs) - self.names = names # Classes names + self.initialize_region() # Initialize speed region - # Tracking information - self.trk_history = defaultdict(list) - - self.view_img = view_img # bool for displaying inference - self.tf = line_thickness # line thickness for annotator self.spd = {} # set for speed data self.trkd_ids = [] # list for already speed_estimated and tracked ID's - self.spdl = spdl_dist_thresh # Speed line distance threshold self.trk_pt = {} # set for tracks previous time self.trk_pp = {} # set for tracks previous point - # Check if the environment supports imshow - self.env_check = check_imshow(warn=True) - - def estimate_speed(self, im0, tracks): + def estimate_speed(self, im0): """ Estimates the speed of objects based on tracking data. Args: - im0 (ndarray): Image. - tracks (list): List of tracks obtained from the object tracking process. - - Returns: - (ndarray): The image with annotated boxes and tracks. + im0 (ndarray): The input image that will be used for processing + Returns + im0 (ndarray): The processed image for more usage """ - if tracks[0].boxes.id is None: - return im0 + self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator + self.extract_tracks(im0) # Extract tracks - boxes = tracks[0].boxes.xyxy.cpu() - clss = tracks[0].boxes.cls.cpu().tolist() - t_ids = tracks[0].boxes.id.int().cpu().tolist() - annotator = Annotator(im0, line_width=self.tf) - annotator.draw_region(reg_pts=self.reg_pts, color=(255, 0, 255), thickness=self.tf * 2) + self.annotator.draw_region( + reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2 + ) # Draw region - for box, t_id, cls in zip(boxes, t_ids, clss): - track = self.trk_history[t_id] - bbox_center = (float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2)) - track.append(bbox_center) + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + self.store_tracking_history(track_id, box) # Store track history - if len(track) > 30: - track.pop(0) + # Check if track_id is already in self.trk_pp or trk_pt initialize if not + if track_id not in self.trk_pt: + self.trk_pt[track_id] = 0 + if track_id not in self.trk_pp: + self.trk_pp[track_id] = self.track_line[-1] - trk_pts = np.hstack(track).astype(np.int32).reshape((-1, 1, 2)) + speed_label = f"{int(self.spd[track_id])} km/h" if track_id in self.spd else self.names[int(cls)] + self.annotator.box_label(box, label=speed_label, color=colors(track_id, True)) # Draw bounding box - if t_id not in self.trk_pt: - self.trk_pt[t_id] = 0 + # Draw tracks of objects + self.annotator.draw_centroid_and_tracks( + self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width + ) - speed_label = f"{int(self.spd[t_id])} km/h" if t_id in self.spd else self.names[int(cls)] - bbox_color = colors(int(t_id), True) - - annotator.box_label(box, speed_label, bbox_color) - cv2.polylines(im0, [trk_pts], isClosed=False, color=bbox_color, thickness=self.tf) - cv2.circle(im0, (int(track[-1][0]), int(track[-1][1])), self.tf * 2, bbox_color, -1) - - # Calculation of object speed - if not self.reg_pts[0][0] < track[-1][0] < self.reg_pts[1][0]: - return - if self.reg_pts[1][1] - self.spdl < track[-1][1] < self.reg_pts[1][1] + self.spdl: - direction = "known" - elif self.reg_pts[0][1] - self.spdl < track[-1][1] < self.reg_pts[0][1] + self.spdl: + # Calculate object speed and direction based on region intersection + if LineString([self.trk_pp[track_id], self.track_line[-1]]).intersects(self.l_s): direction = "known" else: direction = "unknown" - if self.trk_pt.get(t_id) != 0 and direction != "unknown" and t_id not in self.trkd_ids: - self.trkd_ids.append(t_id) - - time_difference = time() - self.trk_pt[t_id] + # Perform speed calculation and tracking updates if direction is valid + if direction == "known" and track_id not in self.trkd_ids: + self.trkd_ids.append(track_id) + time_difference = time() - self.trk_pt[track_id] if time_difference > 0: - self.spd[t_id] = np.abs(track[-1][1] - self.trk_pp[t_id][1]) / time_difference - - self.trk_pt[t_id] = time() - self.trk_pp[t_id] = track[-1] - - if self.view_img and self.env_check: - cv2.imshow("Ultralytics Speed Estimation", im0) - if cv2.waitKey(1) & 0xFF == ord("q"): - return + self.spd[track_id] = np.abs(self.track_line[-1][1] - self.trk_pp[track_id][1]) / time_difference - return im0 + self.trk_pt[track_id] = time() + self.trk_pp[track_id] = self.track_line[-1] + self.display_output(im0) # display output with base class function -if __name__ == "__main__": - names = {0: "person", 1: "car"} # example class names - speed_estimator = SpeedEstimator(names) + return im0 # return output image for more usage