Introduced `BaseSolution` class for Ultralytics solutions (#16671)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/16690/head
Muhammad Rizwan Munawar 1 month ago committed by GitHub
parent e5d3427a52
commit 70ba988c68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 140
      docs/en/guides/object-counting.md
  2. 16
      docs/en/reference/solutions/solutions.md
  3. 4
      tests/test_solutions.py
  4. 12
      ultralytics/cfg/solutions/default.yaml
  5. 318
      ultralytics/solutions/object_counter.py
  6. 88
      ultralytics/solutions/solutions.py

@ -53,9 +53,8 @@ Object counting with [Ultralytics YOLO11](https://github.com/ultralytics/ultraly
```python ```python
import cv2 import cv2
from ultralytics import YOLO, solutions from ultralytics import solutions
model = YOLO("yolo11n.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4") cap = cv2.VideoCapture("path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file" assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
@ -68,21 +67,18 @@ Object counting with [Ultralytics YOLO11](https://github.com/ultralytics/ultraly
# Init Object Counter # Init Object Counter
counter = solutions.ObjectCounter( counter = solutions.ObjectCounter(
view_img=True, show=True,
reg_pts=region_points, region=region_points,
names=model.names, model="yolo11n.pt",
draw_tracks=True,
line_thickness=2,
) )
# Process video
while cap.isOpened(): while cap.isOpened():
success, im0 = cap.read() success, im0 = cap.read()
if not success: if not success:
print("Video frame is empty or video processing has been successfully completed.") print("Video frame is empty or video processing has been successfully completed.")
break break
tracks = model.track(im0, persist=True, show=False) im0 = counter.count(im0)
im0 = counter.start_counting(im0, tracks)
video_writer.write(im0) video_writer.write(im0)
cap.release() cap.release()
@ -95,34 +91,32 @@ Object counting with [Ultralytics YOLO11](https://github.com/ultralytics/ultraly
```python ```python
import cv2 import cv2
from ultralytics import YOLO, solutions from ultralytics import solutions
model = YOLO("yolo11n-obb.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4") cap = cv2.VideoCapture("path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file" assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
# Define region points # line or region points
region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)] line_points = [(20, 400), (1080, 400)]
# Video writer # Video writer
video_writer = cv2.VideoWriter("object_counting_output.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) video_writer = cv2.VideoWriter("object_counting_output.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
# Init Object Counter # Init Object Counter
counter = solutions.ObjectCounter( counter = solutions.ObjectCounter(
view_img=True, show=True,
reg_pts=region_points, region=line_points,
names=model.names, model="yolo11n-obb.pt",
line_thickness=2,
) )
# Process video
while cap.isOpened(): while cap.isOpened():
success, im0 = cap.read() success, im0 = cap.read()
if not success: if not success:
print("Video frame is empty or video processing has been successfully completed.") print("Video frame is empty or video processing has been successfully completed.")
break break
tracks = model.track(im0, persist=True, show=False) im0 = counter.count(im0)
im0 = counter.start_counting(im0, tracks)
video_writer.write(im0) video_writer.write(im0)
cap.release() cap.release()
@ -135,14 +129,13 @@ Object counting with [Ultralytics YOLO11](https://github.com/ultralytics/ultraly
```python ```python
import cv2 import cv2
from ultralytics import YOLO, solutions from ultralytics import solutions
model = YOLO("yolo11n.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4") cap = cv2.VideoCapture("path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file" assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
# Define region points as a polygon with 5 points # Define region points
region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360), (20, 400)] region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360), (20, 400)]
# Video writer # Video writer
@ -150,20 +143,18 @@ Object counting with [Ultralytics YOLO11](https://github.com/ultralytics/ultraly
# Init Object Counter # Init Object Counter
counter = solutions.ObjectCounter( counter = solutions.ObjectCounter(
view_img=True, show=True,
reg_pts=region_points, region=region_points,
names=model.names, model="yolo11n.pt",
draw_tracks=True,
line_thickness=2,
) )
# Process video
while cap.isOpened(): while cap.isOpened():
success, im0 = cap.read() success, im0 = cap.read()
if not success: if not success:
print("Video frame is empty or video processing has been successfully completed.") print("Video frame is empty or video processing has been successfully completed.")
break break
tracks = model.track(im0, persist=True, show=False) im0 = counter.count(im0)
im0 = counter.start_counting(im0, tracks)
video_writer.write(im0) video_writer.write(im0)
cap.release() cap.release()
@ -176,14 +167,13 @@ Object counting with [Ultralytics YOLO11](https://github.com/ultralytics/ultraly
```python ```python
import cv2 import cv2
from ultralytics import YOLO, solutions from ultralytics import solutions
model = YOLO("yolo11n.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4") cap = cv2.VideoCapture("path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file" assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
# Define line points # Define region points
line_points = [(20, 400), (1080, 400)] line_points = [(20, 400), (1080, 400)]
# Video writer # Video writer
@ -191,20 +181,18 @@ Object counting with [Ultralytics YOLO11](https://github.com/ultralytics/ultraly
# Init Object Counter # Init Object Counter
counter = solutions.ObjectCounter( counter = solutions.ObjectCounter(
view_img=True, show=True,
reg_pts=line_points, region=line_points,
names=model.names, model="yolo11n.pt",
draw_tracks=True,
line_thickness=2,
) )
# Process video
while cap.isOpened(): while cap.isOpened():
success, im0 = cap.read() success, im0 = cap.read()
if not success: if not success:
print("Video frame is empty or video processing has been successfully completed.") print("Video frame is empty or video processing has been successfully completed.")
break break
tracks = model.track(im0, persist=True, show=False) im0 = counter.count(im0)
im0 = counter.start_counting(im0, tracks)
video_writer.write(im0) video_writer.write(im0)
cap.release() cap.release()
@ -217,35 +205,29 @@ Object counting with [Ultralytics YOLO11](https://github.com/ultralytics/ultraly
```python ```python
import cv2 import cv2
from ultralytics import YOLO, solutions from ultralytics import solutions
model = YOLO("yolo11n.pt")
cap = cv2.VideoCapture("path/to/video/file.mp4") cap = cv2.VideoCapture("path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file" assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
line_points = [(20, 400), (1080, 400)] # line or region points
classes_to_count = [0, 2] # person and car classes for count
# Video writer # Video writer
video_writer = cv2.VideoWriter("object_counting_output.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) video_writer = cv2.VideoWriter("object_counting_output.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
# Init Object Counter # Init Object Counter
counter = solutions.ObjectCounter( counter = solutions.ObjectCounter(
view_img=True, show=True,
reg_pts=line_points, model="yolo11n.pt",
names=model.names, classes=[0, 1],
draw_tracks=True,
line_thickness=2,
) )
# Process video
while cap.isOpened(): while cap.isOpened():
success, im0 = cap.read() success, im0 = cap.read()
if not success: if not success:
print("Video frame is empty or video processing has been successfully completed.") print("Video frame is empty or video processing has been successfully completed.")
break break
tracks = model.track(im0, persist=True, show=False, classes=classes_to_count) im0 = counter.count(im0)
im0 = counter.start_counting(im0, tracks)
video_writer.write(im0) video_writer.write(im0)
cap.release() cap.release()
@ -253,23 +235,18 @@ Object counting with [Ultralytics YOLO11](https://github.com/ultralytics/ultraly
cv2.destroyAllWindows() cv2.destroyAllWindows()
``` ```
???+ tip "Region is Movable"
You can move the region anywhere in the frame by clicking on its edges
### Argument `ObjectCounter` ### Argument `ObjectCounter`
Here's a table with the `ObjectCounter` arguments: Here's a table with the `ObjectCounter` arguments:
| Name | Type | Default | Description | | Name | Type | Default | Description |
| ----------------- | ------ | -------------------------- | ---------------------------------------------------------------------- | | ------------ | ------ | -------------------------- | ---------------------------------------------------------------------- |
| `names` | `dict` | `None` | Dictionary of classes names. | | `model` | `str` | `None` | Path to Ultralytics YOLO Model File |
| `reg_pts` | `list` | `[(20, 400), (1260, 400)]` | List of points defining the counting region. | | `region` | `list` | `[(20, 400), (1260, 400)]` | List of points defining the counting region. |
| `line_thickness` | `int` | `2` | Line thickness for bounding boxes. | | `line_width` | `int` | `2` | Line thickness for bounding boxes. |
| `view_img` | `bool` | `False` | Flag to control whether to display the video stream. | | `show` | `bool` | `False` | Flag to control whether to display the video stream. |
| `view_in_counts` | `bool` | `True` | Flag to control whether to display the in counts on the video stream. | | `show_in` | `bool` | `True` | Flag to control whether to display the in counts on the video stream. |
| `view_out_counts` | `bool` | `True` | Flag to control whether to display the out counts on the video stream. | | `show_out` | `bool` | `True` | Flag to control whether to display the out counts on the video stream. |
| `draw_tracks` | `bool` | `False` | Flag to control whether to draw the object tracks. |
### Arguments `model.track` ### Arguments `model.track`
@ -282,38 +259,34 @@ Here's a table with the `ObjectCounter` arguments:
To count objects in a video using Ultralytics YOLO11, you can follow these steps: To count objects in a video using Ultralytics YOLO11, you can follow these steps:
1. Import the necessary libraries (`cv2`, `ultralytics`). 1. Import the necessary libraries (`cv2`, `ultralytics`).
2. Load a pretrained YOLO11 model. 2. Define the counting region (e.g., a polygon, line, etc.).
3. Define the counting region (e.g., a polygon, line, etc.). 3. Set up the video capture and initialize the object counter.
4. Set up the video capture and initialize the object counter. 4. Process each frame to track objects and count them within the defined region.
5. Process each frame to track objects and count them within the defined region.
Here's a simple example for counting in a region: Here's a simple example for counting in a region:
```python ```python
import cv2 import cv2
from ultralytics import YOLO, solutions from ultralytics import solutions
def count_objects_in_region(video_path, output_video_path, model_path): def count_objects_in_region(video_path, output_video_path, model_path):
"""Count objects in a specific region within a video.""" """Count objects in a specific region within a video."""
model = YOLO(model_path)
cap = cv2.VideoCapture(video_path) cap = cv2.VideoCapture(video_path)
assert cap.isOpened(), "Error reading video file" assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
video_writer = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) video_writer = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
counter = solutions.ObjectCounter(
view_img=True, reg_pts=region_points, names=model.names, draw_tracks=True, line_thickness=2 region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
) counter = solutions.ObjectCounter(show=True, region=region_points, model=model_path)
while cap.isOpened(): while cap.isOpened():
success, im0 = cap.read() success, im0 = cap.read()
if not success: if not success:
print("Video frame is empty or video processing has been successfully completed.") print("Video frame is empty or video processing has been successfully completed.")
break break
tracks = model.track(im0, persist=True, show=False) im0 = counter.start_counting(im0)
im0 = counter.start_counting(im0, tracks)
video_writer.write(im0) video_writer.write(im0)
cap.release() cap.release()
@ -343,28 +316,25 @@ To count specific classes of objects using Ultralytics YOLO11, you need to speci
```python ```python
import cv2 import cv2
from ultralytics import YOLO, solutions from ultralytics import solutions
def count_specific_classes(video_path, output_video_path, model_path, classes_to_count): def count_specific_classes(video_path, output_video_path, model_path, classes_to_count):
"""Count specific classes of objects in a video.""" """Count specific classes of objects in a video."""
model = YOLO(model_path)
cap = cv2.VideoCapture(video_path) cap = cv2.VideoCapture(video_path)
assert cap.isOpened(), "Error reading video file" assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
line_points = [(20, 400), (1080, 400)]
video_writer = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) video_writer = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
counter = solutions.ObjectCounter(
view_img=True, reg_pts=line_points, names=model.names, draw_tracks=True, line_thickness=2 line_points = [(20, 400), (1080, 400)]
) counter = solutions.ObjectCounter(show=True, region=line_points, model=model_path, classes=classes_to_count)
while cap.isOpened(): while cap.isOpened():
success, im0 = cap.read() success, im0 = cap.read()
if not success: if not success:
print("Video frame is empty or video processing has been successfully completed.") print("Video frame is empty or video processing has been successfully completed.")
break break
tracks = model.track(im0, persist=True, show=False, classes=classes_to_count) im0 = counter.start_counting(im0)
im0 = counter.start_counting(im0, tracks)
video_writer.write(im0) video_writer.write(im0)
cap.release() cap.release()

@ -0,0 +1,16 @@
---
description: Explore the Ultralytics Solution Base class for real-time object counting,virtual gym, heatmaps, speed estimation using Ultralytics YOLO. Learn to implement Ultralytics solutions effectively.
keywords: Ultralytics, Solutions, Object counting, Speed Estimation, Heatmaps, Queue Management, AI Gym, YOLO, pose detection, gym step counting, real-time pose estimation, Python
---
# Reference for `ultralytics/solutions/solutions.py`
!!! note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/solutions.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/solutions.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/solutions/solutions.py) 🛠. Thank you 🙏!
<br>
## ::: ultralytics.solutions.solutions.BaseSolution
<br><br>

@ -19,7 +19,7 @@ def test_major_solutions():
cap = cv2.VideoCapture("solutions_ci_demo.mp4") cap = cv2.VideoCapture("solutions_ci_demo.mp4")
assert cap.isOpened(), "Error reading video file" assert cap.isOpened(), "Error reading video file"
region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)] region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
counter = solutions.ObjectCounter(reg_pts=region_points, names=names, view_img=False) # counter = solutions.ObjectCounter(reg_pts=region_points, names=names, view_img=False)
heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, names=names, view_img=False) heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, names=names, view_img=False)
speed = solutions.SpeedEstimator(reg_pts=region_points, names=names, view_img=False) speed = solutions.SpeedEstimator(reg_pts=region_points, names=names, view_img=False)
queue = solutions.QueueManager(names=names, reg_pts=region_points, view_img=False) queue = solutions.QueueManager(names=names, reg_pts=region_points, view_img=False)
@ -29,7 +29,7 @@ def test_major_solutions():
break break
original_im0 = im0.copy() original_im0 = im0.copy()
tracks = model.track(im0, persist=True, show=False) tracks = model.track(im0, persist=True, show=False)
_ = counter.start_counting(original_im0.copy(), tracks) # _ = counter.start_counting(original_im0.copy(), tracks)
_ = heatmap.generate_heatmap(original_im0.copy(), tracks) _ = heatmap.generate_heatmap(original_im0.copy(), tracks)
_ = speed.estimate_speed(original_im0.copy(), tracks) _ = speed.estimate_speed(original_im0.copy(), tracks)
_ = queue.process_queue(original_im0.copy(), tracks) _ = queue.process_queue(original_im0.copy(), tracks)

@ -0,0 +1,12 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# Configuration for Ultralytics Solutions
model: "yolo11n.pt" # The Ultralytics YOLO11 model to be used (e.g., yolo11n.pt for YOLO11 nano version)
region: # Object counting, queue or speed estimation region points
line_width: 2 # Thickness of the lines used to draw regions on the image/video frames
show: True # Flag to control whether to display output image or not
show_in: True # Flag to display objects moving *into* the defined region
show_out: True # Flag to display objects moving *out of* the defined region
classes: # To count specific classes

@ -1,243 +1,129 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
from collections import defaultdict from shapely.geometry import LineString, Point
import cv2 from ultralytics.solutions.solutions import BaseSolution # Import a parent class
from ultralytics.utils.checks import check_imshow, check_requirements
from ultralytics.utils.plotting import Annotator, colors from ultralytics.utils.plotting import Annotator, colors
check_requirements("shapely>=2.0.0")
from shapely.geometry import LineString, Point, Polygon class ObjectCounter(BaseSolution):
"""A class to manage the counting of objects in a real-time video stream based on their tracks."""
def __init__(self, **kwargs):
"""Initialization function for Count class, a child class of BaseSolution class, can be used for counting the
objects.
"""
super().__init__(**kwargs)
class ObjectCounter: self.in_count = 0 # Counter for objects moving inward
"""A class to manage the counting of objects in a real-time video stream based on their tracks.""" self.out_count = 0 # Counter for objects moving outward
self.counted_ids = [] # List of IDs of objects that have been counted
self.classwise_counts = {} # Dictionary for counts, categorized by object class
def __init__( self.initialize_region() # Setup region and counting areas
self,
names, self.show_in = self.CFG["show_in"]
reg_pts=None, self.show_out = self.CFG["show_out"]
line_thickness=2,
view_img=False, def count_objects(self, track_line, box, track_id, prev_position, cls):
view_in_counts=True,
view_out_counts=True,
draw_tracks=False,
):
""" """
Initializes the ObjectCounter with various tracking and counting parameters. Helper function to count objects within a polygonal region.
Args: Args:
names (dict): Dictionary of class names. track_line (dict): last 30 frame track record
reg_pts (list): List of points defining the counting region. box (list): Bounding box data for specific track in current frame
line_thickness (int): Line thickness for bounding boxes. track_id (int): track ID of the object
view_img (bool): Flag to control whether to display the video stream. prev_position (tuple): last frame position coordinates of the track
view_in_counts (bool): Flag to control whether to display the in counts on the video stream. cls (int): Class index for classwise count updates
view_out_counts (bool): Flag to control whether to display the out counts on the video stream.
draw_tracks (bool): Flag to control whether to draw the object tracks.
""" """
# Mouse events if prev_position is None or track_id in self.counted_ids:
self.is_drawing = False return
self.selected_point = None
centroid = self.r_s.centroid
# Region & Line Information dx = (box[0] - prev_position[0]) * (centroid.x - prev_position[0])
self.reg_pts = [(20, 400), (1260, 400)] if reg_pts is None else reg_pts dy = (box[1] - prev_position[1]) * (centroid.y - prev_position[1])
self.counting_region = None
if len(self.region) >= 3 and self.r_s.contains(Point(track_line[-1])):
# Image and annotation Information self.counted_ids.append(track_id)
self.im0 = None # For polygon region
self.tf = line_thickness if dx > 0:
self.view_img = view_img self.in_count += 1
self.view_in_counts = view_in_counts self.classwise_counts[self.names[cls]]["IN"] += 1
self.view_out_counts = view_out_counts else:
self.out_count += 1
self.names = names # Classes names self.classwise_counts[self.names[cls]]["OUT"] += 1
self.window_name = "Ultralytics YOLOv8 Object Counter"
elif len(self.region) < 3 and LineString([prev_position, box[:2]]).intersects(self.l_s):
# Object counting Information self.counted_ids.append(track_id)
self.in_counts = 0 # For linear region
self.out_counts = 0 if dx > 0 and dy > 0:
self.count_ids = [] self.in_count += 1
self.class_wise_count = {} self.classwise_counts[self.names[cls]]["IN"] += 1
else:
# Tracks info self.out_count += 1
self.track_history = defaultdict(list) self.classwise_counts[self.names[cls]]["OUT"] += 1
self.draw_tracks = draw_tracks
def store_classwise_counts(self, cls):
# Check if environment supports imshow
self.env_check = check_imshow(warn=True)
# Initialize counting region
if len(self.reg_pts) == 2:
print("Line Counter Initiated.")
self.counting_region = LineString(self.reg_pts)
elif len(self.reg_pts) >= 3:
print("Polygon Counter Initiated.")
self.counting_region = Polygon(self.reg_pts)
else:
print("Invalid Region points provided, region_points must be 2 for lines or >= 3 for polygons.")
print("Using Line Counter Now")
self.counting_region = LineString(self.reg_pts)
# Define the counting line segment
self.counting_line_segment = LineString(
[
(self.reg_pts[0][0], self.reg_pts[0][1]),
(self.reg_pts[1][0], self.reg_pts[1][1]),
]
)
def mouse_event_for_region(self, event, x, y, flags, params):
""" """
Handles mouse events for defining and moving the counting region in a real-time video stream. Initialize class-wise counts if not already present.
Args: Args:
event (int): The type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN, etc.). cls (int): Class index for classwise count updates
x (int): The x-coordinate of the mouse pointer.
y (int): The y-coordinate of the mouse pointer.
flags (int): Any associated event flags (e.g., cv2.EVENT_FLAG_CTRLKEY, cv2.EVENT_FLAG_SHIFTKEY, etc.).
params (dict): Additional parameters for the function.
""" """
if event == cv2.EVENT_LBUTTONDOWN: if self.names[cls] not in self.classwise_counts:
for i, point in enumerate(self.reg_pts): self.classwise_counts[self.names[cls]] = {"IN": 0, "OUT": 0}
if (
isinstance(point, (tuple, list))
and len(point) >= 2
and (abs(x - point[0]) < 10 and abs(y - point[1]) < 10)
):
self.selected_point = i
self.is_drawing = True
break
elif event == cv2.EVENT_MOUSEMOVE:
if self.is_drawing and self.selected_point is not None:
self.reg_pts[self.selected_point] = (x, y)
self.counting_region = Polygon(self.reg_pts)
elif event == cv2.EVENT_LBUTTONUP:
self.is_drawing = False
self.selected_point = None
def extract_and_process_tracks(self, tracks):
"""Extracts and processes tracks for object counting in a video stream."""
# Annotator Init and region drawing
annotator = Annotator(self.im0, self.tf, self.names)
# Draw region or line
annotator.draw_region(reg_pts=self.reg_pts, color=(104, 0, 123), thickness=self.tf * 2)
# Extract tracks for OBB or object detection
track_data = tracks[0].obb or tracks[0].boxes
if track_data and track_data.id is not None:
boxes = track_data.xyxy.cpu()
clss = track_data.cls.cpu().tolist()
track_ids = track_data.id.int().cpu().tolist()
# Extract tracks
for box, track_id, cls in zip(boxes, track_ids, clss):
# Draw bounding box
annotator.box_label(box, label=self.names[cls], color=colors(int(track_id), True))
# Store class info
if self.names[cls] not in self.class_wise_count:
self.class_wise_count[self.names[cls]] = {"IN": 0, "OUT": 0}
# Draw Tracks
track_line = self.track_history[track_id]
track_line.append((float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2)))
if len(track_line) > 30:
track_line.pop(0)
# Draw track trails
if self.draw_tracks:
annotator.draw_centroid_and_tracks(
track_line,
color=colors(int(track_id), True),
track_thickness=self.tf,
)
prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None def display_counts(self, im0):
"""
Helper function to display object counts on the frame.
# Count objects in any polygon Args:
if len(self.reg_pts) >= 3: im0 (ndarray): The input image or frame
is_inside = self.counting_region.contains(Point(track_line[-1])) """
labels_dict = {
if prev_position is not None and is_inside and track_id not in self.count_ids: str.capitalize(key): f"{'IN ' + str(value['IN']) if self.show_in else ''} "
self.count_ids.append(track_id) f"{'OUT ' + str(value['OUT']) if self.show_out else ''}".strip()
for key, value in self.classwise_counts.items()
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0: if value["IN"] != 0 or value["OUT"] != 0
self.in_counts += 1 }
self.class_wise_count[self.names[cls]]["IN"] += 1
else:
self.out_counts += 1
self.class_wise_count[self.names[cls]]["OUT"] += 1
# Count objects using line
elif len(self.reg_pts) == 2:
if (
prev_position is not None
and track_id not in self.count_ids
and LineString([(prev_position[0], prev_position[1]), (box[0], box[1])]).intersects(
self.counting_line_segment
)
):
self.count_ids.append(track_id)
# Determine the direction of movement (IN or OUT)
dx = (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0])
dy = (box[1] - prev_position[1]) * (self.counting_region.centroid.y - prev_position[1])
if dx > 0 and dy > 0:
self.in_counts += 1
self.class_wise_count[self.names[cls]]["IN"] += 1
else:
self.out_counts += 1
self.class_wise_count[self.names[cls]]["OUT"] += 1
labels_dict = {}
for key, value in self.class_wise_count.items():
if value["IN"] != 0 or value["OUT"] != 0:
if not self.view_in_counts and not self.view_out_counts:
continue
elif not self.view_in_counts:
labels_dict[str.capitalize(key)] = f"OUT {value['OUT']}"
elif not self.view_out_counts:
labels_dict[str.capitalize(key)] = f"IN {value['IN']}"
else:
labels_dict[str.capitalize(key)] = f"IN {value['IN']} OUT {value['OUT']}"
if labels_dict: if labels_dict:
annotator.display_analytics(self.im0, labels_dict, (104, 31, 17), (255, 255, 255), 10) self.annotator.display_analytics(im0, labels_dict, (104, 31, 17), (255, 255, 255), 10)
def display_frames(self): def count(self, im0):
"""Displays the current frame with annotations and regions in a window."""
if self.env_check:
cv2.namedWindow(self.window_name)
if len(self.reg_pts) == 4: # only add mouse event If user drawn region
cv2.setMouseCallback(self.window_name, self.mouse_event_for_region, {"region_points": self.reg_pts})
cv2.imshow(self.window_name, self.im0)
# Break Window
if cv2.waitKey(1) & 0xFF == ord("q"):
return
def start_counting(self, im0, tracks):
""" """
Main function to start the object counting process. Processes input data (frames or object tracks) and updates counts.
Args: Args:
im0 (ndarray): Current frame from the video stream. im0 (ndarray): The input image that will be used for processing
tracks (list): List of tracks obtained from the object tracking process. Returns
im0 (ndarray): The processed image for more usage
""" """
self.im0 = im0 # store image self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator
self.extract_and_process_tracks(tracks) # draw region even if no objects self.extract_tracks(im0) # Extract tracks
if self.view_img: self.annotator.draw_region(
self.display_frames() reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2
return self.im0 ) # Draw region
# Iterate over bounding boxes, track ids and classes index
if self.track_data is not None and self.track_data.id is not None:
for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss):
# Draw bounding box and counting region
self.annotator.box_label(box, label=self.names[cls], color=colors(track_id, True))
self.store_tracking_history(track_id, box) # Store track history
self.store_classwise_counts(cls) # store classwise counts in dict
# Draw centroid of objects
self.annotator.draw_centroid_and_tracks(
self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width
)
# store previous position of track for object counting
prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
self.count_objects(self.track_line, box, track_id, prev_position, cls) # Perform object counting
self.display_counts(im0) # Display the counts on the frame
self.display_output(im0) # display output with base class function
if __name__ == "__main__": return im0 # return output image for more usage
classes_names = {0: "person", 1: "car"} # example class names
ObjectCounter(classes_names)

