From ef28f1078c691e84569ee74ec3f3e0177a041347 Mon Sep 17 00:00:00 2001 From: Francesco Mattioli Date: Fri, 18 Oct 2024 12:37:02 +0200 Subject: [PATCH 1/2] Fixed build docs regex security (#17012) --- docs/build_docs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/build_docs.py b/docs/build_docs.py index 483a2dd05..281a85f51 100644 --- a/docs/build_docs.py +++ b/docs/build_docs.py @@ -199,7 +199,7 @@ def convert_plaintext_links_to_html(content): for text_node in paragraph.find_all(string=True, recursive=False): if text_node.parent.name not in {"a", "code"}: # Ignore links and code blocks new_text = re.sub( - r'(https?://[^\s()<>]+(?:\.[^\s()<>]+)+)(?\1', str(text_node), ) From 8d7d1fe39047a9016645a1c8ae410e70cfa7d9a4 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 18 Oct 2024 13:54:45 +0200 Subject: [PATCH 2/2] `ultralytics 8.3.16` PyTorch 2.5.0 support (#16998) Signed-off-by: UltralyticsAssistant Co-authored-by: UltralyticsAssistant Co-authored-by: RizwanMunawar Co-authored-by: Muhammad Rizwan Munawar --- .github/workflows/publish.yml | 2 +- docs/mkdocs_github_authors.yaml | 3 + mkdocs.yml | 1 + pyproject.toml | 2 +- tests/test_solutions.py | 36 +++--- ultralytics/__init__.py | 2 +- ultralytics/data/split_dota.py | 6 +- ultralytics/solutions/ai_gym.py | 52 +++++++-- ultralytics/solutions/analytics.py | 77 +++++++++++-- ultralytics/solutions/distance_calculation.py | 60 ++++++++-- ultralytics/solutions/heatmap.py | 64 ++++++++--- ultralytics/solutions/object_counter.py | 104 ++++++++++++++---- ultralytics/solutions/parking_management.py | 90 +++++++++++++-- ultralytics/solutions/queue_management.py | 67 +++++++++-- ultralytics/solutions/solutions.py | 95 ++++++++++++---- ultralytics/solutions/speed_estimation.py | 48 ++++++-- ultralytics/solutions/streamlit_inference.py | 5 +- 17 files changed, 570 insertions(+), 144 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 024eb8567..1ec1b9a93 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -18,7 +18,7 @@ jobs: name: Publish runs-on: ubuntu-latest permissions: - id-token: write # for PyPI trusted publishing + id-token: write # for PyPI trusted publishing steps: - name: Checkout code uses: actions/checkout@v4 diff --git a/docs/mkdocs_github_authors.yaml b/docs/mkdocs_github_authors.yaml index 0e0423c24..2e2092138 100644 --- a/docs/mkdocs_github_authors.yaml +++ b/docs/mkdocs_github_authors.yaml @@ -76,6 +76,9 @@ 79740115+0xSynapse@users.noreply.github.com: avatar: https://avatars.githubusercontent.com/u/79740115?v=4 username: 0xSynapse +91465467+lalayants@users.noreply.github.com: + avatar: https://avatars.githubusercontent.com/u/91465467?v=4 + username: lalayants Francesco.mttl@gmail.com: avatar: https://avatars.githubusercontent.com/u/3855193?v=4 username: ambitious-octopus diff --git a/mkdocs.yml b/mkdocs.yml index 771084066..f5298dc47 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -555,6 +555,7 @@ nav: - utils: reference/nn/modules/utils.md - tasks: reference/nn/tasks.md - solutions: + - solutions: reference/solutions/solutions.md - ai_gym: reference/solutions/ai_gym.md - analytics: reference/solutions/analytics.md - distance_calculation: reference/solutions/distance_calculation.md diff --git a/pyproject.toml b/pyproject.toml index 3fb80e62a..f6cb23204 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ build-backend = "setuptools.build_meta" [project] name = "ultralytics" dynamic = ["version"] -description = "Ultralytics YOLO for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification." +description = "Ultralytics YOLO 🚀 for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification." readme = "README.md" requires-python = ">=3.8" license = { "text" = "AGPL-3.0" } diff --git a/tests/test_solutions.py b/tests/test_solutions.py index d3ba2d5fc..e01da6d81 100644 --- a/tests/test_solutions.py +++ b/tests/test_solutions.py @@ -17,10 +17,15 @@ def test_major_solutions(): cap = cv2.VideoCapture("solutions_ci_demo.mp4") assert cap.isOpened(), "Error reading video file" region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)] - counter = solutions.ObjectCounter(region=region_points, model="yolo11n.pt", show=False) - heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False) - speed = solutions.SpeedEstimator(region=region_points, model="yolo11n.pt", show=False) - queue = solutions.QueueManager(region=region_points, model="yolo11n.pt", show=False) + counter = solutions.ObjectCounter(region=region_points, model="yolo11n.pt", show=False) # Test object counter + heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False) # Test heatmaps + speed = solutions.SpeedEstimator(region=region_points, model="yolo11n.pt", show=False) # Test queue manager + queue = solutions.QueueManager(region=region_points, model="yolo11n.pt", show=False) # Test speed estimation + line_analytics = solutions.Analytics(analytics_type="line", model="yolo11n.pt", show=False) # line analytics + pie_analytics = solutions.Analytics(analytics_type="pie", model="yolo11n.pt", show=False) # line analytics + bar_analytics = solutions.Analytics(analytics_type="bar", model="yolo11n.pt", show=False) # line analytics + area_analytics = solutions.Analytics(analytics_type="area", model="yolo11n.pt", show=False) # line analytics + frame_count = 0 # Required for analytics while cap.isOpened(): success, im0 = cap.read() if not success: @@ -30,24 +35,23 @@ def test_major_solutions(): _ = heatmap.generate_heatmap(original_im0.copy()) _ = speed.estimate_speed(original_im0.copy()) _ = queue.process_queue(original_im0.copy()) + _ = line_analytics.process_data(original_im0.copy(), frame_count) + _ = pie_analytics.process_data(original_im0.copy(), frame_count) + _ = bar_analytics.process_data(original_im0.copy(), frame_count) + _ = area_analytics.process_data(original_im0.copy(), frame_count) cap.release() - cv2.destroyAllWindows() - -@pytest.mark.slow -def test_aigym(): - """Test the workouts monitoring solution.""" + # Test workouts monitoring safe_download(url=WORKOUTS_SOLUTION_DEMO) - cap = cv2.VideoCapture("solution_ci_pose_demo.mp4") - assert cap.isOpened(), "Error reading video file" - gym = solutions.AIGym(line_width=2, kpts=[5, 11, 13]) - while cap.isOpened(): - success, im0 = cap.read() + cap1 = cv2.VideoCapture("solution_ci_pose_demo.mp4") + assert cap1.isOpened(), "Error reading video file" + gym = solutions.AIGym(line_width=2, kpts=[5, 11, 13], show=False) + while cap1.isOpened(): + success, im0 = cap1.read() if not success: break _ = gym.monitor(im0) - cap.release() - cv2.destroyAllWindows() + cap1.release() @pytest.mark.slow diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index d83c00a02..9c0a6f394 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.3.15" +__version__ = "8.3.16" import os diff --git a/ultralytics/data/split_dota.py b/ultralytics/data/split_dota.py index f9acffe9b..b745b3662 100644 --- a/ultralytics/data/split_dota.py +++ b/ultralytics/data/split_dota.py @@ -13,9 +13,6 @@ from tqdm import tqdm from ultralytics.data.utils import exif_size, img2label_paths from ultralytics.utils.checks import check_requirements -check_requirements("shapely") -from shapely.geometry import Polygon - def bbox_iof(polygon1, bbox2, eps=1e-6): """ @@ -33,6 +30,9 @@ def bbox_iof(polygon1, bbox2, eps=1e-6): Polygon format: [x1, y1, x2, y2, x3, y3, x4, y4]. Bounding box format: [x_min, y_min, x_max, y_max]. """ + check_requirements("shapely") + from shapely.geometry import Polygon + polygon1 = polygon1.reshape(-1, 4, 2) lt_point = np.min(polygon1, axis=-2) # left-top rb_point = np.max(polygon1, axis=-2) # right-bottom diff --git a/ultralytics/solutions/ai_gym.py b/ultralytics/solutions/ai_gym.py index 02345749c..0d131bd9d 100644 --- a/ultralytics/solutions/ai_gym.py +++ b/ultralytics/solutions/ai_gym.py @@ -1,16 +1,40 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from ultralytics.solutions.solutions import BaseSolution # Import a parent class +from ultralytics.solutions.solutions import BaseSolution from ultralytics.utils.plotting import Annotator class AIGym(BaseSolution): - """A class to manage the gym steps of people in a real-time video stream based on their poses.""" + """ + A class to manage gym steps of people in a real-time video stream based on their poses. + + This class extends BaseSolution to monitor workouts using YOLO pose estimation models. It tracks and counts + repetitions of exercises based on predefined angle thresholds for up and down positions. + + Attributes: + count (List[int]): Repetition counts for each detected person. + angle (List[float]): Current angle of the tracked body part for each person. + stage (List[str]): Current exercise stage ('up', 'down', or '-') for each person. + initial_stage (str | None): Initial stage of the exercise. + up_angle (float): Angle threshold for considering the 'up' position of an exercise. + down_angle (float): Angle threshold for considering the 'down' position of an exercise. + kpts (List[int]): Indices of keypoints used for angle calculation. + lw (int): Line width for drawing annotations. + annotator (Annotator): Object for drawing annotations on the image. + + Methods: + monitor: Processes a frame to detect poses, calculate angles, and count repetitions. + + Examples: + >>> gym = AIGym(model="yolov8n-pose.pt") + >>> image = cv2.imread("gym_scene.jpg") + >>> processed_image = gym.monitor(image) + >>> cv2.imshow("Processed Image", processed_image) + >>> cv2.waitKey(0) + """ def __init__(self, **kwargs): - """Initialization function for AiGYM class, a child class of BaseSolution class, can be used for workouts - monitoring. - """ + """Initializes AIGym for workout monitoring using pose estimation and predefined angles.""" # Check if the model name ends with '-pose' if "model" in kwargs and "-pose" not in kwargs["model"]: kwargs["model"] = "yolo11n-pose.pt" @@ -31,12 +55,22 @@ class AIGym(BaseSolution): def monitor(self, im0): """ - Monitor the workouts using Ultralytics YOLO Pose Model: https://docs.ultralytics.com/tasks/pose/. + Monitors workouts using Ultralytics YOLO Pose Model. + + This function processes an input image to track and analyze human poses for workout monitoring. It uses + the YOLO Pose model to detect keypoints, estimate angles, and count repetitions based on predefined + angle thresholds. Args: - im0 (ndarray): The input image that will be used for processing - Returns - im0 (ndarray): The processed image for more usage + im0 (ndarray): Input image for processing. + + Returns: + (ndarray): Processed image with annotations for workout monitoring. + + Examples: + >>> gym = AIGym() + >>> image = cv2.imread("workout.jpg") + >>> processed_image = gym.monitor(image) """ # Extract tracks tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"])[0] diff --git a/ultralytics/solutions/analytics.py b/ultralytics/solutions/analytics.py index 6c2f27db0..aed7beed9 100644 --- a/ultralytics/solutions/analytics.py +++ b/ultralytics/solutions/analytics.py @@ -12,10 +12,41 @@ from ultralytics.solutions.solutions import BaseSolution # Import a parent clas class Analytics(BaseSolution): - """A class to create and update various types of charts (line, bar, pie, area) for visual analytics.""" + """ + A class for creating and updating various types of charts for visual analytics. + + This class extends BaseSolution to provide functionality for generating line, bar, pie, and area charts + based on object detection and tracking data. + + Attributes: + type (str): The type of analytics chart to generate ('line', 'bar', 'pie', or 'area'). + x_label (str): Label for the x-axis. + y_label (str): Label for the y-axis. + bg_color (str): Background color of the chart frame. + fg_color (str): Foreground color of the chart frame. + title (str): Title of the chart window. + max_points (int): Maximum number of data points to display on the chart. + fontsize (int): Font size for text display. + color_cycle (cycle): Cyclic iterator for chart colors. + total_counts (int): Total count of detected objects (used for line charts). + clswise_count (Dict[str, int]): Dictionary for class-wise object counts. + fig (Figure): Matplotlib figure object for the chart. + ax (Axes): Matplotlib axes object for the chart. + canvas (FigureCanvas): Canvas for rendering the chart. + + Methods: + process_data: Processes image data and updates the chart. + update_graph: Updates the chart with new data points. + + Examples: + >>> analytics = Analytics(analytics_type="line") + >>> frame = cv2.imread("image.jpg") + >>> processed_frame = analytics.process_data(frame, frame_number=1) + >>> cv2.imshow("Analytics", processed_frame) + """ def __init__(self, **kwargs): - """Initialize the Analytics class with various chart types.""" + """Initialize Analytics class with various chart types for visual data representation.""" super().__init__(**kwargs) self.type = self.CFG["analytics_type"] # extract type of analytics @@ -31,8 +62,8 @@ class Analytics(BaseSolution): figsize = (19.2, 10.8) # Set output image size 1920 * 1080 self.color_cycle = cycle(["#DD00BA", "#042AFF", "#FF4447", "#7D24FF", "#BD00FF"]) - self.total_counts = 0 # count variable for storing total counts i.e for line - self.clswise_count = {} # dictionary for classwise counts + self.total_counts = 0 # count variable for storing total counts i.e. for line + self.clswise_count = {} # dictionary for class-wise counts # Ensure line and area chart if self.type in {"line", "area"}: @@ -48,15 +79,28 @@ class Analytics(BaseSolution): self.canvas = FigureCanvas(self.fig) # Set common axis properties self.ax.set_facecolor(self.bg_color) self.color_mapping = {} - self.ax.axis("equal") if self.type == "pie" else None # Ensure pie chart is circular + + if self.type == "pie": # Ensure pie chart is circular + self.ax.axis("equal") def process_data(self, im0, frame_number): """ - Process the image data, run object tracking. + Processes image data and runs object tracking to update analytics charts. Args: - im0 (ndarray): Input image for processing. - frame_number (int): Video frame # for plotting the data. + im0 (np.ndarray): Input image for processing. + frame_number (int): Video frame number for plotting the data. + + Returns: + (np.ndarray): Processed image with updated analytics chart. + + Raises: + ModuleNotFoundError: If an unsupported chart type is specified. + + Examples: + >>> analytics = Analytics(analytics_type="line") + >>> frame = np.zeros((480, 640, 3), dtype=np.uint8) + >>> processed_frame = analytics.process_data(frame, frame_number=1) """ self.extract_tracks(im0) # Extract tracks @@ -79,13 +123,22 @@ class Analytics(BaseSolution): def update_graph(self, frame_number, count_dict=None, plot="line"): """ - Update the graph (line or area) with new data for single or multiple classes. + Updates the graph with new data for single or multiple classes. Args: frame_number (int): The current frame number. - count_dict (dict, optional): Dictionary with class names as keys and counts as values for multiple classes. - If None, updates a single line graph. - plot (str): Type of the plot i.e. line, bar or area. + count_dict (Dict[str, int] | None): Dictionary with class names as keys and counts as values for multiple + classes. If None, updates a single line graph. + plot (str): Type of the plot. Options are 'line', 'bar', 'pie', or 'area'. + + Returns: + (np.ndarray): Updated image containing the graph. + + Examples: + >>> analytics = Analytics() + >>> frame_number = 10 + >>> count_dict = {"person": 5, "car": 3} + >>> updated_image = analytics.update_graph(frame_number, count_dict, plot="bar") """ if count_dict is None: # Single line update diff --git a/ultralytics/solutions/distance_calculation.py b/ultralytics/solutions/distance_calculation.py index 773b6086d..608aa97d7 100644 --- a/ultralytics/solutions/distance_calculation.py +++ b/ultralytics/solutions/distance_calculation.py @@ -4,15 +4,41 @@ import math import cv2 -from ultralytics.solutions.solutions import BaseSolution # Import a parent class +from ultralytics.solutions.solutions import BaseSolution from ultralytics.utils.plotting import Annotator, colors class DistanceCalculation(BaseSolution): - """A class to calculate distance between two objects in a real-time video stream based on their tracks.""" + """ + A class to calculate distance between two objects in a real-time video stream based on their tracks. + + This class extends BaseSolution to provide functionality for selecting objects and calculating the distance + between them in a video stream using YOLO object detection and tracking. + + Attributes: + left_mouse_count (int): Counter for left mouse button clicks. + selected_boxes (Dict[int, List[float]]): Dictionary to store selected bounding boxes and their track IDs. + annotator (Annotator): An instance of the Annotator class for drawing on the image. + boxes (List[List[float]]): List of bounding boxes for detected objects. + track_ids (List[int]): List of track IDs for detected objects. + clss (List[int]): List of class indices for detected objects. + names (List[str]): List of class names that the model can detect. + centroids (List[List[int]]): List to store centroids of selected bounding boxes. + + Methods: + mouse_event_for_distance: Handles mouse events for selecting objects in the video stream. + calculate: Processes video frames and calculates the distance between selected objects. + + Examples: + >>> distance_calc = DistanceCalculation() + >>> frame = cv2.imread("frame.jpg") + >>> processed_frame = distance_calc.calculate(frame) + >>> cv2.imshow("Distance Calculation", processed_frame) + >>> cv2.waitKey(0) + """ def __init__(self, **kwargs): - """Initializes the DistanceCalculation class with the given parameters.""" + """Initializes the DistanceCalculation class for measuring object distances in video streams.""" super().__init__(**kwargs) # Mouse event information @@ -21,14 +47,18 @@ class DistanceCalculation(BaseSolution): def mouse_event_for_distance(self, event, x, y, flags, param): """ - Handles mouse events to select regions in a real-time video stream. + Handles mouse events to select regions in a real-time video stream for distance calculation. Args: - event (int): Type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN, etc.). + event (int): Type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN). x (int): X-coordinate of the mouse pointer. y (int): Y-coordinate of the mouse pointer. - flags (int): Flags associated with the event (e.g., cv2.EVENT_FLAG_CTRLKEY, cv2.EVENT_FLAG_SHIFTKEY, etc.). - param (dict): Additional parameters passed to the function. + flags (int): Flags associated with the event (e.g., cv2.EVENT_FLAG_CTRLKEY, cv2.EVENT_FLAG_SHIFTKEY). + param (Dict): Additional parameters passed to the function. + + Examples: + >>> # Assuming 'dc' is an instance of DistanceCalculation + >>> cv2.setMouseCallback("window_name", dc.mouse_event_for_distance) """ if event == cv2.EVENT_LBUTTONDOWN: self.left_mouse_count += 1 @@ -43,13 +73,23 @@ class DistanceCalculation(BaseSolution): def calculate(self, im0): """ - Processes the video frame and calculates the distance between two bounding boxes. + Processes a video frame and calculates the distance between two selected bounding boxes. + + This method extracts tracks from the input frame, annotates bounding boxes, and calculates the distance + between two user-selected objects if they have been chosen. Args: - im0 (ndarray): The image frame. + im0 (numpy.ndarray): The input image frame to process. Returns: - (ndarray): The processed image frame. + (numpy.ndarray): The processed image frame with annotations and distance calculations. + + Examples: + >>> import numpy as np + >>> from ultralytics.solutions import DistanceCalculation + >>> dc = DistanceCalculation() + >>> frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) + >>> processed_frame = dc.calculate(frame) """ self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator self.extract_tracks(im0) # Extract tracks diff --git a/ultralytics/solutions/heatmap.py b/ultralytics/solutions/heatmap.py index d7dcf71cf..39352a9bd 100644 --- a/ultralytics/solutions/heatmap.py +++ b/ultralytics/solutions/heatmap.py @@ -3,15 +3,40 @@ import cv2 import numpy as np -from ultralytics.solutions.object_counter import ObjectCounter # Import object counter class +from ultralytics.solutions.object_counter import ObjectCounter from ultralytics.utils.plotting import Annotator class Heatmap(ObjectCounter): - """A class to draw heatmaps in real-time video stream based on their tracks.""" + """ + A class to draw heatmaps in real-time video streams based on object tracks. + + This class extends the ObjectCounter class to generate and visualize heatmaps of object movements in video + streams. It uses tracked object positions to create a cumulative heatmap effect over time. + + Attributes: + initialized (bool): Flag indicating whether the heatmap has been initialized. + colormap (int): OpenCV colormap used for heatmap visualization. + heatmap (np.ndarray): Array storing the cumulative heatmap data. + annotator (Annotator): Object for drawing annotations on the image. + + Methods: + heatmap_effect: Calculates and updates the heatmap effect for a given bounding box. + generate_heatmap: Generates and applies the heatmap effect to each frame. + + Examples: + >>> from ultralytics.solutions import Heatmap + >>> heatmap = Heatmap(model="yolov8n.pt", colormap=cv2.COLORMAP_JET) + >>> results = heatmap("path/to/video.mp4") + >>> for result in results: + ... print(result.speed) # Print inference speed + ... cv2.imshow("Heatmap", result.plot()) + ... if cv2.waitKey(1) & 0xFF == ord("q"): + ... break + """ def __init__(self, **kwargs): - """Initializes function for heatmap class with default values.""" + """Initializes the Heatmap class for real-time video stream heatmap generation based on object tracks.""" super().__init__(**kwargs) self.initialized = False # bool variable for heatmap initialization @@ -23,10 +48,15 @@ class Heatmap(ObjectCounter): def heatmap_effect(self, box): """ - Efficient calculation of heatmap area and effect location for applying colormap. + Efficiently calculates heatmap area and effect location for applying colormap. Args: - box (list): Bounding Box coordinates data [x0, y0, x1, y1] + box (List[float]): Bounding box coordinates [x0, y0, x1, y1]. + + Examples: + >>> heatmap = Heatmap() + >>> box = [100, 100, 200, 200] + >>> heatmap.heatmap_effect(box) """ x0, y0, x1, y1 = map(int, box) radius_squared = (min(x1 - x0, y1 - y0) // 2) ** 2 @@ -48,9 +78,15 @@ class Heatmap(ObjectCounter): Generate heatmap for each frame using Ultralytics. Args: - im0 (ndarray): Input image array for processing + im0 (np.ndarray): Input image array for processing. + Returns: - im0 (ndarray): Processed image for further usage + (np.ndarray): Processed image with heatmap overlay and object counts (if region is specified). + + Examples: + >>> heatmap = Heatmap() + >>> im0 = cv2.imread("image.jpg") + >>> result = heatmap.generate_heatmap(im0) """ if not self.initialized: self.heatmap = np.zeros_like(im0, dtype=np.float32) * 0.99 @@ -70,16 +106,17 @@ class Heatmap(ObjectCounter): self.store_classwise_counts(cls) # store classwise counts in dict # Store tracking previous position and perform object counting - prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None + prev_position = None + if len(self.track_history[track_id]) > 1: + prev_position = self.track_history[track_id][-2] self.count_objects(self.track_line, box, track_id, prev_position, cls) # Perform object counting - self.display_counts(im0) if self.region is not None else None # Display the counts on the frame + if self.region is not None: + self.display_counts(im0) # Display the counts on the frame # Normalize, apply colormap to heatmap and combine with original image - im0 = ( - im0 - if self.track_data.id is None - else cv2.addWeighted( + if self.track_data.id is not None: + im0 = cv2.addWeighted( im0, 0.5, cv2.applyColorMap( @@ -88,7 +125,6 @@ class Heatmap(ObjectCounter): 0.5, 0, ) - ) self.display_output(im0) # display output with base class function return im0 # return output image for more usage diff --git a/ultralytics/solutions/object_counter.py b/ultralytics/solutions/object_counter.py index d57674642..637492073 100644 --- a/ultralytics/solutions/object_counter.py +++ b/ultralytics/solutions/object_counter.py @@ -1,18 +1,40 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from shapely.geometry import LineString, Point - -from ultralytics.solutions.solutions import BaseSolution # Import a parent class +from ultralytics.solutions.solutions import BaseSolution from ultralytics.utils.plotting import Annotator, colors class ObjectCounter(BaseSolution): - """A class to manage the counting of objects in a real-time video stream based on their tracks.""" + """ + A class to manage the counting of objects in a real-time video stream based on their tracks. + + This class extends the BaseSolution class and provides functionality for counting objects moving in and out of a + specified region in a video stream. It supports both polygonal and linear regions for counting. + + Attributes: + in_count (int): Counter for objects moving inward. + out_count (int): Counter for objects moving outward. + counted_ids (List[int]): List of IDs of objects that have been counted. + classwise_counts (Dict[str, Dict[str, int]]): Dictionary for counts, categorized by object class. + region_initialized (bool): Flag indicating whether the counting region has been initialized. + show_in (bool): Flag to control display of inward count. + show_out (bool): Flag to control display of outward count. + + Methods: + count_objects: Counts objects within a polygonal or linear region. + store_classwise_counts: Initializes class-wise counts if not already present. + display_counts: Displays object counts on the frame. + count: Processes input data (frames or object tracks) and updates counts. + + Examples: + >>> counter = ObjectCounter() + >>> frame = cv2.imread("frame.jpg") + >>> processed_frame = counter.count(frame) + >>> print(f"Inward count: {counter.in_count}, Outward count: {counter.out_count}") + """ def __init__(self, **kwargs): - """Initialization function for Count class, a child class of BaseSolution class, can be used for counting the - objects. - """ + """Initializes the ObjectCounter class for real-time object counting in video streams.""" super().__init__(**kwargs) self.in_count = 0 # Counter for objects moving inward @@ -26,14 +48,23 @@ class ObjectCounter(BaseSolution): def count_objects(self, track_line, box, track_id, prev_position, cls): """ - Helper function to count objects within a polygonal region. + Counts objects within a polygonal or linear region based on their tracks. Args: - track_line (dict): last 30 frame track record - box (list): Bounding box data for specific track in current frame - track_id (int): track ID of the object - prev_position (tuple): last frame position coordinates of the track - cls (int): Class index for classwise count updates + track_line (Dict): Last 30 frame track record for the object. + box (List[float]): Bounding box coordinates [x1, y1, x2, y2] for the specific track in the current frame. + track_id (int): Unique identifier for the tracked object. + prev_position (Tuple[float, float]): Last frame position coordinates (x, y) of the track. + cls (int): Class index for classwise count updates. + + Examples: + >>> counter = ObjectCounter() + >>> track_line = {1: [100, 200], 2: [110, 210], 3: [120, 220]} + >>> box = [130, 230, 150, 250] + >>> track_id = 1 + >>> prev_position = (120, 220) + >>> cls = 0 + >>> counter.count_objects(track_line, box, track_id, prev_position, cls) """ if prev_position is None or track_id in self.counted_ids: return @@ -42,7 +73,7 @@ class ObjectCounter(BaseSolution): dx = (box[0] - prev_position[0]) * (centroid.x - prev_position[0]) dy = (box[1] - prev_position[1]) * (centroid.y - prev_position[1]) - if len(self.region) >= 3 and self.r_s.contains(Point(track_line[-1])): + if len(self.region) >= 3 and self.r_s.contains(self.Point(track_line[-1])): self.counted_ids.append(track_id) # For polygon region if dx > 0: @@ -52,7 +83,7 @@ class ObjectCounter(BaseSolution): self.out_count += 1 self.classwise_counts[self.names[cls]]["OUT"] += 1 - elif len(self.region) < 3 and LineString([prev_position, box[:2]]).intersects(self.l_s): + elif len(self.region) < 3 and self.LineString([prev_position, box[:2]]).intersects(self.r_s): self.counted_ids.append(track_id) # For linear region if dx > 0 and dy > 0: @@ -64,20 +95,34 @@ class ObjectCounter(BaseSolution): def store_classwise_counts(self, cls): """ - Initialize class-wise counts if not already present. + Initialize class-wise counts for a specific object class if not already present. Args: - cls (int): Class index for classwise count updates + cls (int): Class index for classwise count updates. + + This method ensures that the 'classwise_counts' dictionary contains an entry for the specified class, + initializing 'IN' and 'OUT' counts to zero if the class is not already present. + + Examples: + >>> counter = ObjectCounter() + >>> counter.store_classwise_counts(0) # Initialize counts for class index 0 + >>> print(counter.classwise_counts) + {'person': {'IN': 0, 'OUT': 0}} """ if self.names[cls] not in self.classwise_counts: self.classwise_counts[self.names[cls]] = {"IN": 0, "OUT": 0} def display_counts(self, im0): """ - Helper function to display object counts on the frame. + Displays object counts on the input image or frame. Args: - im0 (ndarray): The input image or frame + im0 (numpy.ndarray): The input image or frame to display counts on. + + Examples: + >>> counter = ObjectCounter() + >>> frame = cv2.imread("image.jpg") + >>> counter.display_counts(frame) """ labels_dict = { str.capitalize(key): f"{'IN ' + str(value['IN']) if self.show_in else ''} " @@ -91,12 +136,21 @@ class ObjectCounter(BaseSolution): def count(self, im0): """ - Processes input data (frames or object tracks) and updates counts. + Processes input data (frames or object tracks) and updates object counts. + + This method initializes the counting region, extracts tracks, draws bounding boxes and regions, updates + object counts, and displays the results on the input image. Args: - im0 (ndarray): The input image that will be used for processing - Returns - im0 (ndarray): The processed image for more usage + im0 (numpy.ndarray): The input image or frame to be processed. + + Returns: + (numpy.ndarray): The processed image with annotations and count information. + + Examples: + >>> counter = ObjectCounter() + >>> frame = cv2.imread("path/to/image.jpg") + >>> processed_frame = counter.count(frame) """ if not self.region_initialized: self.initialize_region() @@ -122,7 +176,9 @@ class ObjectCounter(BaseSolution): ) # store previous position of track for object counting - prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None + prev_position = None + if len(self.track_history[track_id]) > 1: + prev_position = self.track_history[track_id][-2] self.count_objects(self.track_line, box, track_id, prev_position, cls) # Perform object counting self.display_counts(im0) # Display the counts on the frame diff --git a/ultralytics/solutions/parking_management.py b/ultralytics/solutions/parking_management.py index 33beb80bf..fa815938a 100644 --- a/ultralytics/solutions/parking_management.py +++ b/ultralytics/solutions/parking_management.py @@ -10,10 +10,44 @@ from ultralytics.utils.plotting import Annotator class ParkingPtsSelection: - """Class for selecting and managing parking zone points on images using a Tkinter-based UI.""" + """ + A class for selecting and managing parking zone points on images using a Tkinter-based UI. + + This class provides functionality to upload an image, select points to define parking zones, and save the + selected points to a JSON file. It uses Tkinter for the graphical user interface. + + Attributes: + tk (module): The Tkinter module for GUI operations. + filedialog (module): Tkinter's filedialog module for file selection operations. + messagebox (module): Tkinter's messagebox module for displaying message boxes. + master (tk.Tk): The main Tkinter window. + canvas (tk.Canvas): The canvas widget for displaying the image and drawing bounding boxes. + image (PIL.Image.Image): The uploaded image. + canvas_image (ImageTk.PhotoImage): The image displayed on the canvas. + rg_data (List[List[Tuple[int, int]]]): List of bounding boxes, each defined by 4 points. + current_box (List[Tuple[int, int]]): Temporary storage for the points of the current bounding box. + imgw (int): Original width of the uploaded image. + imgh (int): Original height of the uploaded image. + canvas_max_width (int): Maximum width of the canvas. + canvas_max_height (int): Maximum height of the canvas. + + Methods: + setup_ui: Sets up the Tkinter UI components. + initialize_properties: Initializes the necessary properties. + upload_image: Uploads an image, resizes it to fit the canvas, and displays it. + on_canvas_click: Handles mouse clicks to add points for bounding boxes. + draw_box: Draws a bounding box on the canvas. + remove_last_bounding_box: Removes the last bounding box and redraws the canvas. + redraw_canvas: Redraws the canvas with the image and all bounding boxes. + save_to_json: Saves the bounding boxes to a JSON file. + + Examples: + >>> parking_selector = ParkingPtsSelection() + >>> # Use the GUI to upload an image, select parking zones, and save the data + """ def __init__(self): - """Class initialization method.""" + """Initializes the ParkingPtsSelection class, setting up UI and properties for parking zone point selection.""" check_requirements("tkinter") import tkinter as tk from tkinter import filedialog, messagebox @@ -24,7 +58,7 @@ class ParkingPtsSelection: self.master.mainloop() def setup_ui(self): - """Sets up the Tkinter UI components.""" + """Sets up the Tkinter UI components for the parking zone points selection interface.""" self.master = self.tk.Tk() self.master.title("Ultralytics Parking Zones Points Selector") self.master.resizable(False, False) @@ -45,14 +79,14 @@ class ParkingPtsSelection: self.tk.Button(button_frame, text=text, command=cmd).pack(side=self.tk.LEFT) def initialize_properties(self): - """Initialize the necessary properties.""" + """Initialize properties for image, canvas, bounding boxes, and dimensions.""" self.image = self.canvas_image = None self.rg_data, self.current_box = [], [] self.imgw = self.imgh = 0 self.canvas_max_width, self.canvas_max_height = 1280, 720 def upload_image(self): - """Uploads an image, resizes it to fit the canvas, and displays it.""" + """Uploads and displays an image on the canvas, resizing it to fit within specified dimensions.""" from PIL import Image, ImageTk # scope because ImageTk requires tkinter package self.image = Image.open(self.filedialog.askopenfilename(filetypes=[("Image Files", "*.png;*.jpg;*.jpeg")])) @@ -76,7 +110,7 @@ class ParkingPtsSelection: self.rg_data.clear(), self.current_box.clear() def on_canvas_click(self, event): - """Handles mouse clicks to add points for bounding boxes.""" + """Handles mouse clicks to add points for bounding boxes on the canvas.""" self.current_box.append((event.x, event.y)) self.canvas.create_oval(event.x - 3, event.y - 3, event.x + 3, event.y + 3, fill="red") if len(self.current_box) == 4: @@ -85,12 +119,12 @@ class ParkingPtsSelection: self.current_box.clear() def draw_box(self, box): - """Draws a bounding box on the canvas.""" + """Draws a bounding box on the canvas using the provided coordinates.""" for i in range(4): self.canvas.create_line(box[i], box[(i + 1) % 4], fill="blue", width=2) def remove_last_bounding_box(self): - """Removes the last bounding box and redraws the canvas.""" + """Removes the last bounding box from the list and redraws the canvas.""" if not self.rg_data: self.messagebox.showwarning("Warning", "No bounding boxes to remove.") return @@ -105,7 +139,7 @@ class ParkingPtsSelection: self.draw_box(box) def save_to_json(self): - """Saves the bounding boxes to a JSON file.""" + """Saves the selected parking zone points to a JSON file with scaled coordinates.""" scale_w, scale_h = self.imgw / self.canvas.winfo_width(), self.imgh / self.canvas.winfo_height() data = [{"points": [(int(x * scale_w), int(y * scale_h)) for x, y in box]} for box in self.rg_data] with open("bounding_boxes.json", "w") as f: @@ -114,7 +148,30 @@ class ParkingPtsSelection: class ParkingManagement(BaseSolution): - """Manages parking occupancy and availability using YOLO model for real-time monitoring and visualization.""" + """ + Manages parking occupancy and availability using YOLO model for real-time monitoring and visualization. + + This class extends BaseSolution to provide functionality for parking lot management, including detection of + occupied spaces, visualization of parking regions, and display of occupancy statistics. + + Attributes: + json_file (str): Path to the JSON file containing parking region details. + json (List[Dict]): Loaded JSON data containing parking region information. + pr_info (Dict[str, int]): Dictionary storing parking information (Occupancy and Available spaces). + arc (Tuple[int, int, int]): RGB color tuple for available region visualization. + occ (Tuple[int, int, int]): RGB color tuple for occupied region visualization. + dc (Tuple[int, int, int]): RGB color tuple for centroid visualization of detected objects. + + Methods: + process_data: Processes model data for parking lot management and visualization. + + Examples: + >>> from ultralytics.solutions import ParkingManagement + >>> parking_manager = ParkingManagement(model="yolov8n.pt", json_file="parking_regions.json") + >>> results = parking_manager(source="parking_lot_video.mp4") + >>> print(f"Occupied spaces: {parking_manager.pr_info['Occupancy']}") + >>> print(f"Available spaces: {parking_manager.pr_info['Available']}") + """ def __init__(self, **kwargs): """Initializes the parking management system with a YOLO model and visualization settings.""" @@ -136,10 +193,19 @@ class ParkingManagement(BaseSolution): def process_data(self, im0): """ - Process the model data for parking lot management. + Processes the model data for parking lot management. + + This function analyzes the input image, extracts tracks, and determines the occupancy status of parking + regions defined in the JSON file. It annotates the image with occupied and available parking spots, + and updates the parking information. Args: - im0 (ndarray): inference image. + im0 (np.ndarray): The input inference image. + + Examples: + >>> parking_manager = ParkingManagement(json_file="parking_regions.json") + >>> image = cv2.imread("parking_lot.jpg") + >>> parking_manager.process_data(image) """ self.extract_tracks(im0) # extract tracks from im0 es, fs = len(self.json), 0 # empty slots, filled slots diff --git a/ultralytics/solutions/queue_management.py b/ultralytics/solutions/queue_management.py index 287f337dc..ca0acb14f 100644 --- a/ultralytics/solutions/queue_management.py +++ b/ultralytics/solutions/queue_management.py @@ -1,16 +1,40 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from shapely.geometry import Point - -from ultralytics.solutions.solutions import BaseSolution # Import a parent class +from ultralytics.solutions.solutions import BaseSolution from ultralytics.utils.plotting import Annotator, colors class QueueManager(BaseSolution): - """A class to manage the queue in a real-time video stream based on object tracks.""" + """ + Manages queue counting in real-time video streams based on object tracks. + + This class extends BaseSolution to provide functionality for tracking and counting objects within a specified + region in video frames. + + Attributes: + counts (int): The current count of objects in the queue. + rect_color (Tuple[int, int, int]): RGB color tuple for drawing the queue region rectangle. + region_length (int): The number of points defining the queue region. + annotator (Annotator): An instance of the Annotator class for drawing on frames. + track_line (List[Tuple[int, int]]): List of track line coordinates. + track_history (Dict[int, List[Tuple[int, int]]]): Dictionary storing tracking history for each object. + + Methods: + initialize_region: Initializes the queue region. + process_queue: Processes a single frame for queue management. + extract_tracks: Extracts object tracks from the current frame. + store_tracking_history: Stores the tracking history for an object. + display_output: Displays the processed output. + + Examples: + >>> queue_manager = QueueManager(source="video.mp4", region=[100, 100, 200, 200, 300, 300]) + >>> for frame in video_stream: + ... processed_frame = queue_manager.process_queue(frame) + ... cv2.imshow("Queue Management", processed_frame) + """ def __init__(self, **kwargs): - """Initializes the QueueManager with specified parameters for tracking and counting objects.""" + """Initializes the QueueManager with parameters for tracking and counting objects in a video stream.""" super().__init__(**kwargs) self.initialize_region() self.counts = 0 # Queue counts Information @@ -19,12 +43,31 @@ class QueueManager(BaseSolution): def process_queue(self, im0): """ - Main function to start the queue management process. + Processes the queue management for a single frame of video. Args: - im0 (ndarray): The input image that will be used for processing - Returns - im0 (ndarray): The processed image for more usage + im0 (numpy.ndarray): Input image for processing, typically a frame from a video stream. + + Returns: + (numpy.ndarray): Processed image with annotations, bounding boxes, and queue counts. + + This method performs the following steps: + 1. Resets the queue count for the current frame. + 2. Initializes an Annotator object for drawing on the image. + 3. Extracts tracks from the image. + 4. Draws the counting region on the image. + 5. For each detected object: + - Draws bounding boxes and labels. + - Stores tracking history. + - Draws centroids and tracks. + - Checks if the object is inside the counting region and updates the count. + 6. Displays the queue count on the image. + 7. Displays the processed output. + + Examples: + >>> queue_manager = QueueManager() + >>> frame = cv2.imread("frame.jpg") + >>> processed_frame = queue_manager.process_queue(frame) """ self.counts = 0 # Reset counts every frame self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator @@ -48,8 +91,10 @@ class QueueManager(BaseSolution): track_history = self.track_history.get(track_id, []) # store previous position of track and check if the object is inside the counting region - prev_position = track_history[-2] if len(track_history) > 1 else None - if self.region_length >= 3 and prev_position and self.r_s.contains(Point(self.track_line[-1])): + prev_position = None + if len(track_history) > 1: + prev_position = track_history[-2] + if self.region_length >= 3 and prev_position and self.r_s.contains(self.Point(self.track_line[-1])): self.counts += 1 # Display queue counts diff --git a/ultralytics/solutions/solutions.py b/ultralytics/solutions/solutions.py index 71a92becf..1af0c0ba0 100644 --- a/ultralytics/solutions/solutions.py +++ b/ultralytics/solutions/solutions.py @@ -9,21 +9,51 @@ from ultralytics import YOLO from ultralytics.utils import LOGGER, yaml_load from ultralytics.utils.checks import check_imshow, check_requirements -check_requirements("shapely>=2.0.0") -from shapely.geometry import LineString, Polygon - DEFAULT_SOL_CFG_PATH = Path(__file__).resolve().parents[1] / "cfg/solutions/default.yaml" class BaseSolution: - """A class to manage all the Ultralytics Solutions: https://docs.ultralytics.com/solutions/.""" + """ + A base class for managing Ultralytics Solutions. + + This class provides core functionality for various Ultralytics Solutions, including model loading, object tracking, + and region initialization. + + Attributes: + LineString (shapely.geometry.LineString): Class for creating line string geometries. + Polygon (shapely.geometry.Polygon): Class for creating polygon geometries. + Point (shapely.geometry.Point): Class for creating point geometries. + CFG (Dict): Configuration dictionary loaded from a YAML file and updated with kwargs. + region (List[Tuple[int, int]]): List of coordinate tuples defining a region of interest. + line_width (int): Width of lines used in visualizations. + model (ultralytics.YOLO): Loaded YOLO model instance. + names (Dict[int, str]): Dictionary mapping class indices to class names. + env_check (bool): Flag indicating whether the environment supports image display. + track_history (collections.defaultdict): Dictionary to store tracking history for each object. + + Methods: + extract_tracks: Apply object tracking and extract tracks from an input image. + store_tracking_history: Store object tracking history for a given track ID and bounding box. + initialize_region: Initialize the counting region and line segment based on configuration. + display_output: Display the results of processing, including showing frames or saving results. + + Examples: + >>> solution = BaseSolution(model="yolov8n.pt", region=[(0, 0), (100, 0), (100, 100), (0, 100)]) + >>> solution.initialize_region() + >>> image = cv2.imread("image.jpg") + >>> solution.extract_tracks(image) + >>> solution.display_output(image) + """ def __init__(self, **kwargs): - """ - Base initializer for all solutions. + """Initializes the BaseSolution class with configuration settings and YOLO model for Ultralytics solutions.""" + check_requirements("shapely>=2.0.0") + from shapely.geometry import LineString, Point, Polygon + + self.LineString = LineString + self.Polygon = Polygon + self.Point = Point - Child classes should call this with necessary parameters. - """ # Load config and update with args self.CFG = yaml_load(DEFAULT_SOL_CFG_PATH) self.CFG.update(kwargs) @@ -42,10 +72,15 @@ class BaseSolution: def extract_tracks(self, im0): """ - Apply object tracking and extract tracks. + Applies object tracking and extracts tracks from an input image or frame. Args: - im0 (ndarray): The input image or frame + im0 (ndarray): The input image or frame. + + Examples: + >>> solution = BaseSolution() + >>> frame = cv2.imread("path/to/image.jpg") + >>> solution.extract_tracks(frame) """ self.tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"]) @@ -62,11 +97,18 @@ class BaseSolution: def store_tracking_history(self, track_id, box): """ - Store object tracking history. + Stores the tracking history of an object. + + This method updates the tracking history for a given object by appending the center point of its + bounding box to the track line. It maintains a maximum of 30 points in the tracking history. Args: - track_id (int): The track ID of the object - box (list): Bounding box coordinates of the object + track_id (int): The unique identifier for the tracked object. + box (List[float]): The bounding box coordinates of the object in the format [x1, y1, x2, y2]. + + Examples: + >>> solution = BaseSolution() + >>> solution.store_tracking_history(1, [100, 200, 300, 400]) """ # Store tracking history self.track_line = self.track_history[track_id] @@ -75,19 +117,32 @@ class BaseSolution: self.track_line.pop(0) def initialize_region(self): - """Initialize the counting region and line segment based on config.""" - self.region = [(20, 400), (1080, 404), (1080, 360), (20, 360)] if self.region is None else self.region - self.r_s = Polygon(self.region) if len(self.region) >= 3 else LineString(self.region) # region segment - self.l_s = LineString( - [(self.region[0][0], self.region[0][1]), (self.region[1][0], self.region[1][1])] - ) # line segment + """Initialize the counting region and line segment based on configuration settings.""" + if self.region is None: + self.region = [(20, 400), (1080, 404), (1080, 360), (20, 360)] + self.r_s = ( + self.Polygon(self.region) if len(self.region) >= 3 else self.LineString(self.region) + ) # region or line def display_output(self, im0): """ Display the results of the processing, which could involve showing frames, printing counts, or saving results. + This method is responsible for visualizing the output of the object detection and tracking process. It displays + the processed frame with annotations, and allows for user interaction to close the display. + Args: - im0 (ndarray): The input image or frame + im0 (numpy.ndarray): The input image or frame that has been processed and annotated. + + Examples: + >>> solution = BaseSolution() + >>> frame = cv2.imread("path/to/image.jpg") + >>> solution.display_output(frame) + + Notes: + - This method will only display output if the 'show' configuration is set to True and the environment + supports image display. + - The display can be closed by pressing the 'q' key. """ if self.CFG.get("show") and self.env_check: cv2.imshow("Ultralytics Solutions", im0) diff --git a/ultralytics/solutions/speed_estimation.py b/ultralytics/solutions/speed_estimation.py index decd159b5..0c4bc5f05 100644 --- a/ultralytics/solutions/speed_estimation.py +++ b/ultralytics/solutions/speed_estimation.py @@ -4,15 +4,43 @@ from time import time import numpy as np -from ultralytics.solutions.solutions import BaseSolution, LineString +from ultralytics.solutions.solutions import BaseSolution from ultralytics.utils.plotting import Annotator, colors class SpeedEstimator(BaseSolution): - """A class to estimate the speed of objects in a real-time video stream based on their tracks.""" + """ + A class to estimate the speed of objects in a real-time video stream based on their tracks. + + This class extends the BaseSolution class and provides functionality for estimating object speeds using + tracking data in video streams. + + Attributes: + spd (Dict[int, float]): Dictionary storing speed data for tracked objects. + trkd_ids (List[int]): List of tracked object IDs that have already been speed-estimated. + trk_pt (Dict[int, float]): Dictionary storing previous timestamps for tracked objects. + trk_pp (Dict[int, Tuple[float, float]]): Dictionary storing previous positions for tracked objects. + annotator (Annotator): Annotator object for drawing on images. + region (List[Tuple[int, int]]): List of points defining the speed estimation region. + track_line (List[Tuple[float, float]]): List of points representing the object's track. + r_s (LineString): LineString object representing the speed estimation region. + + Methods: + initialize_region: Initializes the speed estimation region. + estimate_speed: Estimates the speed of objects based on tracking data. + store_tracking_history: Stores the tracking history for an object. + extract_tracks: Extracts tracks from the current frame. + display_output: Displays the output with annotations. + + Examples: + >>> estimator = SpeedEstimator() + >>> frame = cv2.imread("frame.jpg") + >>> processed_frame = estimator.estimate_speed(frame) + >>> cv2.imshow("Speed Estimation", processed_frame) + """ def __init__(self, **kwargs): - """Initializes the SpeedEstimator with the given parameters.""" + """Initializes the SpeedEstimator object with speed estimation parameters and data structures.""" super().__init__(**kwargs) self.initialize_region() # Initialize speed region @@ -27,9 +55,15 @@ class SpeedEstimator(BaseSolution): Estimates the speed of objects based on tracking data. Args: - im0 (ndarray): The input image that will be used for processing - Returns - im0 (ndarray): The processed image for more usage + im0 (np.ndarray): Input image for processing. Shape is typically (H, W, C) for RGB images. + + Returns: + (np.ndarray): Processed image with speed estimations and annotations. + + Examples: + >>> estimator = SpeedEstimator() + >>> image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) + >>> processed_image = estimator.estimate_speed(image) """ self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator self.extract_tracks(im0) # Extract tracks @@ -56,7 +90,7 @@ class SpeedEstimator(BaseSolution): ) # Calculate object speed and direction based on region intersection - if LineString([self.trk_pp[track_id], self.track_line[-1]]).intersects(self.l_s): + if self.LineString([self.trk_pp[track_id], self.track_line[-1]]).intersects(self.r_s): direction = "known" else: direction = "unknown" diff --git a/ultralytics/solutions/streamlit_inference.py b/ultralytics/solutions/streamlit_inference.py index f38cceb3c..dcae3add7 100644 --- a/ultralytics/solutions/streamlit_inference.py +++ b/ultralytics/solutions/streamlit_inference.py @@ -11,7 +11,7 @@ from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS def inference(model=None): - """Runs real-time object detection on video input using Ultralytics YOLO11 in a Streamlit application.""" + """Performs real-time object detection on video input using YOLO in a Streamlit web application.""" check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds import streamlit as st @@ -108,7 +108,7 @@ def inference(model=None): st.warning("Failed to read frame from webcam. Please make sure the webcam is connected properly.") break - prev_time = time.time() + prev_time = time.time() # Store initial time for FPS calculation # Store model predictions if enable_trk == "Yes": @@ -120,7 +120,6 @@ def inference(model=None): # Calculate model FPS curr_time = time.time() fps = 1 / (curr_time - prev_time) - prev_time = curr_time # display frame org_frame.image(frame, channels="BGR")