commit
b1f13a334c
101 changed files with 1736 additions and 1903 deletions
@ -1,17 +1,17 @@ |
||||
| Argument | Type | Default | Description | |
||||
| --------------- | -------------- | ---------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | |
||||
| `source` | `str` | `'ultralytics/assets'` | Specifies the data source for inference. Can be an image path, video file, directory, URL, or device ID for live feeds. Supports a wide range of formats and sources, enabling flexible application across [different types of input](/modes/predict.md/#inference-sources). | |
||||
| `conf` | `float` | `0.25` | Sets the minimum confidence threshold for detections. Objects detected with confidence below this threshold will be disregarded. Adjusting this value can help reduce false positives. | |
||||
| `iou` | `float` | `0.7` | [Intersection Over Union](https://www.ultralytics.com/glossary/intersection-over-union-iou) (IoU) threshold for Non-Maximum Suppression (NMS). Lower values result in fewer detections by eliminating overlapping boxes, useful for reducing duplicates. | |
||||
| `imgsz` | `int or tuple` | `640` | Defines the image size for inference. Can be a single integer `640` for square resizing or a (height, width) tuple. Proper sizing can improve detection [accuracy](https://www.ultralytics.com/glossary/accuracy) and processing speed. | |
||||
| `half` | `bool` | `False` | Enables half-[precision](https://www.ultralytics.com/glossary/precision) (FP16) inference, which can speed up model inference on supported GPUs with minimal impact on accuracy. | |
||||
| `device` | `str` | `None` | Specifies the device for inference (e.g., `cpu`, `cuda:0` or `0`). Allows users to select between CPU, a specific GPU, or other compute devices for model execution. | |
||||
| `max_det` | `int` | `300` | Maximum number of detections allowed per image. Limits the total number of objects the model can detect in a single inference, preventing excessive outputs in dense scenes. | |
||||
| `vid_stride` | `int` | `1` | Frame stride for video inputs. Allows skipping frames in videos to speed up processing at the cost of temporal resolution. A value of 1 processes every frame, higher values skip frames. | |
||||
| `stream_buffer` | `bool` | `False` | Determines the frame processing strategy for video streams. If `False` processing only the most recent frame, minimizing latency (optimized for real-time applications). If `True' processes all frames in order, ensuring no frames are skipped. | |
||||
| `visualize` | `bool` | `False` | Activates visualization of model features during inference, providing insights into what the model is "seeing". Useful for debugging and model interpretation. | |
||||
| `augment` | `bool` | `False` | Enables test-time augmentation (TTA) for predictions, potentially improving detection robustness at the cost of inference speed. | |
||||
| `agnostic_nms` | `bool` | `False` | Enables class-agnostic Non-Maximum Suppression (NMS), which merges overlapping boxes of different classes. Useful in multi-class detection scenarios where class overlap is common. | |
||||
| `classes` | `list[int]` | `None` | Filters predictions to a set of class IDs. Only detections belonging to the specified classes will be returned. Useful for focusing on relevant objects in multi-class detection tasks. | |
||||
| `retina_masks` | `bool` | `False` | Uses high-resolution segmentation masks if available in the model. This can enhance mask quality for segmentation tasks, providing finer detail. | |
||||
| `embed` | `list[int]` | `None` | Specifies the layers from which to extract feature vectors or [embeddings](https://www.ultralytics.com/glossary/embeddings). Useful for downstream tasks like clustering or similarity search. | |
||||
| Argument | Type | Default | Description | |
||||
| --------------- | -------------- | ---------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | |
||||
| `source` | `str` | `'ultralytics/assets'` | Specifies the data source for inference. Can be an image path, video file, directory, URL, or device ID for live feeds. Supports a wide range of formats and sources, enabling flexible application across [different types of input](/modes/predict.md/#inference-sources). | |
||||
| `conf` | `float` | `0.25` | Sets the minimum confidence threshold for detections. Objects detected with confidence below this threshold will be disregarded. Adjusting this value can help reduce false positives. | |
||||
| `iou` | `float` | `0.7` | [Intersection Over Union](https://www.ultralytics.com/glossary/intersection-over-union-iou) (IoU) threshold for Non-Maximum Suppression (NMS). Lower values result in fewer detections by eliminating overlapping boxes, useful for reducing duplicates. | |
||||
| `imgsz` | `int or tuple` | `640` | Defines the image size for inference. Can be a single integer `640` for square resizing or a (height, width) tuple. Proper sizing can improve detection [accuracy](https://www.ultralytics.com/glossary/accuracy) and processing speed. | |
||||
| `half` | `bool` | `False` | Enables half-[precision](https://www.ultralytics.com/glossary/precision) (FP16) inference, which can speed up model inference on supported GPUs with minimal impact on accuracy. | |
||||
| `device` | `str` | `None` | Specifies the device for inference (e.g., `cpu`, `cuda:0` or `0`). Allows users to select between CPU, a specific GPU, or other compute devices for model execution. | |
||||
| `max_det` | `int` | `300` | Maximum number of detections allowed per image. Limits the total number of objects the model can detect in a single inference, preventing excessive outputs in dense scenes. | |
||||
| `vid_stride` | `int` | `1` | Frame stride for video inputs. Allows skipping frames in videos to speed up processing at the cost of temporal resolution. A value of 1 processes every frame, higher values skip frames. | |
||||
| `stream_buffer` | `bool` | `False` | Determines whether to queue incoming frames for video streams. If `False`, old frames get dropped to accomodate new frames (optimized for real-time applications). If `True', queues new frames in a buffer, ensuring no frames get skipped, but will cause latency if inference FPS is lower than stream FPS. | |
||||
| `visualize` | `bool` | `False` | Activates visualization of model features during inference, providing insights into what the model is "seeing". Useful for debugging and model interpretation. | |
||||
| `augment` | `bool` | `False` | Enables test-time augmentation (TTA) for predictions, potentially improving detection robustness at the cost of inference speed. | |
||||
| `agnostic_nms` | `bool` | `False` | Enables class-agnostic Non-Maximum Suppression (NMS), which merges overlapping boxes of different classes. Useful in multi-class detection scenarios where class overlap is common. | |
||||
| `classes` | `list[int]` | `None` | Filters predictions to a set of class IDs. Only detections belonging to the specified classes will be returned. Useful for focusing on relevant objects in multi-class detection tasks. | |
||||
| `retina_masks` | `bool` | `False` | Uses high-resolution segmentation masks if available in the model. This can enhance mask quality for segmentation tasks, providing finer detail. | |
||||
| `embed` | `list[int]` | `None` | Specifies the layers from which to extract feature vectors or [embeddings](https://www.ultralytics.com/glossary/embeddings). Useful for downstream tasks like clustering or similarity search. | |
||||
|
@ -1,66 +0,0 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
import PIL |
||||
import pytest |
||||
|
||||
from ultralytics import Explorer |
||||
from ultralytics.utils import ASSETS |
||||
from ultralytics.utils.torch_utils import TORCH_1_13 |
||||
|
||||
|
||||
@pytest.mark.slow |
||||
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13") |
||||
def test_similarity(): |
||||
"""Test the correctness and response length of similarity calculations and SQL queries in the Explorer.""" |
||||
exp = Explorer(data="coco8.yaml") |
||||
exp.create_embeddings_table() |
||||
similar = exp.get_similar(idx=1) |
||||
assert len(similar) == 4 |
||||
similar = exp.get_similar(img=ASSETS / "bus.jpg") |
||||
assert len(similar) == 4 |
||||
similar = exp.get_similar(idx=[1, 2], limit=2) |
||||
assert len(similar) == 2 |
||||
sim_idx = exp.similarity_index() |
||||
assert len(sim_idx) == 4 |
||||
sql = exp.sql_query("WHERE labels LIKE '%zebra%'") |
||||
assert len(sql) == 1 |
||||
|
||||
|
||||
@pytest.mark.slow |
||||
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13") |
||||
def test_det(): |
||||
"""Test detection functionalities and verify embedding table includes bounding boxes.""" |
||||
exp = Explorer(data="coco8.yaml", model="yolo11n.pt") |
||||
exp.create_embeddings_table(force=True) |
||||
assert len(exp.table.head()["bboxes"]) > 0 |
||||
similar = exp.get_similar(idx=[1, 2], limit=10) |
||||
assert len(similar) > 0 |
||||
# This is a loose test, just checks errors not correctness |
||||
similar = exp.plot_similar(idx=[1, 2], limit=10) |
||||
assert isinstance(similar, PIL.Image.Image) |
||||
|
||||
|
||||
@pytest.mark.slow |
||||
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13") |
||||
def test_seg(): |
||||
"""Test segmentation functionalities and ensure the embedding table includes segmentation masks.""" |
||||
exp = Explorer(data="coco8-seg.yaml", model="yolo11n-seg.pt") |
||||
exp.create_embeddings_table(force=True) |
||||
assert len(exp.table.head()["masks"]) > 0 |
||||
similar = exp.get_similar(idx=[1, 2], limit=10) |
||||
assert len(similar) > 0 |
||||
similar = exp.plot_similar(idx=[1, 2], limit=10) |
||||
assert isinstance(similar, PIL.Image.Image) |
||||
|
||||
|
||||
@pytest.mark.slow |
||||
@pytest.mark.skipif(not TORCH_1_13, reason="Explorer requires torch>=1.13") |
||||
def test_pose(): |
||||
"""Test pose estimation functionality and verify the embedding table includes keypoints.""" |
||||
exp = Explorer(data="coco8-pose.yaml", model="yolo11n-pose.pt") |
||||
exp.create_embeddings_table(force=True) |
||||
assert len(exp.table.head()["keypoints"]) > 0 |
||||
similar = exp.get_similar(idx=[1, 2], limit=10) |
||||
assert len(similar) > 0 |
||||
similar = exp.plot_similar(idx=[1, 2], limit=10) |
||||
assert isinstance(similar, PIL.Image.Image) |
@ -0,0 +1,16 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
# Configuration for Ultralytics Solutions |
||||
|
||||
model: "yolo11n.pt" # The Ultralytics YOLO11 model to be used (e.g., yolo11n.pt for YOLO11 nano version and yolov8n.pt for YOLOv8 nano version) |
||||
|
||||
region: # Object counting, queue or speed estimation region points. Default region points are [(20, 400), (1080, 404), (1080, 360), (20, 360)] |
||||
line_width: 2 # Width of the annotator used to draw regions on the image/video frames + bounding boxes and tracks drawing. Default value is 2. |
||||
show: True # Flag to control whether to display output image or not, you can set this as False i.e. when deploying it on some embedded devices. |
||||
show_in: True # Flag to display objects moving *into* the defined region |
||||
show_out: True # Flag to display objects moving *out of* the defined region |
||||
classes: # To count specific classes. i.e, if you want to detect, track and count the person with COCO model, you can use classes=0, Default its None |
||||
up_angle: 145.0 # Workouts up_angle for counts, 145.0 is default value. You can adjust it for different workouts, based on position of keypoints. |
||||
down_angle: 90 # Workouts down_angle for counts, 90 is default value. You can change it for different workouts, based on position of keypoints. |
||||
kpts: [6, 8, 10] # Keypoints for workouts monitoring, i.e. If you want to consider keypoints for pushups that have mostly values of [6, 8, 10]. |
||||
colormap: # Colormap for heatmap, Only OPENCV supported colormaps can be used. By default COLORMAP_PARULA will be used for visualization. |
@ -1,127 +1,79 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
import cv2 |
||||
|
||||
from ultralytics.utils.checks import check_imshow |
||||
from ultralytics.solutions.solutions import BaseSolution # Import a parent class |
||||
from ultralytics.utils.plotting import Annotator |
||||
|
||||
|
||||
class AIGym: |
||||
class AIGym(BaseSolution): |
||||
"""A class to manage the gym steps of people in a real-time video stream based on their poses.""" |
||||
|
||||
def __init__( |
||||
self, |
||||
kpts_to_check, |
||||
line_thickness=2, |
||||
view_img=False, |
||||
pose_up_angle=145.0, |
||||
pose_down_angle=90.0, |
||||
pose_type="pullup", |
||||
): |
||||
def __init__(self, **kwargs): |
||||
"""Initialization function for AiGYM class, a child class of BaseSolution class, can be used for workouts |
||||
monitoring. |
||||
""" |
||||
Initializes the AIGym class with the specified parameters. |
||||
|
||||
Args: |
||||
kpts_to_check (list): Indices of keypoints to check. |
||||
line_thickness (int, optional): Thickness of the lines drawn. Defaults to 2. |
||||
view_img (bool, optional): Flag to display the image. Defaults to False. |
||||
pose_up_angle (float, optional): Angle threshold for the 'up' pose. Defaults to 145.0. |
||||
pose_down_angle (float, optional): Angle threshold for the 'down' pose. Defaults to 90.0. |
||||
pose_type (str, optional): Type of pose to detect ('pullup', 'pushup', 'abworkout'). Defaults to "pullup". |
||||
# Check if the model name ends with '-pose' |
||||
if "model" in kwargs and "-pose" not in kwargs["model"]: |
||||
kwargs["model"] = "yolo11n-pose.pt" |
||||
elif "model" not in kwargs: |
||||
kwargs["model"] = "yolo11n-pose.pt" |
||||
|
||||
super().__init__(**kwargs) |
||||
self.count = [] # List for counts, necessary where there are multiple objects in frame |
||||
self.angle = [] # List for angle, necessary where there are multiple objects in frame |
||||
self.stage = [] # List for stage, necessary where there are multiple objects in frame |
||||
|
||||
# Extract details from CFG single time for usage later |
||||
self.initial_stage = None |
||||
self.up_angle = float(self.CFG["up_angle"]) # Pose up predefined angle to consider up pose |
||||
self.down_angle = float(self.CFG["down_angle"]) # Pose down predefined angle to consider down pose |
||||
self.kpts = self.CFG["kpts"] # User selected kpts of workouts storage for further usage |
||||
self.lw = self.CFG["line_width"] # Store line_width for usage |
||||
|
||||
def monitor(self, im0): |
||||
""" |
||||
# Image and line thickness |
||||
self.im0 = None |
||||
self.tf = line_thickness |
||||
|
||||
# Keypoints and count information |
||||
self.keypoints = None |
||||
self.poseup_angle = pose_up_angle |
||||
self.posedown_angle = pose_down_angle |
||||
self.threshold = 0.001 |
||||
|
||||
# Store stage, count and angle information |
||||
self.angle = None |
||||
self.count = None |
||||
self.stage = None |
||||
self.pose_type = pose_type |
||||
self.kpts_to_check = kpts_to_check |
||||
|
||||
# Visual Information |
||||
self.view_img = view_img |
||||
self.annotator = None |
||||
|
||||
# Check if environment supports imshow |
||||
self.env_check = check_imshow(warn=True) |
||||
self.count = [] |
||||
self.angle = [] |
||||
self.stage = [] |
||||
|
||||
def start_counting(self, im0, results): |
||||
""" |
||||
Function used to count the gym steps. |
||||
Monitor the workouts using Ultralytics YOLOv8 Pose Model: https://docs.ultralytics.com/tasks/pose/. |
||||
|
||||
Args: |
||||
im0 (ndarray): Current frame from the video stream. |
||||
results (list): Pose estimation data. |
||||
im0 (ndarray): The input image that will be used for processing |
||||
Returns |
||||
im0 (ndarray): The processed image for more usage |
||||
""" |
||||
self.im0 = im0 |
||||
|
||||
if not len(results[0]): |
||||
return self.im0 |
||||
|
||||
if len(results[0]) > len(self.count): |
||||
new_human = len(results[0]) - len(self.count) |
||||
self.count += [0] * new_human |
||||
self.angle += [0] * new_human |
||||
self.stage += ["-"] * new_human |
||||
|
||||
self.keypoints = results[0].keypoints.data |
||||
self.annotator = Annotator(im0, line_width=self.tf) |
||||
|
||||
for ind, k in enumerate(reversed(self.keypoints)): |
||||
# Estimate angle and draw specific points based on pose type |
||||
if self.pose_type in {"pushup", "pullup", "abworkout", "squat"}: |
||||
self.angle[ind] = self.annotator.estimate_pose_angle( |
||||
k[int(self.kpts_to_check[0])].cpu(), |
||||
k[int(self.kpts_to_check[1])].cpu(), |
||||
k[int(self.kpts_to_check[2])].cpu(), |
||||
) |
||||
self.im0 = self.annotator.draw_specific_points(k, self.kpts_to_check, shape=(640, 640), radius=10) |
||||
|
||||
# Check and update pose stages and counts based on angle |
||||
if self.pose_type in {"abworkout", "pullup"}: |
||||
if self.angle[ind] > self.poseup_angle: |
||||
self.stage[ind] = "down" |
||||
if self.angle[ind] < self.posedown_angle and self.stage[ind] == "down": |
||||
self.stage[ind] = "up" |
||||
self.count[ind] += 1 |
||||
|
||||
elif self.pose_type in {"pushup", "squat"}: |
||||
if self.angle[ind] > self.poseup_angle: |
||||
self.stage[ind] = "up" |
||||
if self.angle[ind] < self.posedown_angle and self.stage[ind] == "up": |
||||
self.stage[ind] = "down" |
||||
# Extract tracks |
||||
tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"])[0] |
||||
|
||||
if tracks.boxes.id is not None: |
||||
# Extract and check keypoints |
||||
if len(tracks) > len(self.count): |
||||
new_human = len(tracks) - len(self.count) |
||||
self.angle += [0] * new_human |
||||
self.count += [0] * new_human |
||||
self.stage += ["-"] * new_human |
||||
|
||||
# Initialize annotator |
||||
self.annotator = Annotator(im0, line_width=self.lw) |
||||
|
||||
# Enumerate over keypoints |
||||
for ind, k in enumerate(reversed(tracks.keypoints.data)): |
||||
# Get keypoints and estimate the angle |
||||
kpts = [k[int(self.kpts[i])].cpu() for i in range(3)] |
||||
self.angle[ind] = self.annotator.estimate_pose_angle(*kpts) |
||||
im0 = self.annotator.draw_specific_points(k, self.kpts, radius=self.lw * 3) |
||||
|
||||
# Determine stage and count logic based on angle thresholds |
||||
if self.angle[ind] < self.down_angle: |
||||
if self.stage[ind] == "up": |
||||
self.count[ind] += 1 |
||||
self.stage[ind] = "down" |
||||
elif self.angle[ind] > self.up_angle: |
||||
self.stage[ind] = "up" |
||||
|
||||
# Display angle, count, and stage text |
||||
self.annotator.plot_angle_and_count_and_stage( |
||||
angle_text=self.angle[ind], |
||||
count_text=self.count[ind], |
||||
stage_text=self.stage[ind], |
||||
center_kpt=k[int(self.kpts_to_check[1])], |
||||
angle_text=self.angle[ind], # angle text for display |
||||
count_text=self.count[ind], # count text for workouts |
||||
stage_text=self.stage[ind], # stage position text |
||||
center_kpt=k[int(self.kpts[1])], # center keypoint for display |
||||
) |
||||
|
||||
# Draw keypoints |
||||
self.annotator.kpts(k, shape=(640, 640), radius=1, kpt_line=True) |
||||
|
||||
# Display the image if environment supports it and view_img is True |
||||
if self.env_check and self.view_img: |
||||
cv2.imshow("Ultralytics YOLOv8 AI GYM", self.im0) |
||||
if cv2.waitKey(1) & 0xFF == ord("q"): |
||||
return |
||||
|
||||
return self.im0 |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
kpts_to_check = [0, 1, 2] # example keypoints |
||||
aigym = AIGym(kpts_to_check) |
||||
self.display_output(im0) # Display output image, if environment support display |
||||
return im0 # return an image for writing or further usage |
||||
|
@ -1,259 +1,93 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
from collections import defaultdict |
||||
|
||||
import cv2 |
||||
import numpy as np |
||||
|
||||
from ultralytics.utils.checks import check_imshow, check_requirements |
||||
from ultralytics.solutions.object_counter import ObjectCounter # Import object counter class |
||||
from ultralytics.utils.plotting import Annotator |
||||
|
||||
check_requirements("shapely>=2.0.0") |
||||
|
||||
from shapely.geometry import LineString, Point, Polygon |
||||
|
||||
|
||||
class Heatmap: |
||||
class Heatmap(ObjectCounter): |
||||
"""A class to draw heatmaps in real-time video stream based on their tracks.""" |
||||
|
||||
def __init__( |
||||
self, |
||||
names, |
||||
imw=0, |
||||
imh=0, |
||||
colormap=cv2.COLORMAP_JET, |
||||
heatmap_alpha=0.5, |
||||
view_img=False, |
||||
view_in_counts=True, |
||||
view_out_counts=True, |
||||
count_reg_pts=None, |
||||
count_txt_color=(0, 0, 0), |
||||
count_bg_color=(255, 255, 255), |
||||
count_reg_color=(255, 0, 255), |
||||
region_thickness=5, |
||||
line_dist_thresh=15, |
||||
line_thickness=2, |
||||
decay_factor=0.99, |
||||
shape="circle", |
||||
): |
||||
"""Initializes the heatmap class with default values for Visual, Image, track, count and heatmap parameters.""" |
||||
# Visual information |
||||
self.annotator = None |
||||
self.view_img = view_img |
||||
self.shape = shape |
||||
|
||||
self.initialized = False |
||||
self.names = names # Classes names |
||||
|
||||
# Image information |
||||
self.imw = imw |
||||
self.imh = imh |
||||
self.im0 = None |
||||
self.tf = line_thickness |
||||
self.view_in_counts = view_in_counts |
||||
self.view_out_counts = view_out_counts |
||||
|
||||
# Heatmap colormap and heatmap np array |
||||
self.colormap = colormap |
||||
self.heatmap = None |
||||
self.heatmap_alpha = heatmap_alpha |
||||
|
||||
# Predict/track information |
||||
self.boxes = [] |
||||
self.track_ids = [] |
||||
self.clss = [] |
||||
self.track_history = defaultdict(list) |
||||
|
||||
# Region & Line Information |
||||
self.counting_region = None |
||||
self.line_dist_thresh = line_dist_thresh |
||||
self.region_thickness = region_thickness |
||||
self.region_color = count_reg_color |
||||
|
||||
# Object Counting Information |
||||
self.in_counts = 0 |
||||
self.out_counts = 0 |
||||
self.count_ids = [] |
||||
self.class_wise_count = {} |
||||
self.count_txt_color = count_txt_color |
||||
self.count_bg_color = count_bg_color |
||||
self.cls_txtdisplay_gap = 50 |
||||
|
||||
# Decay factor |
||||
self.decay_factor = decay_factor |
||||
def __init__(self, **kwargs): |
||||
"""Initializes function for heatmap class with default values.""" |
||||
super().__init__(**kwargs) |
||||
|
||||
# Check if environment supports imshow |
||||
self.env_check = check_imshow(warn=True) |
||||
self.initialized = False # bool variable for heatmap initialization |
||||
if self.region is not None: # check if user provided the region coordinates |
||||
self.initialize_region() |
||||
|
||||
# Region and line selection |
||||
self.count_reg_pts = count_reg_pts |
||||
print(self.count_reg_pts) |
||||
if self.count_reg_pts is not None: |
||||
if len(self.count_reg_pts) == 2: |
||||
print("Line Counter Initiated.") |
||||
self.counting_region = LineString(self.count_reg_pts) |
||||
elif len(self.count_reg_pts) >= 3: |
||||
print("Polygon Counter Initiated.") |
||||
self.counting_region = Polygon(self.count_reg_pts) |
||||
else: |
||||
print("Invalid Region points provided, region_points must be 2 for lines or >= 3 for polygons.") |
||||
print("Using Line Counter Now") |
||||
self.counting_region = LineString(self.count_reg_pts) |
||||
# store colormap |
||||
self.colormap = cv2.COLORMAP_PARULA if self.CFG["colormap"] is None else self.CFG["colormap"] |
||||
|
||||
# Shape of heatmap, if not selected |
||||
if self.shape not in {"circle", "rect"}: |
||||
print("Unknown shape value provided, 'circle' & 'rect' supported") |
||||
print("Using Circular shape now") |
||||
self.shape = "circle" |
||||
|
||||
def extract_results(self, tracks): |
||||
def heatmap_effect(self, box): |
||||
""" |
||||
Extracts results from the provided data. |
||||
Efficient calculation of heatmap area and effect location for applying colormap. |
||||
|
||||
Args: |
||||
tracks (list): List of tracks obtained from the object tracking process. |
||||
""" |
||||
if tracks[0].boxes.id is not None: |
||||
self.boxes = tracks[0].boxes.xyxy.cpu() |
||||
self.clss = tracks[0].boxes.cls.tolist() |
||||
self.track_ids = tracks[0].boxes.id.int().tolist() |
||||
|
||||
def generate_heatmap(self, im0, tracks): |
||||
box (list): Bounding Box coordinates data [x0, y0, x1, y1] |
||||
""" |
||||
Generate heatmap based on tracking data. |
||||
x0, y0, x1, y1 = map(int, box) |
||||
radius_squared = (min(x1 - x0, y1 - y0) // 2) ** 2 |
||||
|
||||
Args: |
||||
im0 (nd array): Image |
||||
tracks (list): List of tracks obtained from the object tracking process. |
||||
""" |
||||
self.im0 = im0 |
||||
# Create a meshgrid with region of interest (ROI) for vectorized distance calculations |
||||
xv, yv = np.meshgrid(np.arange(x0, x1), np.arange(y0, y1)) |
||||
|
||||
# Initialize heatmap only once |
||||
if not self.initialized: |
||||
self.heatmap = np.zeros((int(self.im0.shape[0]), int(self.im0.shape[1])), dtype=np.float32) |
||||
self.initialized = True |
||||
# Calculate squared distances from the center |
||||
dist_squared = (xv - ((x0 + x1) // 2)) ** 2 + (yv - ((y0 + y1) // 2)) ** 2 |
||||
|
||||
self.heatmap *= self.decay_factor # decay factor |
||||
# Create a mask of points within the radius |
||||
within_radius = dist_squared <= radius_squared |
||||
|
||||
self.extract_results(tracks) |
||||
self.annotator = Annotator(self.im0, self.tf, None) |
||||
# Update only the values within the bounding box in a single vectorized operation |
||||
self.heatmap[y0:y1, x0:x1][within_radius] += 2 |
||||
|
||||
if self.track_ids: |
||||
# Draw counting region |
||||
if self.count_reg_pts is not None: |
||||
self.annotator.draw_region( |
||||
reg_pts=self.count_reg_pts, color=self.region_color, thickness=self.region_thickness |
||||
) |
||||
|
||||
for box, cls, track_id in zip(self.boxes, self.clss, self.track_ids): |
||||
# Store class info |
||||
if self.names[cls] not in self.class_wise_count: |
||||
self.class_wise_count[self.names[cls]] = {"IN": 0, "OUT": 0} |
||||
|
||||
if self.shape == "circle": |
||||
center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2)) |
||||
radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2 |
||||
def generate_heatmap(self, im0): |
||||
""" |
||||
Generate heatmap for each frame using Ultralytics. |
||||
|
||||
y, x = np.ogrid[0 : self.heatmap.shape[0], 0 : self.heatmap.shape[1]] |
||||
mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius**2 |
||||
Args: |
||||
im0 (ndarray): Input image array for processing |
||||
Returns: |
||||
im0 (ndarray): Processed image for further usage |
||||
""" |
||||
self.heatmap = np.zeros_like(im0, dtype=np.float32) * 0.99 if not self.initialized else self.heatmap |
||||
self.initialized = True # Initialize heatmap only once |
||||
|
||||
self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += ( |
||||
2 * mask[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] |
||||
) |
||||
self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator |
||||
self.extract_tracks(im0) # Extract tracks |
||||
|
||||
else: |
||||
self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += 2 |
||||
# Iterate over bounding boxes, track ids and classes index |
||||
for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): |
||||
# Draw bounding box and counting region |
||||
self.heatmap_effect(box) |
||||
|
||||
# Store tracking hist |
||||
track_line = self.track_history[track_id] |
||||
track_line.append((float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2))) |
||||
if len(track_line) > 30: |
||||
track_line.pop(0) |
||||
if self.region is not None: |
||||
self.annotator.draw_region(reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2) |
||||
self.store_tracking_history(track_id, box) # Store track history |
||||
self.store_classwise_counts(cls) # store classwise counts in dict |
||||
|
||||
# Store tracking previous position and perform object counting |
||||
prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None |
||||
self.count_objects(self.track_line, box, track_id, prev_position, cls) # Perform object counting |
||||
|
||||
if self.count_reg_pts is not None: |
||||
# Count objects in any polygon |
||||
if len(self.count_reg_pts) >= 3: |
||||
is_inside = self.counting_region.contains(Point(track_line[-1])) |
||||
|
||||
if prev_position is not None and is_inside and track_id not in self.count_ids: |
||||
self.count_ids.append(track_id) |
||||
|
||||
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0: |
||||
self.in_counts += 1 |
||||
self.class_wise_count[self.names[cls]]["IN"] += 1 |
||||
else: |
||||
self.out_counts += 1 |
||||
self.class_wise_count[self.names[cls]]["OUT"] += 1 |
||||
|
||||
# Count objects using line |
||||
elif len(self.count_reg_pts) == 2: |
||||
if prev_position is not None and track_id not in self.count_ids: |
||||
distance = Point(track_line[-1]).distance(self.counting_region) |
||||
if distance < self.line_dist_thresh and track_id not in self.count_ids: |
||||
self.count_ids.append(track_id) |
||||
|
||||
if (box[0] - prev_position[0]) * ( |
||||
self.counting_region.centroid.x - prev_position[0] |
||||
) > 0: |
||||
self.in_counts += 1 |
||||
self.class_wise_count[self.names[cls]]["IN"] += 1 |
||||
else: |
||||
self.out_counts += 1 |
||||
self.class_wise_count[self.names[cls]]["OUT"] += 1 |
||||
|
||||
else: |
||||
for box, cls in zip(self.boxes, self.clss): |
||||
if self.shape == "circle": |
||||
center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2)) |
||||
radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2 |
||||
|
||||
y, x = np.ogrid[0 : self.heatmap.shape[0], 0 : self.heatmap.shape[1]] |
||||
mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius**2 |
||||
|
||||
self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += ( |
||||
2 * mask[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] |
||||
) |
||||
|
||||
else: |
||||
self.heatmap[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] += 2 |
||||
|
||||
if self.count_reg_pts is not None: |
||||
labels_dict = {} |
||||
|
||||
for key, value in self.class_wise_count.items(): |
||||
if value["IN"] != 0 or value["OUT"] != 0: |
||||
if not self.view_in_counts and not self.view_out_counts: |
||||
continue |
||||
elif not self.view_in_counts: |
||||
labels_dict[str.capitalize(key)] = f"OUT {value['OUT']}" |
||||
elif not self.view_out_counts: |
||||
labels_dict[str.capitalize(key)] = f"IN {value['IN']}" |
||||
else: |
||||
labels_dict[str.capitalize(key)] = f"IN {value['IN']} OUT {value['OUT']}" |
||||
|
||||
if labels_dict is not None: |
||||
self.annotator.display_analytics(self.im0, labels_dict, self.count_txt_color, self.count_bg_color, 10) |
||||
self.display_counts(im0) if self.region is not None else None # Display the counts on the frame |
||||
|
||||
# Normalize, apply colormap to heatmap and combine with original image |
||||
heatmap_normalized = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX) |
||||
heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), self.colormap) |
||||
self.im0 = cv2.addWeighted(self.im0, 1 - self.heatmap_alpha, heatmap_colored, self.heatmap_alpha, 0) |
||||
|
||||
if self.env_check and self.view_img: |
||||
self.display_frames() |
||||
|
||||
return self.im0 |
||||
|
||||
def display_frames(self): |
||||
"""Display frame.""" |
||||
cv2.imshow("Ultralytics Heatmap", self.im0) |
||||
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"): |
||||
return |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
classes_names = {0: "person", 1: "car"} # example class names |
||||
heatmap = Heatmap(classes_names) |
||||
im0 = ( |
||||
im0 |
||||
if self.track_data.id is None |
||||
else cv2.addWeighted( |
||||
im0, |
||||
0.5, |
||||
cv2.applyColorMap( |
||||
cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8), self.colormap |
||||
), |
||||
0.5, |
||||
0, |
||||
) |
||||
) |
||||
|
||||
self.display_output(im0) # display output with base class function |
||||
return im0 # return output image for more usage |
||||
|
@ -1,243 +1,131 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
from collections import defaultdict |
||||
from shapely.geometry import LineString, Point |
||||
|
||||
import cv2 |
||||
|
||||
from ultralytics.utils.checks import check_imshow, check_requirements |
||||
from ultralytics.solutions.solutions import BaseSolution # Import a parent class |
||||
from ultralytics.utils.plotting import Annotator, colors |
||||
|
||||
check_requirements("shapely>=2.0.0") |
||||
|
||||
from shapely.geometry import LineString, Point, Polygon |
||||
class ObjectCounter(BaseSolution): |
||||
"""A class to manage the counting of objects in a real-time video stream based on their tracks.""" |
||||
|
||||
def __init__(self, **kwargs): |
||||
"""Initialization function for Count class, a child class of BaseSolution class, can be used for counting the |
||||
objects. |
||||
""" |
||||
super().__init__(**kwargs) |
||||
|
||||
class ObjectCounter: |
||||
"""A class to manage the counting of objects in a real-time video stream based on their tracks.""" |
||||
self.in_count = 0 # Counter for objects moving inward |
||||
self.out_count = 0 # Counter for objects moving outward |
||||
self.counted_ids = [] # List of IDs of objects that have been counted |
||||
self.classwise_counts = {} # Dictionary for counts, categorized by object class |
||||
self.region_initialized = False # Bool variable for region initialization |
||||
|
||||
def __init__( |
||||
self, |
||||
names, |
||||
reg_pts=None, |
||||
line_thickness=2, |
||||
view_img=False, |
||||
view_in_counts=True, |
||||
view_out_counts=True, |
||||
draw_tracks=False, |
||||
): |
||||
self.show_in = self.CFG["show_in"] |
||||
self.show_out = self.CFG["show_out"] |
||||
|
||||
def count_objects(self, track_line, box, track_id, prev_position, cls): |
||||
""" |
||||
Initializes the ObjectCounter with various tracking and counting parameters. |
||||
Helper function to count objects within a polygonal region. |
||||
|
||||
Args: |
||||
names (dict): Dictionary of class names. |
||||
reg_pts (list): List of points defining the counting region. |
||||
line_thickness (int): Line thickness for bounding boxes. |
||||
view_img (bool): Flag to control whether to display the video stream. |
||||
view_in_counts (bool): Flag to control whether to display the in counts on the video stream. |
||||
view_out_counts (bool): Flag to control whether to display the out counts on the video stream. |
||||
draw_tracks (bool): Flag to control whether to draw the object tracks. |
||||
track_line (dict): last 30 frame track record |
||||
box (list): Bounding box data for specific track in current frame |
||||
track_id (int): track ID of the object |
||||
prev_position (tuple): last frame position coordinates of the track |
||||
cls (int): Class index for classwise count updates |
||||
""" |
||||
# Mouse events |
||||
self.is_drawing = False |
||||
self.selected_point = None |
||||
|
||||
# Region & Line Information |
||||
self.reg_pts = [(20, 400), (1260, 400)] if reg_pts is None else reg_pts |
||||
self.counting_region = None |
||||
|
||||
# Image and annotation Information |
||||
self.im0 = None |
||||
self.tf = line_thickness |
||||
self.view_img = view_img |
||||
self.view_in_counts = view_in_counts |
||||
self.view_out_counts = view_out_counts |
||||
|
||||
self.names = names # Classes names |
||||
self.window_name = "Ultralytics YOLOv8 Object Counter" |
||||
|
||||
# Object counting Information |
||||
self.in_counts = 0 |
||||
self.out_counts = 0 |
||||
self.count_ids = [] |
||||
self.class_wise_count = {} |
||||
|
||||
# Tracks info |
||||
self.track_history = defaultdict(list) |
||||
self.draw_tracks = draw_tracks |
||||
|
||||
# Check if environment supports imshow |
||||
self.env_check = check_imshow(warn=True) |
||||
|
||||
# Initialize counting region |
||||
if len(self.reg_pts) == 2: |
||||
print("Line Counter Initiated.") |
||||
self.counting_region = LineString(self.reg_pts) |
||||
elif len(self.reg_pts) >= 3: |
||||
print("Polygon Counter Initiated.") |
||||
self.counting_region = Polygon(self.reg_pts) |
||||
else: |
||||
print("Invalid Region points provided, region_points must be 2 for lines or >= 3 for polygons.") |
||||
print("Using Line Counter Now") |
||||
self.counting_region = LineString(self.reg_pts) |
||||
|
||||
# Define the counting line segment |
||||
self.counting_line_segment = LineString( |
||||
[ |
||||
(self.reg_pts[0][0], self.reg_pts[0][1]), |
||||
(self.reg_pts[1][0], self.reg_pts[1][1]), |
||||
] |
||||
) |
||||
|
||||
def mouse_event_for_region(self, event, x, y, flags, params): |
||||
if prev_position is None or track_id in self.counted_ids: |
||||
return |
||||
|
||||
centroid = self.r_s.centroid |
||||
dx = (box[0] - prev_position[0]) * (centroid.x - prev_position[0]) |
||||
dy = (box[1] - prev_position[1]) * (centroid.y - prev_position[1]) |
||||
|
||||
if len(self.region) >= 3 and self.r_s.contains(Point(track_line[-1])): |
||||
self.counted_ids.append(track_id) |
||||
# For polygon region |
||||
if dx > 0: |
||||
self.in_count += 1 |
||||
self.classwise_counts[self.names[cls]]["IN"] += 1 |
||||
else: |
||||
self.out_count += 1 |
||||
self.classwise_counts[self.names[cls]]["OUT"] += 1 |
||||
|
||||
elif len(self.region) < 3 and LineString([prev_position, box[:2]]).intersects(self.l_s): |
||||
self.counted_ids.append(track_id) |
||||
# For linear region |
||||
if dx > 0 and dy > 0: |
||||
self.in_count += 1 |
||||
self.classwise_counts[self.names[cls]]["IN"] += 1 |
||||
else: |
||||
self.out_count += 1 |
||||
self.classwise_counts[self.names[cls]]["OUT"] += 1 |
||||
|
||||
def store_classwise_counts(self, cls): |
||||
""" |
||||
Handles mouse events for defining and moving the counting region in a real-time video stream. |
||||
Initialize class-wise counts if not already present. |
||||
|
||||
Args: |
||||
event (int): The type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN, etc.). |
||||
x (int): The x-coordinate of the mouse pointer. |
||||
y (int): The y-coordinate of the mouse pointer. |
||||
flags (int): Any associated event flags (e.g., cv2.EVENT_FLAG_CTRLKEY, cv2.EVENT_FLAG_SHIFTKEY, etc.). |
||||
params (dict): Additional parameters for the function. |
||||
cls (int): Class index for classwise count updates |
||||
""" |
||||
if self.names[cls] not in self.classwise_counts: |
||||
self.classwise_counts[self.names[cls]] = {"IN": 0, "OUT": 0} |
||||
|
||||
def display_counts(self, im0): |
||||
""" |
||||
if event == cv2.EVENT_LBUTTONDOWN: |
||||
for i, point in enumerate(self.reg_pts): |
||||
if ( |
||||
isinstance(point, (tuple, list)) |
||||
and len(point) >= 2 |
||||
and (abs(x - point[0]) < 10 and abs(y - point[1]) < 10) |
||||
): |
||||
self.selected_point = i |
||||
self.is_drawing = True |
||||
break |
||||
|
||||
elif event == cv2.EVENT_MOUSEMOVE: |
||||
if self.is_drawing and self.selected_point is not None: |
||||
self.reg_pts[self.selected_point] = (x, y) |
||||
self.counting_region = Polygon(self.reg_pts) |
||||
|
||||
elif event == cv2.EVENT_LBUTTONUP: |
||||
self.is_drawing = False |
||||
self.selected_point = None |
||||
|
||||
def extract_and_process_tracks(self, tracks): |
||||
"""Extracts and processes tracks for object counting in a video stream.""" |
||||
# Annotator Init and region drawing |
||||
annotator = Annotator(self.im0, self.tf, self.names) |
||||
|
||||
# Draw region or line |
||||
annotator.draw_region(reg_pts=self.reg_pts, color=(104, 0, 123), thickness=self.tf * 2) |
||||
|
||||
# Extract tracks for OBB or object detection |
||||
track_data = tracks[0].obb or tracks[0].boxes |
||||
|
||||
if track_data and track_data.id is not None: |
||||
boxes = track_data.xyxy.cpu() |
||||
clss = track_data.cls.cpu().tolist() |
||||
track_ids = track_data.id.int().cpu().tolist() |
||||
|
||||
# Extract tracks |
||||
for box, track_id, cls in zip(boxes, track_ids, clss): |
||||
# Draw bounding box |
||||
annotator.box_label(box, label=self.names[cls], color=colors(int(track_id), True)) |
||||
|
||||
# Store class info |
||||
if self.names[cls] not in self.class_wise_count: |
||||
self.class_wise_count[self.names[cls]] = {"IN": 0, "OUT": 0} |
||||
|
||||
# Draw Tracks |
||||
track_line = self.track_history[track_id] |
||||
track_line.append((float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2))) |
||||
if len(track_line) > 30: |
||||
track_line.pop(0) |
||||
|
||||
# Draw track trails |
||||
if self.draw_tracks: |
||||
annotator.draw_centroid_and_tracks( |
||||
track_line, |
||||
color=colors(int(track_id), True), |
||||
track_thickness=self.tf, |
||||
) |
||||
|
||||
prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None |
||||
|
||||
# Count objects in any polygon |
||||
if len(self.reg_pts) >= 3: |
||||
is_inside = self.counting_region.contains(Point(track_line[-1])) |
||||
|
||||
if prev_position is not None and is_inside and track_id not in self.count_ids: |
||||
self.count_ids.append(track_id) |
||||
|
||||
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0: |
||||
self.in_counts += 1 |
||||
self.class_wise_count[self.names[cls]]["IN"] += 1 |
||||
else: |
||||
self.out_counts += 1 |
||||
self.class_wise_count[self.names[cls]]["OUT"] += 1 |
||||
|
||||
# Count objects using line |
||||
elif len(self.reg_pts) == 2: |
||||
if ( |
||||
prev_position is not None |
||||
and track_id not in self.count_ids |
||||
and LineString([(prev_position[0], prev_position[1]), (box[0], box[1])]).intersects( |
||||
self.counting_line_segment |
||||
) |
||||
): |
||||
self.count_ids.append(track_id) |
||||
|
||||
# Determine the direction of movement (IN or OUT) |
||||
dx = (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) |
||||
dy = (box[1] - prev_position[1]) * (self.counting_region.centroid.y - prev_position[1]) |
||||
if dx > 0 and dy > 0: |
||||
self.in_counts += 1 |
||||
self.class_wise_count[self.names[cls]]["IN"] += 1 |
||||
else: |
||||
self.out_counts += 1 |
||||
self.class_wise_count[self.names[cls]]["OUT"] += 1 |
||||
|
||||
labels_dict = {} |
||||
|
||||
for key, value in self.class_wise_count.items(): |
||||
if value["IN"] != 0 or value["OUT"] != 0: |
||||
if not self.view_in_counts and not self.view_out_counts: |
||||
continue |
||||
elif not self.view_in_counts: |
||||
labels_dict[str.capitalize(key)] = f"OUT {value['OUT']}" |
||||
elif not self.view_out_counts: |
||||
labels_dict[str.capitalize(key)] = f"IN {value['IN']}" |
||||
else: |
||||
labels_dict[str.capitalize(key)] = f"IN {value['IN']} OUT {value['OUT']}" |
||||
Helper function to display object counts on the frame. |
||||
|
||||
Args: |
||||
im0 (ndarray): The input image or frame |
||||
""" |
||||
labels_dict = { |
||||
str.capitalize(key): f"{'IN ' + str(value['IN']) if self.show_in else ''} " |
||||
f"{'OUT ' + str(value['OUT']) if self.show_out else ''}".strip() |
||||
for key, value in self.classwise_counts.items() |
||||
if value["IN"] != 0 or value["OUT"] != 0 |
||||
} |
||||
|
||||
if labels_dict: |
||||
annotator.display_analytics(self.im0, labels_dict, (104, 31, 17), (255, 255, 255), 10) |
||||
|
||||
def display_frames(self): |
||||
"""Displays the current frame with annotations and regions in a window.""" |
||||
if self.env_check: |
||||
cv2.namedWindow(self.window_name) |
||||
if len(self.reg_pts) == 4: # only add mouse event If user drawn region |
||||
cv2.setMouseCallback(self.window_name, self.mouse_event_for_region, {"region_points": self.reg_pts}) |
||||
cv2.imshow(self.window_name, self.im0) |
||||
# Break Window |
||||
if cv2.waitKey(1) & 0xFF == ord("q"): |
||||
return |
||||
|
||||
def start_counting(self, im0, tracks): |
||||
self.annotator.display_analytics(im0, labels_dict, (104, 31, 17), (255, 255, 255), 10) |
||||
|
||||
def count(self, im0): |
||||
""" |
||||
Main function to start the object counting process. |
||||
Processes input data (frames or object tracks) and updates counts. |
||||
|
||||
Args: |
||||
im0 (ndarray): Current frame from the video stream. |
||||
tracks (list): List of tracks obtained from the object tracking process. |
||||
im0 (ndarray): The input image that will be used for processing |
||||
Returns |
||||
im0 (ndarray): The processed image for more usage |
||||
""" |
||||
self.im0 = im0 # store image |
||||
self.extract_and_process_tracks(tracks) # draw region even if no objects |
||||
if not self.region_initialized: |
||||
self.initialize_region() |
||||
self.region_initialized = True |
||||
|
||||
self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator |
||||
self.extract_tracks(im0) # Extract tracks |
||||
|
||||
self.annotator.draw_region( |
||||
reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2 |
||||
) # Draw region |
||||
|
||||
# Iterate over bounding boxes, track ids and classes index |
||||
for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): |
||||
# Draw bounding box and counting region |
||||
self.annotator.box_label(box, label=self.names[cls], color=colors(track_id, True)) |
||||
self.store_tracking_history(track_id, box) # Store track history |
||||
self.store_classwise_counts(cls) # store classwise counts in dict |
||||
|
||||
# Draw tracks of objects |
||||
self.annotator.draw_centroid_and_tracks( |
||||
self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width |
||||
) |
||||
|
||||
if self.view_img: |
||||
self.display_frames() |
||||
return self.im0 |
||||
# store previous position of track for object counting |
||||
prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None |
||||
self.count_objects(self.track_line, box, track_id, prev_position, cls) # Perform object counting |
||||
|
||||
self.display_counts(im0) # Display the counts on the frame |
||||
self.display_output(im0) # display output with base class function |
||||
|
||||
if __name__ == "__main__": |
||||
classes_names = {0: "person", 1: "car"} # example class names |
||||
ObjectCounter(classes_names) |
||||
return im0 # return output image for more usage |
||||
|
@ -1,127 +1,64 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
from collections import defaultdict |
||||
from shapely.geometry import Point |
||||
|
||||
import cv2 |
||||
|
||||
from ultralytics.utils.checks import check_imshow, check_requirements |
||||
from ultralytics.solutions.solutions import BaseSolution # Import a parent class |
||||
from ultralytics.utils.plotting import Annotator, colors |
||||
|
||||
check_requirements("shapely>=2.0.0") |
||||
|
||||
from shapely.geometry import Point, Polygon |
||||
|
||||
|
||||
class QueueManager: |
||||
class QueueManager(BaseSolution): |
||||
"""A class to manage the queue in a real-time video stream based on object tracks.""" |
||||
|
||||
def __init__( |
||||
self, |
||||
names, |
||||
reg_pts=None, |
||||
line_thickness=2, |
||||
view_img=False, |
||||
draw_tracks=False, |
||||
): |
||||
def __init__(self, **kwargs): |
||||
"""Initializes the QueueManager with specified parameters for tracking and counting objects.""" |
||||
super().__init__(**kwargs) |
||||
self.initialize_region() |
||||
self.counts = 0 # Queue counts Information |
||||
self.rect_color = (255, 255, 255) # Rectangle color |
||||
self.region_length = len(self.region) # Store region length for further usage |
||||
|
||||
def process_queue(self, im0): |
||||
""" |
||||
Initializes the QueueManager with specified parameters for tracking and counting objects. |
||||
Main function to start the queue management process. |
||||
|
||||
Args: |
||||
names (dict): A dictionary mapping class IDs to class names. |
||||
reg_pts (list of tuples, optional): Points defining the counting region polygon. Defaults to a predefined |
||||
rectangle. |
||||
line_thickness (int, optional): Thickness of the annotation lines. Defaults to 2. |
||||
view_img (bool, optional): Whether to display the image frames. Defaults to False. |
||||
draw_tracks (bool, optional): Whether to draw tracks of the objects. Defaults to False. |
||||
im0 (ndarray): The input image that will be used for processing |
||||
Returns |
||||
im0 (ndarray): The processed image for more usage |
||||
""" |
||||
# Region & Line Information |
||||
self.reg_pts = reg_pts if reg_pts is not None else [(20, 60), (20, 680), (1120, 680), (1120, 60)] |
||||
self.counting_region = ( |
||||
Polygon(self.reg_pts) if len(self.reg_pts) >= 3 else Polygon([(20, 60), (20, 680), (1120, 680), (1120, 60)]) |
||||
) |
||||
|
||||
# annotation Information |
||||
self.tf = line_thickness |
||||
self.view_img = view_img |
||||
|
||||
self.names = names # Class names |
||||
|
||||
# Object counting Information |
||||
self.counts = 0 |
||||
|
||||
# Tracks info |
||||
self.track_history = defaultdict(list) |
||||
self.draw_tracks = draw_tracks |
||||
|
||||
# Check if environment supports imshow |
||||
self.env_check = check_imshow(warn=True) |
||||
|
||||
def extract_and_process_tracks(self, tracks, im0): |
||||
"""Extracts and processes tracks for queue management in a video stream.""" |
||||
# Initialize annotator and draw the queue region |
||||
annotator = Annotator(im0, self.tf, self.names) |
||||
self.counts = 0 # Reset counts every frame |
||||
if tracks[0].boxes.id is not None: |
||||
boxes = tracks[0].boxes.xyxy.cpu() |
||||
clss = tracks[0].boxes.cls.cpu().tolist() |
||||
track_ids = tracks[0].boxes.id.int().cpu().tolist() |
||||
self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator |
||||
self.extract_tracks(im0) # Extract tracks |
||||
|
||||
# Extract tracks |
||||
for box, track_id, cls in zip(boxes, track_ids, clss): |
||||
# Draw bounding box |
||||
annotator.box_label(box, label=self.names[cls], color=colors(int(track_id), True)) |
||||
self.annotator.draw_region( |
||||
reg_pts=self.region, color=self.rect_color, thickness=self.line_width * 2 |
||||
) # Draw region |
||||
|
||||
# Update track history |
||||
track_line = self.track_history[track_id] |
||||
track_line.append((float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2))) |
||||
if len(track_line) > 30: |
||||
track_line.pop(0) |
||||
for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): |
||||
# Draw bounding box and counting region |
||||
self.annotator.box_label(box, label=self.names[cls], color=colors(track_id, True)) |
||||
self.store_tracking_history(track_id, box) # Store track history |
||||
|
||||
# Draw track trails if enabled |
||||
if self.draw_tracks: |
||||
annotator.draw_centroid_and_tracks( |
||||
track_line, |
||||
color=colors(int(track_id), True), |
||||
track_thickness=self.line_thickness, |
||||
) |
||||
|
||||
prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None |
||||
|
||||
# Check if the object is inside the counting region |
||||
if len(self.reg_pts) >= 3: |
||||
is_inside = self.counting_region.contains(Point(track_line[-1])) |
||||
if prev_position is not None and is_inside: |
||||
self.counts += 1 |
||||
|
||||
# Display queue counts |
||||
label = f"Queue Counts : {str(self.counts)}" |
||||
if label is not None: |
||||
annotator.queue_counts_display( |
||||
label, |
||||
points=self.reg_pts, |
||||
region_color=(255, 0, 255), |
||||
txt_color=(104, 31, 17), |
||||
# Draw tracks of objects |
||||
self.annotator.draw_centroid_and_tracks( |
||||
self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width |
||||
) |
||||
|
||||
if self.env_check and self.view_img: |
||||
annotator.draw_region(reg_pts=self.reg_pts, thickness=self.tf * 2, color=(255, 0, 255)) |
||||
cv2.imshow("Ultralytics YOLOv8 Queue Manager", im0) |
||||
# Close window on 'q' key press |
||||
if cv2.waitKey(1) & 0xFF == ord("q"): |
||||
return |
||||
# Cache frequently accessed attributes |
||||
track_history = self.track_history.get(track_id, []) |
||||
|
||||
def process_queue(self, im0, tracks): |
||||
""" |
||||
Main function to start the queue management process. |
||||
|
||||
Args: |
||||
im0 (ndarray): Current frame from the video stream. |
||||
tracks (list): List of tracks obtained from the object tracking process. |
||||
""" |
||||
self.extract_and_process_tracks(tracks, im0) # Extract and process tracks |
||||
return im0 |
||||
# store previous position of track and check if the object is inside the counting region |
||||
prev_position = track_history[-2] if len(track_history) > 1 else None |
||||
if self.region_length >= 3 and prev_position and self.r_s.contains(Point(self.track_line[-1])): |
||||
self.counts += 1 |
||||
|
||||
# Display queue counts |
||||
self.annotator.queue_counts_display( |
||||
f"Queue Counts : {str(self.counts)}", |
||||
points=self.region, |
||||
region_color=self.rect_color, |
||||
txt_color=(104, 31, 17), |
||||
) |
||||
self.display_output(im0) # display output with base class function |
||||
|
||||
if __name__ == "__main__": |
||||
classes_names = {0: "person", 1: "car"} # example class names |
||||
queue_manager = QueueManager(classes_names) |
||||
return im0 # return output image for more usage |
||||
|
@ -1,116 +1,76 @@ |
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||
|
||||
from collections import defaultdict |
||||
from time import time |
||||
|
||||
import cv2 |
||||
import numpy as np |
||||
|
||||
from ultralytics.utils.checks import check_imshow |
||||
from ultralytics.solutions.solutions import BaseSolution, LineString |
||||
from ultralytics.utils.plotting import Annotator, colors |
||||
|
||||
|
||||
class SpeedEstimator: |
||||
class SpeedEstimator(BaseSolution): |
||||
"""A class to estimate the speed of objects in a real-time video stream based on their tracks.""" |
||||
|
||||
def __init__(self, names, reg_pts=None, view_img=False, line_thickness=2, spdl_dist_thresh=10): |
||||
""" |
||||
Initializes the SpeedEstimator with the given parameters. |
||||
|
||||
Args: |
||||
names (dict): Dictionary of class names. |
||||
reg_pts (list, optional): List of region points for speed estimation. Defaults to [(20, 400), (1260, 400)]. |
||||
view_img (bool, optional): Whether to display the image with annotations. Defaults to False. |
||||
line_thickness (int, optional): Thickness of the lines for drawing boxes and tracks. Defaults to 2. |
||||
spdl_dist_thresh (int, optional): Distance threshold for speed calculation. Defaults to 10. |
||||
""" |
||||
# Region information |
||||
self.reg_pts = reg_pts if reg_pts is not None else [(20, 400), (1260, 400)] |
||||
def __init__(self, **kwargs): |
||||
"""Initializes the SpeedEstimator with the given parameters.""" |
||||
super().__init__(**kwargs) |
||||
|
||||
self.names = names # Classes names |
||||
self.initialize_region() # Initialize speed region |
||||
|
||||
# Tracking information |
||||
self.trk_history = defaultdict(list) |
||||
|
||||
self.view_img = view_img # bool for displaying inference |
||||
self.tf = line_thickness # line thickness for annotator |
||||
self.spd = {} # set for speed data |
||||
self.trkd_ids = [] # list for already speed_estimated and tracked ID's |
||||
self.spdl = spdl_dist_thresh # Speed line distance threshold |
||||
self.trk_pt = {} # set for tracks previous time |
||||
self.trk_pp = {} # set for tracks previous point |
||||
|
||||
# Check if the environment supports imshow |
||||
self.env_check = check_imshow(warn=True) |
||||
|
||||
def estimate_speed(self, im0, tracks): |
||||
def estimate_speed(self, im0): |
||||
""" |
||||
Estimates the speed of objects based on tracking data. |
||||
|
||||
Args: |
||||
im0 (ndarray): Image. |
||||
tracks (list): List of tracks obtained from the object tracking process. |
||||
|
||||
Returns: |
||||
(ndarray): The image with annotated boxes and tracks. |
||||
im0 (ndarray): The input image that will be used for processing |
||||
Returns |
||||
im0 (ndarray): The processed image for more usage |
||||
""" |
||||
if tracks[0].boxes.id is None: |
||||
return im0 |
||||
self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator |
||||
self.extract_tracks(im0) # Extract tracks |
||||
|
||||
boxes = tracks[0].boxes.xyxy.cpu() |
||||
clss = tracks[0].boxes.cls.cpu().tolist() |
||||
t_ids = tracks[0].boxes.id.int().cpu().tolist() |
||||
annotator = Annotator(im0, line_width=self.tf) |
||||
annotator.draw_region(reg_pts=self.reg_pts, color=(255, 0, 255), thickness=self.tf * 2) |
||||
self.annotator.draw_region( |
||||
reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2 |
||||
) # Draw region |
||||
|
||||
for box, t_id, cls in zip(boxes, t_ids, clss): |
||||
track = self.trk_history[t_id] |
||||
bbox_center = (float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2)) |
||||
track.append(bbox_center) |
||||
for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): |
||||
self.store_tracking_history(track_id, box) # Store track history |
||||
|
||||
if len(track) > 30: |
||||
track.pop(0) |
||||
# Check if track_id is already in self.trk_pp or trk_pt initialize if not |
||||
if track_id not in self.trk_pt: |
||||
self.trk_pt[track_id] = 0 |
||||
if track_id not in self.trk_pp: |
||||
self.trk_pp[track_id] = self.track_line[-1] |
||||
|
||||
trk_pts = np.hstack(track).astype(np.int32).reshape((-1, 1, 2)) |
||||
speed_label = f"{int(self.spd[track_id])} km/h" if track_id in self.spd else self.names[int(cls)] |
||||
self.annotator.box_label(box, label=speed_label, color=colors(track_id, True)) # Draw bounding box |
||||
|
||||
if t_id not in self.trk_pt: |
||||
self.trk_pt[t_id] = 0 |
||||
# Draw tracks of objects |
||||
self.annotator.draw_centroid_and_tracks( |
||||
self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width |
||||
) |
||||
|
||||
speed_label = f"{int(self.spd[t_id])} km/h" if t_id in self.spd else self.names[int(cls)] |
||||
bbox_color = colors(int(t_id), True) |
||||
|
||||
annotator.box_label(box, speed_label, bbox_color) |
||||
cv2.polylines(im0, [trk_pts], isClosed=False, color=bbox_color, thickness=self.tf) |
||||
cv2.circle(im0, (int(track[-1][0]), int(track[-1][1])), self.tf * 2, bbox_color, -1) |
||||
|
||||
# Calculation of object speed |
||||
if not self.reg_pts[0][0] < track[-1][0] < self.reg_pts[1][0]: |
||||
return |
||||
if self.reg_pts[1][1] - self.spdl < track[-1][1] < self.reg_pts[1][1] + self.spdl: |
||||
direction = "known" |
||||
elif self.reg_pts[0][1] - self.spdl < track[-1][1] < self.reg_pts[0][1] + self.spdl: |
||||
# Calculate object speed and direction based on region intersection |
||||
if LineString([self.trk_pp[track_id], self.track_line[-1]]).intersects(self.l_s): |
||||
direction = "known" |
||||
else: |
||||
direction = "unknown" |
||||
|
||||
if self.trk_pt.get(t_id) != 0 and direction != "unknown" and t_id not in self.trkd_ids: |
||||
self.trkd_ids.append(t_id) |
||||
|
||||
time_difference = time() - self.trk_pt[t_id] |
||||
# Perform speed calculation and tracking updates if direction is valid |
||||
if direction == "known" and track_id not in self.trkd_ids: |
||||
self.trkd_ids.append(track_id) |
||||
time_difference = time() - self.trk_pt[track_id] |
||||
if time_difference > 0: |
||||
self.spd[t_id] = np.abs(track[-1][1] - self.trk_pp[t_id][1]) / time_difference |
||||
|
||||
self.trk_pt[t_id] = time() |
||||
self.trk_pp[t_id] = track[-1] |
||||
|
||||
if self.view_img and self.env_check: |
||||
cv2.imshow("Ultralytics Speed Estimation", im0) |
||||
if cv2.waitKey(1) & 0xFF == ord("q"): |
||||
return |
||||
self.spd[track_id] = np.abs(self.track_line[-1][1] - self.trk_pp[track_id][1]) / time_difference |
||||
|
||||
return im0 |
||||
self.trk_pt[track_id] = time() |
||||
self.trk_pp[track_id] = self.track_line[-1] |
||||
|
||||
self.display_output(im0) # display output with base class function |
||||
|
||||
if __name__ == "__main__": |
||||
names = {0: "person", 1: "car"} # example class names |
||||
speed_estimator = SpeedEstimator(names) |
||||
return im0 # return output image for more usage |
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue