Merge branch 'main' into cli-info

cli-info
Burhan 2 months ago committed by GitHub
commit 9a9b644f99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 4
      .github/workflows/ci.yaml
  2. 60
      docs/en/guides/speed-estimation.md
  3. 7
      tests/test_solutions.py
  4. 2
      ultralytics/__init__.py
  5. 18
      ultralytics/cfg/solutions/default.yaml
  6. 15
      ultralytics/engine/trainer.py
  7. 4
      ultralytics/nn/autobackend.py
  8. 4
      ultralytics/nn/tasks.py
  9. 118
      ultralytics/solutions/speed_estimation.py

@ -56,7 +56,7 @@ jobs:
shell: bash # for Windows compatibility
run: |
python -m pip install --upgrade pip wheel
pip install -e . --extra-index-url https://download.pytorch.org/whl/cpu
pip install . --extra-index-url https://download.pytorch.org/whl/cpu
- name: Check environment
run: |
yolo checks
@ -213,7 +213,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Install requirements
run: pip install -e . pytest-cov
run: pip install . pytest-cov
- name: Check environment
run: |
yolo checks

@ -45,40 +45,33 @@ keywords: Ultralytics YOLO11, speed estimation, object tracking, computer vision
```python
import cv2
from ultralytics import YOLO, solutions
from ultralytics import solutions
model = YOLO("yolo11n.pt")
names = model.model.names
cap = cv2.VideoCapture("Path/to/video/file.mp4")
cap = cv2.VideoCapture("path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
# Video writer
video_writer = cv2.VideoWriter("speed_estimation.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
video_writer = cv2.VideoWriter("speed_management.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
line_pts = [(0, 360), (1280, 360)]
speed_region = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
# Init speed-estimation obj
speed_obj = solutions.SpeedEstimator(
reg_pts=line_pts,
names=names,
view_img=True,
)
speed = solutions.SpeedEstimator(model="yolo11n.pt", region=speed_region, show=True)
while cap.isOpened():
success, im0 = cap.read()
if not success:
print("Video frame is empty or video processing has been successfully completed.")
break
tracks = model.track(im0, persist=True)
if success:
out = speed.estimate_speed(im0)
video_writer.write(im0)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
continue
im0 = speed_obj.estimate_speed(im0, tracks)
video_writer.write(im0)
print("Video frame is empty or video processing has been successfully completed.")
break
cap.release()
video_writer.release()
cv2.destroyAllWindows()
```
@ -88,13 +81,12 @@ keywords: Ultralytics YOLO11, speed estimation, object tracking, computer vision
### Arguments `SpeedEstimator`
| Name | Type | Default | Description |
| ------------------ | ------ | -------------------------- | ---------------------------------------------------- |
| `names` | `dict` | `None` | Dictionary of class names. |
| `reg_pts` | `list` | `[(20, 400), (1260, 400)]` | List of region points for speed estimation. |
| `view_img` | `bool` | `False` | Whether to display the image with annotations. |
| `line_thickness` | `int` | `2` | Thickness of the lines for drawing boxes and tracks. |
| `spdl_dist_thresh` | `int` | `10` | Distance threshold for speed calculation. |
| Name | Type | Default | Description |
| ------------ | ------ | -------------------------- | ---------------------------------------------------- |
| `model` | `str` | `None` | Path to Ultralytics YOLO Model File |
| `region` | `list` | `[(20, 400), (1260, 400)]` | List of points defining the counting region. |
| `line_width` | `int` | `2` | Line thickness for bounding boxes. |
| `show` | `bool` | `False` | Flag to control whether to display the video stream. |
### Arguments `model.track`
@ -111,10 +103,7 @@ Estimating object speed with Ultralytics YOLO11 involves combining [object detec
```python
import cv2
from ultralytics import YOLO, solutions
model = YOLO("yolo11n.pt")
names = model.model.names
from ultralytics import solutions
cap = cv2.VideoCapture("path/to/video/file.mp4")
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
@ -122,17 +111,16 @@ video_writer = cv2.VideoWriter("speed_estimation.avi", cv2.VideoWriter_fourcc(*"
# Initialize SpeedEstimator
speed_obj = solutions.SpeedEstimator(
reg_pts=[(0, 360), (1280, 360)],
names=names,
view_img=True,
region=[(0, 360), (1280, 360)],
model="yolo11n.pt",
show=True,
)
while cap.isOpened():
success, im0 = cap.read()
if not success:
break
tracks = model.track(im0, persist=True, show=False)
im0 = speed_obj.estimate_speed(im0, tracks)
im0 = speed_obj.estimate_speed(im0)
video_writer.write(im0)
cap.release()

@ -14,24 +14,21 @@ WORKOUTS_SOLUTION_DEMO = "https://github.com/ultralytics/assets/releases/downloa
def test_major_solutions():
"""Test the object counting, heatmap, speed estimation and queue management solution."""
safe_download(url=MAJOR_SOLUTIONS_DEMO)
model = YOLO("yolo11n.pt")
names = model.names
cap = cv2.VideoCapture("solutions_ci_demo.mp4")
assert cap.isOpened(), "Error reading video file"
region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
counter = solutions.ObjectCounter(region=region_points, model="yolo11n.pt", show=False)
heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False)
speed = solutions.SpeedEstimator(reg_pts=region_points, names=names, view_img=False)
speed = solutions.SpeedEstimator(region=region_points, model="yolo11n.pt", show=False)
queue = solutions.QueueManager(region=region_points, model="yolo11n.pt", show=False)
while cap.isOpened():
success, im0 = cap.read()
if not success:
break
original_im0 = im0.copy()
tracks = model.track(im0, persist=True, show=False)
_ = counter.count(original_im0.copy())
_ = heatmap.generate_heatmap(original_im0.copy())
_ = speed.estimate_speed(original_im0.copy(), tracks)
_ = speed.estimate_speed(original_im0.copy())
_ = queue.process_queue(original_im0.copy())
cap.release()
cv2.destroyAllWindows()

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.8"
__version__ = "8.3.9"
import os

@ -2,15 +2,15 @@
# Configuration for Ultralytics Solutions
model: "yolo11n.pt" # The Ultralytics YOLO11 model to be used (e.g., yolo11n.pt for YOLO11 nano version)
model: "yolo11n.pt" # The Ultralytics YOLO11 model to be used (e.g., yolo11n.pt for YOLO11 nano version and yolov8n.pt for YOLOv8 nano version)
region: # Object counting, queue or speed estimation region points
line_width: 2 # Thickness of the lines used to draw regions on the image/video frames
show: True # Flag to control whether to display output image or not
region: # Object counting, queue or speed estimation region points. Default region points are [(20, 400), (1080, 404), (1080, 360), (20, 360)]
line_width: 2 # Width of the annotator used to draw regions on the image/video frames + bounding boxes and tracks drawing. Default value is 2.
show: True # Flag to control whether to display output image or not, you can set this as False i.e. when deploying it on some embedded devices.
show_in: True # Flag to display objects moving *into* the defined region
show_out: True # Flag to display objects moving *out of* the defined region
classes: # To count specific classes
up_angle: 145.0 # Workouts up_angle for counts, 145.0 is default value
down_angle: 90 # Workouts down_angle for counts, 90 is default value
kpts: [6, 8, 10] # Keypoints for workouts monitoring
colormap: # Colormap for heatmap
classes: # To count specific classes. i.e, if you want to detect, track and count the person with COCO model, you can use classes=0, Default its None
up_angle: 145.0 # Workouts up_angle for counts, 145.0 is default value. You can adjust it for different workouts, based on position of keypoints.
down_angle: 90 # Workouts down_angle for counts, 90 is default value. You can change it for different workouts, based on position of keypoints.
kpts: [6, 8, 10] # Keypoints for workouts monitoring, i.e. If you want to consider keypoints for pushups that have mostly values of [6, 8, 10].
colormap: # Colormap for heatmap, Only OPENCV supported colormaps can be used. By default COLORMAP_PARULA will be used for visualization.

@ -469,11 +469,9 @@ class BaseTrainer:
if RANK in {-1, 0}:
# Do final val with best.pt
epochs = epoch - self.start_epoch + 1 # total training epochs
seconds = time.time() - self.train_time_start # total training seconds
LOGGER.info(f"\n{epochs} epochs completed in {seconds / 3600:.3f} hours.")
seconds = time.time() - self.train_time_start
LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
self.final_eval()
self.validator.metrics.training = {"epochs": epochs, "seconds": seconds} # add training speed
if self.args.plots:
self.plot_metrics()
self.run_callbacks("on_train_end")
@ -504,7 +502,7 @@ class BaseTrainer:
"""Read results.csv into a dict using pandas."""
import pandas as pd # scope for faster 'import ultralytics'
return {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
return pd.read_csv(self.csv).to_dict(orient="list")
def save_model(self):
"""Save model training checkpoints with additional metadata."""
@ -654,10 +652,11 @@ class BaseTrainer:
def save_metrics(self, metrics):
"""Saves training metrics to a CSV file."""
keys, vals = list(metrics.keys()), list(metrics.values())
n = len(metrics) + 1 # number of cols
s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header
n = len(metrics) + 2 # number of cols
s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header
t = time.time() - self.train_time_start
with open(self.csv, "a") as f:
f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n")
f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n")
def plot_metrics(self):
"""Plot and display metrics visually."""

@ -265,8 +265,8 @@ class AutoBackend(nn.Module):
if -1 in tuple(model.get_tensor_shape(name)):
dynamic = True
context.set_input_shape(name, tuple(model.get_tensor_profile_shape(name, 0)[1]))
if dtype == np.float16:
fp16 = True
if dtype == np.float16:
fp16 = True
else:
output_names.append(name)
shape = tuple(context.get_tensor_shape(name))

@ -1061,10 +1061,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
t = str(m)[8:-2].replace("__main__.", "") # module type
m.np = sum(x.numel() for x in m_.parameters()) # number params
m_.np = sum(x.numel() for x in m_.parameters()) # number params
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
if verbose:
LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}") # print
LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f} {t:<45}{str(args):<30}") # print
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_)
if i == 0:

@ -1,116 +1,76 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from collections import defaultdict
from time import time
import cv2
import numpy as np
from ultralytics.utils.checks import check_imshow
from ultralytics.solutions.solutions import BaseSolution, LineString
from ultralytics.utils.plotting import Annotator, colors
class SpeedEstimator:
class SpeedEstimator(BaseSolution):
"""A class to estimate the speed of objects in a real-time video stream based on their tracks."""
def __init__(self, names, reg_pts=None, view_img=False, line_thickness=2, spdl_dist_thresh=10):
"""
Initializes the SpeedEstimator with the given parameters.
Args:
names (dict): Dictionary of class names.
reg_pts (list, optional): List of region points for speed estimation. Defaults to [(20, 400), (1260, 400)].
view_img (bool, optional): Whether to display the image with annotations. Defaults to False.
line_thickness (int, optional): Thickness of the lines for drawing boxes and tracks. Defaults to 2.
spdl_dist_thresh (int, optional): Distance threshold for speed calculation. Defaults to 10.
"""
# Region information
self.reg_pts = reg_pts if reg_pts is not None else [(20, 400), (1260, 400)]
def __init__(self, **kwargs):
"""Initializes the SpeedEstimator with the given parameters."""
super().__init__(**kwargs)
self.names = names # Classes names
self.initialize_region() # Initialize speed region
# Tracking information
self.trk_history = defaultdict(list)
self.view_img = view_img # bool for displaying inference
self.tf = line_thickness # line thickness for annotator
self.spd = {} # set for speed data
self.trkd_ids = [] # list for already speed_estimated and tracked ID's
self.spdl = spdl_dist_thresh # Speed line distance threshold
self.trk_pt = {} # set for tracks previous time
self.trk_pp = {} # set for tracks previous point
# Check if the environment supports imshow
self.env_check = check_imshow(warn=True)
def estimate_speed(self, im0, tracks):
def estimate_speed(self, im0):
"""
Estimates the speed of objects based on tracking data.
Args:
im0 (ndarray): Image.
tracks (list): List of tracks obtained from the object tracking process.
Returns:
(ndarray): The image with annotated boxes and tracks.
im0 (ndarray): The input image that will be used for processing
Returns
im0 (ndarray): The processed image for more usage
"""
if tracks[0].boxes.id is None:
return im0
self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator
self.extract_tracks(im0) # Extract tracks
boxes = tracks[0].boxes.xyxy.cpu()
clss = tracks[0].boxes.cls.cpu().tolist()
t_ids = tracks[0].boxes.id.int().cpu().tolist()
annotator = Annotator(im0, line_width=self.tf)
annotator.draw_region(reg_pts=self.reg_pts, color=(255, 0, 255), thickness=self.tf * 2)
self.annotator.draw_region(
reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2
) # Draw region
for box, t_id, cls in zip(boxes, t_ids, clss):
track = self.trk_history[t_id]
bbox_center = (float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2))
track.append(bbox_center)
for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss):
self.store_tracking_history(track_id, box) # Store track history
if len(track) > 30:
track.pop(0)
# Check if track_id is already in self.trk_pp or trk_pt initialize if not
if track_id not in self.trk_pt:
self.trk_pt[track_id] = 0
if track_id not in self.trk_pp:
self.trk_pp[track_id] = self.track_line[-1]
trk_pts = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
speed_label = f"{int(self.spd[track_id])} km/h" if track_id in self.spd else self.names[int(cls)]
self.annotator.box_label(box, label=speed_label, color=colors(track_id, True)) # Draw bounding box
if t_id not in self.trk_pt:
self.trk_pt[t_id] = 0
# Draw tracks of objects
self.annotator.draw_centroid_and_tracks(
self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width
)
speed_label = f"{int(self.spd[t_id])} km/h" if t_id in self.spd else self.names[int(cls)]
bbox_color = colors(int(t_id), True)
annotator.box_label(box, speed_label, bbox_color)
cv2.polylines(im0, [trk_pts], isClosed=False, color=bbox_color, thickness=self.tf)
cv2.circle(im0, (int(track[-1][0]), int(track[-1][1])), self.tf * 2, bbox_color, -1)
# Calculation of object speed
if not self.reg_pts[0][0] < track[-1][0] < self.reg_pts[1][0]:
return
if self.reg_pts[1][1] - self.spdl < track[-1][1] < self.reg_pts[1][1] + self.spdl:
direction = "known"
elif self.reg_pts[0][1] - self.spdl < track[-1][1] < self.reg_pts[0][1] + self.spdl:
# Calculate object speed and direction based on region intersection
if LineString([self.trk_pp[track_id], self.track_line[-1]]).intersects(self.l_s):
direction = "known"
else:
direction = "unknown"
if self.trk_pt.get(t_id) != 0 and direction != "unknown" and t_id not in self.trkd_ids:
self.trkd_ids.append(t_id)
time_difference = time() - self.trk_pt[t_id]
# Perform speed calculation and tracking updates if direction is valid
if direction == "known" and track_id not in self.trkd_ids:
self.trkd_ids.append(track_id)
time_difference = time() - self.trk_pt[track_id]
if time_difference > 0:
self.spd[t_id] = np.abs(track[-1][1] - self.trk_pp[t_id][1]) / time_difference
self.trk_pt[t_id] = time()
self.trk_pp[t_id] = track[-1]
if self.view_img and self.env_check:
cv2.imshow("Ultralytics Speed Estimation", im0)
if cv2.waitKey(1) & 0xFF == ord("q"):
return
self.spd[track_id] = np.abs(self.track_line[-1][1] - self.trk_pp[track_id][1]) / time_difference
return im0
self.trk_pt[track_id] = time()
self.trk_pp[track_id] = self.track_line[-1]
self.display_output(im0) # display output with base class function
if __name__ == "__main__":
names = {0: "person", 1: "car"} # example class names
speed_estimator = SpeedEstimator(names)
return im0 # return output image for more usage

Loading…
Cancel
Save