@ -0,0 +1,88 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from collections import defaultdict
from pathlib import Path
import cv2
from shapely.geometry import LineString, Polygon
from ultralytics import YOLO
from ultralytics.utils import yaml_load
from ultralytics.utils.checks import check_imshow
DEFAULT_SOL_CFG_PATH = Path(__file__).resolve().parents[1] / "cfg/solutions/default.yaml"
class BaseSolution:
"""A class to manage all the Ultralytics Solutions: https://docs.ultralytics.com/solutions/."""
def __init__(self, **kwargs):
"""
Base initializer for all solutions.
Child classes should call this with necessary parameters.
"""
# Load config and update with args
self.CFG = yaml_load(DEFAULT_SOL_CFG_PATH)
self.CFG.update(kwargs)
print("Ultralytics Solutions: ✅", self.CFG)
self.region = self.CFG["region"] # Store region data for other classes usage
self.line_width = self.CFG["line_width"] # Store line_width for usage
# Load Model and store classes names
self.model = YOLO(self.CFG["model"])
self.names = self.model.names
# Initialize environment and region setup
self.env_check = check_imshow(warn=True)
self.track_history = defaultdict(list)
def extract_tracks(self, im0):
"""
Apply object tracking and extract tracks.
Args:
im0 (ndarray): The input image or frame
"""
self.tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"])
# Extract tracks for OBB or object detection
self.track_data = self.tracks[0].obb or self.tracks[0].boxes
if self.track_data and self.track_data.id is not None:
self.boxes = self.track_data.xyxy.cpu()
self.clss = self.track_data.cls.cpu().tolist()
self.track_ids = self.track_data.id.int().cpu().tolist()
def store_tracking_history(self, track_id, box):
"""
Store object tracking history.
Args:
track_id (int): The track ID of the object
box (list): Bounding box coordinates of the object
"""
# Store tracking history
self.track_line = self.track_history[track_id]
self.track_line.append(((box[0] + box[2]) / 2, (box[1] + box[3]) / 2))
if len(self.track_line) > 30:
self.track_line.pop(0)
def initialize_region(self):
"""Initialize the counting region and line segment based on config."""
self.region = [(20, 400), (1260, 400)] if self.region is None else self.region
self.r_s = Polygon(self.region) if len(self.region) >= 3 else LineString(self.region)
self.l_s = LineString([(self.region[0][0], self.region[0][1]), (self.region[1][0], self.region[1][1])])
def display_output(self, im0):
"""
Display the results of the processing, which could involve showing frames, printing counts, or saving results.
Args:
im0 (ndarray): The input image or frame
"""
if self.CFG.get("show") and self.env_check:
cv2.imshow("Ultralytics Solutions", im0)
if cv2.waitKey(1) & 0xFF == ord("q"):
return
Loading…
Cancel
Save