Merge branch 'main' into mkdocs-exclude

mkdocs-exclude
Ultralytics Assistant 8 months ago committed by GitHub
commit a1f4841cdc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 8
      .github/workflows/ci.yaml
  2. 1
      docs/en/modes/predict.md
  3. 2
      ultralytics/__init__.py
  4. 2
      ultralytics/cfg/__init__.py
  5. 24
      ultralytics/engine/results.py
  6. 43
      ultralytics/trackers/basetrack.py
  7. 82
      ultralytics/trackers/bot_sort.py
  8. 117
      ultralytics/trackers/byte_tracker.py
  9. 19
      ultralytics/trackers/track.py
  10. 90
      ultralytics/trackers/utils/gmc.py
  11. 193
      ultralytics/trackers/utils/kalman_filter.py
  12. 52
      ultralytics/trackers/utils/matching.py
  13. 2
      ultralytics/utils/__init__.py
  14. 103
      ultralytics/utils/files.py
  15. 14
      ultralytics/utils/plotting.py

@ -336,6 +336,13 @@ jobs:
channels: conda-forge,defaults channels: conda-forge,defaults
channel-priority: true channel-priority: true
activate-environment: anaconda-client-env activate-environment: anaconda-client-env
- name: Cleanup toolcache
run: |
echo "Free space before deletion:"
df -h /
rm -rf /opt/hostedtoolcache
echo "Free space after deletion:"
df -h /
- name: Install Linux packages - name: Install Linux packages
run: | run: |
# Fix cv2 ImportError: 'libEGL.so.1: cannot open shared object file: No such file or directory' # Fix cv2 ImportError: 'libEGL.so.1: cannot open shared object file: No such file or directory'
@ -361,6 +368,7 @@ jobs:
yolo val model=yolov8n.pt data=coco8.yaml imgsz=32 yolo val model=yolov8n.pt data=coco8.yaml imgsz=32
yolo export model=yolov8n.pt format=torchscript imgsz=160 yolo export model=yolov8n.pt format=torchscript imgsz=160
- name: Test Python - name: Test Python
# Note this step must use the updated default bash environment, not a python environment
run: | run: |
python -c " python -c "
from ultralytics import YOLO from ultralytics import YOLO

@ -720,6 +720,7 @@ The `plot()` method supports various arguments to customize the output:
| `show` | `bool` | Display the annotated image directly using the default image viewer. | `False` | | `show` | `bool` | Display the annotated image directly using the default image viewer. | `False` |
| `save` | `bool` | Save the annotated image to a file specified by `filename`. | `False` | | `save` | `bool` | Save the annotated image to a file specified by `filename`. | `False` |
| `filename` | `str` | Path and name of the file to save the annotated image if `save` is `True`. | `None` | | `filename` | `str` | Path and name of the file to save the annotated image if `save` is `True`. | `None` |
| `color_mode` | `str` | Specify the color mode, e.g., 'instance' or 'class'. | `'class'` |
## Thread-Safe Inference ## Thread-Safe Inference

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.76" __version__ = "8.2.77"
import os import os

@ -81,7 +81,7 @@ CLI_HELP_MSG = f"""
5. Explore your datasets using semantic search and SQL with a simple GUI powered by Ultralytics Explorer API 5. Explore your datasets using semantic search and SQL with a simple GUI powered by Ultralytics Explorer API
yolo explorer data=data.yaml model=yolov8n.pt yolo explorer data=data.yaml model=yolov8n.pt
6. Streamlit real-time object detection on your webcam with Ultralytics YOLOv8 6. Streamlit real-time webcam inference GUI
yolo streamlit-predict yolo streamlit-predict
7. Run special commands: 7. Run special commands:

@ -460,6 +460,7 @@ class Results(SimpleClass):
show=False, show=False,
save=False, save=False,
filename=None, filename=None,
color_mode="class",
): ):
""" """
Plots detection results on an input RGB image. Plots detection results on an input RGB image.
@ -481,6 +482,7 @@ class Results(SimpleClass):
show (bool): Whether to display the annotated image. show (bool): Whether to display the annotated image.
save (bool): Whether to save the annotated image. save (bool): Whether to save the annotated image.
filename (str | None): Filename to save image if save is True. filename (str | None): Filename to save image if save is True.
color_mode (bool): Specify the color mode, e.g., 'instance' or 'class'. Default to 'class'.
Returns: Returns:
(np.ndarray): Annotated image as a numpy array. (np.ndarray): Annotated image as a numpy array.
@ -491,6 +493,7 @@ class Results(SimpleClass):
... im = result.plot() ... im = result.plot()
... im.show() ... im.show()
""" """
assert color_mode in {"instance", "class"}, f"Expected color_mode='instance' or 'class', not {color_mode}."
if img is None and isinstance(self.orig_img, torch.Tensor): if img is None and isinstance(self.orig_img, torch.Tensor):
img = (self.orig_img[0].detach().permute(1, 2, 0).contiguous() * 255).to(torch.uint8).cpu().numpy() img = (self.orig_img[0].detach().permute(1, 2, 0).contiguous() * 255).to(torch.uint8).cpu().numpy()
@ -519,17 +522,22 @@ class Results(SimpleClass):
.contiguous() .contiguous()
/ 255 / 255
) )
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks)) idx = pred_boxes.cls if pred_boxes and color_mode == "class" else reversed(range(len(pred_masks)))
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu) annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu)
# Plot Detect results # Plot Detect results
if pred_boxes is not None and show_boxes: if pred_boxes is not None and show_boxes:
for d in reversed(pred_boxes): for i, d in enumerate(reversed(pred_boxes)):
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item()) c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
name = ("" if id is None else f"id:{id} ") + names[c] name = ("" if id is None else f"id:{id} ") + names[c]
label = (f"{name} {conf:.2f}" if conf else name) if labels else None label = (f"{name} {conf:.2f}" if conf else name) if labels else None
box = d.xyxyxyxy.reshape(-1, 4, 2).squeeze() if is_obb else d.xyxy.squeeze() box = d.xyxyxyxy.reshape(-1, 4, 2).squeeze() if is_obb else d.xyxy.squeeze()
annotator.box_label(box, label, color=colors(c, True), rotated=is_obb) annotator.box_label(
box,
label,
color=colors(i if color_mode == "instance" else c, True),
rotated=is_obb,
)
# Plot Classify results # Plot Classify results
if pred_probs is not None and show_probs: if pred_probs is not None and show_probs:
@ -539,8 +547,14 @@ class Results(SimpleClass):
# Plot Pose results # Plot Pose results
if self.keypoints is not None: if self.keypoints is not None:
for k in reversed(self.keypoints.data): for i, k in enumerate(reversed(self.keypoints.data)):
annotator.kpts(k, self.orig_shape, radius=kpt_radius, kpt_line=kpt_line) annotator.kpts(
k,
self.orig_shape,
radius=kpt_radius,
kpt_line=kpt_line,
kpt_color=colors(i, True) if color_mode == "instance" else None,
)
# Show results # Show results
if show: if show:

@ -15,6 +15,11 @@ class TrackState:
Tracked (int): State when the object is successfully tracked in subsequent frames. Tracked (int): State when the object is successfully tracked in subsequent frames.
Lost (int): State when the object is no longer tracked. Lost (int): State when the object is no longer tracked.
Removed (int): State when the object is removed from tracking. Removed (int): State when the object is removed from tracking.
Examples:
>>> state = TrackState.New
>>> if state == TrackState.New:
>>> print("Object is newly detected.")
""" """
New = 0 New = 0
@ -33,13 +38,13 @@ class BaseTrack:
is_activated (bool): Flag indicating whether the track is currently active. is_activated (bool): Flag indicating whether the track is currently active.
state (TrackState): Current state of the track. state (TrackState): Current state of the track.
history (OrderedDict): Ordered history of the track's states. history (OrderedDict): Ordered history of the track's states.
features (list): List of features extracted from the object for tracking. features (List): List of features extracted from the object for tracking.
curr_feature (any): The current feature of the object being tracked. curr_feature (Any): The current feature of the object being tracked.
score (float): The confidence score of the tracking. score (float): The confidence score of the tracking.
start_frame (int): The frame number where tracking started. start_frame (int): The frame number where tracking started.
frame_id (int): The most recent frame ID processed by the track. frame_id (int): The most recent frame ID processed by the track.
time_since_update (int): Frames passed since the last update. time_since_update (int): Frames passed since the last update.
location (tuple): The location of the object in the context of multi-camera tracking. location (Tuple): The location of the object in the context of multi-camera tracking.
Methods: Methods:
end_frame: Returns the ID of the last frame where the object was tracked. end_frame: Returns the ID of the last frame where the object was tracked.
@ -50,12 +55,26 @@ class BaseTrack:
mark_lost: Marks the track as lost. mark_lost: Marks the track as lost.
mark_removed: Marks the track as removed. mark_removed: Marks the track as removed.
reset_id: Resets the global track ID counter. reset_id: Resets the global track ID counter.
Examples:
Initialize a new track and mark it as lost:
>>> track = BaseTrack()
>>> track.mark_lost()
>>> print(track.state) # Output: 2 (TrackState.Lost)
""" """
_count = 0 _count = 0
def __init__(self): def __init__(self):
"""Initializes a new track with unique ID and foundational tracking attributes.""" """
Initializes a new track with a unique ID and foundational tracking attributes.
Examples:
Initialize a new track
>>> track = BaseTrack()
>>> print(track.track_id)
0
"""
self.track_id = 0 self.track_id = 0
self.is_activated = False self.is_activated = False
self.state = TrackState.New self.state = TrackState.New
@ -70,36 +89,36 @@ class BaseTrack:
@property @property
def end_frame(self): def end_frame(self):
"""Return the last frame ID of the track.""" """Returns the ID of the most recent frame where the object was tracked."""
return self.frame_id return self.frame_id
@staticmethod @staticmethod
def next_id(): def next_id():
"""Increment and return the global track ID counter.""" """Increment and return the next unique global track ID for object tracking."""
BaseTrack._count += 1 BaseTrack._count += 1
return BaseTrack._count return BaseTrack._count
def activate(self, *args): def activate(self, *args):
"""Abstract method to activate the track with provided arguments.""" """Activates the track with provided arguments, initializing necessary attributes for tracking."""
raise NotImplementedError raise NotImplementedError
def predict(self): def predict(self):
"""Abstract method to predict the next state of the track.""" """Predicts the next state of the track based on the current state and tracking model."""
raise NotImplementedError raise NotImplementedError
def update(self, *args, **kwargs): def update(self, *args, **kwargs):
"""Abstract method to update the track with new observations.""" """Updates the track with new observations and data, modifying its state and attributes accordingly."""
raise NotImplementedError raise NotImplementedError
def mark_lost(self): def mark_lost(self):
"""Mark the track as lost.""" """Marks the track as lost by updating its state to TrackState.Lost."""
self.state = TrackState.Lost self.state = TrackState.Lost
def mark_removed(self): def mark_removed(self):
"""Mark the track as removed.""" """Marks the track as removed by setting its state to TrackState.Removed."""
self.state = TrackState.Removed self.state = TrackState.Removed
@staticmethod @staticmethod
def reset_id(): def reset_id():
"""Reset the global track ID counter.""" """Reset the global track ID counter to its initial value."""
BaseTrack._count = 0 BaseTrack._count = 0

@ -15,6 +15,9 @@ class BOTrack(STrack):
""" """
An extended version of the STrack class for YOLOv8, adding object tracking features. An extended version of the STrack class for YOLOv8, adding object tracking features.
This class extends the STrack class to include additional functionalities for object tracking, such as feature
smoothing, Kalman filter prediction, and reactivation of tracks.
Attributes: Attributes:
shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack. shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack.
smooth_feat (np.ndarray): Smoothed feature vector. smooth_feat (np.ndarray): Smoothed feature vector.
@ -34,16 +37,35 @@ class BOTrack(STrack):
convert_coords(tlwh): Converts tlwh bounding box coordinates to xywh format. convert_coords(tlwh): Converts tlwh bounding box coordinates to xywh format.
tlwh_to_xywh(tlwh): Convert bounding box to xywh format `(center x, center y, width, height)`. tlwh_to_xywh(tlwh): Convert bounding box to xywh format `(center x, center y, width, height)`.
Usage: Examples:
bo_track = BOTrack(tlwh, score, cls, feat) Create a BOTrack instance and update its features
bo_track.predict() >>> bo_track = BOTrack(tlwh=[100, 50, 80, 40], score=0.9, cls=1, feat=np.random.rand(128))
bo_track.update(new_track, frame_id) >>> bo_track.predict()
>>> new_track = BOTrack(tlwh=[110, 60, 80, 40], score=0.85, cls=1, feat=np.random.rand(128))
>>> bo_track.update(new_track, frame_id=2)
""" """
shared_kalman = KalmanFilterXYWH() shared_kalman = KalmanFilterXYWH()
def __init__(self, tlwh, score, cls, feat=None, feat_history=50): def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
"""Initialize YOLOv8 object with temporal parameters, such as feature history, alpha and current features.""" """
Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features.
Args:
tlwh (np.ndarray): Bounding box coordinates in tlwh format (top left x, top left y, width, height).
score (float): Confidence score of the detection.
cls (int): Class ID of the detected object.
feat (np.ndarray | None): Feature vector associated with the detection.
feat_history (int): Maximum length of the feature history deque.
Examples:
Initialize a BOTrack object with bounding box, score, class ID, and feature vector
>>> tlwh = np.array([100, 50, 80, 120])
>>> score = 0.9
>>> cls = 1
>>> feat = np.random.rand(128)
>>> bo_track = BOTrack(tlwh, score, cls, feat)
"""
super().__init__(tlwh, score, cls) super().__init__(tlwh, score, cls)
self.smooth_feat = None self.smooth_feat = None
@ -54,7 +76,7 @@ class BOTrack(STrack):
self.alpha = 0.9 self.alpha = 0.9
def update_features(self, feat): def update_features(self, feat):
"""Update features vector and smooth it using exponential moving average.""" """Update the feature vector and apply exponential moving average smoothing."""
feat /= np.linalg.norm(feat) feat /= np.linalg.norm(feat)
self.curr_feat = feat self.curr_feat = feat
if self.smooth_feat is None: if self.smooth_feat is None:
@ -65,7 +87,7 @@ class BOTrack(STrack):
self.smooth_feat /= np.linalg.norm(self.smooth_feat) self.smooth_feat /= np.linalg.norm(self.smooth_feat)
def predict(self): def predict(self):
"""Predicts the mean and covariance using Kalman filter.""" """Predicts the object's future state using the Kalman filter to update its mean and covariance."""
mean_state = self.mean.copy() mean_state = self.mean.copy()
if self.state != TrackState.Tracked: if self.state != TrackState.Tracked:
mean_state[6] = 0 mean_state[6] = 0
@ -80,14 +102,14 @@ class BOTrack(STrack):
super().re_activate(new_track, frame_id, new_id) super().re_activate(new_track, frame_id, new_id)
def update(self, new_track, frame_id): def update(self, new_track, frame_id):
"""Update the YOLOv8 instance with new track and frame ID.""" """Updates the YOLOv8 instance with new track information and the current frame ID."""
if new_track.curr_feat is not None: if new_track.curr_feat is not None:
self.update_features(new_track.curr_feat) self.update_features(new_track.curr_feat)
super().update(new_track, frame_id) super().update(new_track, frame_id)
@property @property
def tlwh(self): def tlwh(self):
"""Get current position in bounding box format `(top left x, top left y, width, height)`.""" """Returns the current bounding box position in `(top left x, top left y, width, height)` format."""
if self.mean is None: if self.mean is None:
return self._tlwh.copy() return self._tlwh.copy()
ret = self.mean[:4].copy() ret = self.mean[:4].copy()
@ -96,7 +118,7 @@ class BOTrack(STrack):
@staticmethod @staticmethod
def multi_predict(stracks): def multi_predict(stracks):
"""Predicts the mean and covariance of multiple object tracks using shared Kalman filter.""" """Predicts the mean and covariance for multiple object tracks using a shared Kalman filter."""
if len(stracks) <= 0: if len(stracks) <= 0:
return return
multi_mean = np.asarray([st.mean.copy() for st in stracks]) multi_mean = np.asarray([st.mean.copy() for st in stracks])
@ -111,12 +133,12 @@ class BOTrack(STrack):
stracks[i].covariance = cov stracks[i].covariance = cov
def convert_coords(self, tlwh): def convert_coords(self, tlwh):
"""Converts Top-Left-Width-Height bounding box coordinates to X-Y-Width-Height format.""" """Converts tlwh bounding box coordinates to xywh format."""
return self.tlwh_to_xywh(tlwh) return self.tlwh_to_xywh(tlwh)
@staticmethod @staticmethod
def tlwh_to_xywh(tlwh): def tlwh_to_xywh(tlwh):
"""Convert bounding box to format `(center x, center y, width, height)`.""" """Convert bounding box from tlwh (top-left-width-height) to xywh (center-x-center-y-width-height) format."""
ret = np.asarray(tlwh).copy() ret = np.asarray(tlwh).copy()
ret[:2] += ret[2:] / 2 ret[:2] += ret[2:] / 2
return ret return ret
@ -129,9 +151,9 @@ class BOTSORT(BYTETracker):
Attributes: Attributes:
proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections. proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.
appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections. appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections.
encoder (object): Object to handle ReID embeddings, set to None if ReID is not enabled. encoder (Any): Object to handle ReID embeddings, set to None if ReID is not enabled.
gmc (GMC): An instance of the GMC algorithm for data association. gmc (GMC): An instance of the GMC algorithm for data association.
args (object): Parsed command-line arguments containing tracking parameters. args (Any): Parsed command-line arguments containing tracking parameters.
Methods: Methods:
get_kalmanfilter(): Returns an instance of KalmanFilterXYWH for object tracking. get_kalmanfilter(): Returns an instance of KalmanFilterXYWH for object tracking.
@ -139,17 +161,29 @@ class BOTSORT(BYTETracker):
get_dists(tracks, detections): Get distances between tracks and detections using IoU and (optionally) ReID. get_dists(tracks, detections): Get distances between tracks and detections using IoU and (optionally) ReID.
multi_predict(tracks): Predict and track multiple objects with YOLOv8 model. multi_predict(tracks): Predict and track multiple objects with YOLOv8 model.
Usage: Examples:
bot_sort = BOTSORT(args, frame_rate) Initialize BOTSORT and process detections
bot_sort.init_track(dets, scores, cls, img) >>> bot_sort = BOTSORT(args, frame_rate=30)
bot_sort.multi_predict(tracks) >>> bot_sort.init_track(dets, scores, cls, img)
>>> bot_sort.multi_predict(tracks)
Note: Note:
The class is designed to work with the YOLOv8 object detection model and supports ReID only if enabled via args. The class is designed to work with the YOLOv8 object detection model and supports ReID only if enabled via args.
""" """
def __init__(self, args, frame_rate=30): def __init__(self, args, frame_rate=30):
"""Initialize YOLOv8 object with ReID module and GMC algorithm.""" """
Initialize YOLOv8 object with ReID module and GMC algorithm.
Args:
args (object): Parsed command-line arguments containing tracking parameters.
frame_rate (int): Frame rate of the video being processed.
Examples:
Initialize BOTSORT with command-line arguments and a specified frame rate:
>>> args = parse_args()
>>> bot_sort = BOTSORT(args, frame_rate=30)
"""
super().__init__(args, frame_rate) super().__init__(args, frame_rate)
# ReID module # ReID module
self.proximity_thresh = args.proximity_thresh self.proximity_thresh = args.proximity_thresh
@ -161,11 +195,11 @@ class BOTSORT(BYTETracker):
self.gmc = GMC(method=args.gmc_method) self.gmc = GMC(method=args.gmc_method)
def get_kalmanfilter(self): def get_kalmanfilter(self):
"""Returns an instance of KalmanFilterXYWH for object tracking.""" """Returns an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process."""
return KalmanFilterXYWH() return KalmanFilterXYWH()
def init_track(self, dets, scores, cls, img=None): def init_track(self, dets, scores, cls, img=None):
"""Initialize track with detections, scores, and classes.""" """Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features."""
if len(dets) == 0: if len(dets) == 0:
return [] return []
if self.args.with_reid and self.encoder is not None: if self.args.with_reid and self.encoder is not None:
@ -175,7 +209,7 @@ class BOTSORT(BYTETracker):
return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
def get_dists(self, tracks, detections): def get_dists(self, tracks, detections):
"""Get distances between tracks and detections using IoU and (optionally) ReID embeddings.""" """Calculates distances between tracks and detections using IoU and optionally ReID embeddings."""
dists = matching.iou_distance(tracks, detections) dists = matching.iou_distance(tracks, detections)
dists_mask = dists > self.proximity_thresh dists_mask = dists > self.proximity_thresh
@ -190,10 +224,10 @@ class BOTSORT(BYTETracker):
return dists return dists
def multi_predict(self, tracks): def multi_predict(self, tracks):
"""Predict and track multiple objects with YOLOv8 model.""" """Predicts the mean and covariance of multiple object tracks using a shared Kalman filter."""
BOTrack.multi_predict(tracks) BOTrack.multi_predict(tracks)
def reset(self): def reset(self):
"""Reset tracker.""" """Resets the BOTSORT tracker to its initial state, clearing all tracked objects and internal states."""
super().reset() super().reset()
self.gmc.reset_params() self.gmc.reset_params()

@ -25,7 +25,7 @@ class STrack(BaseTrack):
is_activated (bool): Boolean flag indicating if the track has been activated. is_activated (bool): Boolean flag indicating if the track has been activated.
score (float): Confidence score of the track. score (float): Confidence score of the track.
tracklet_len (int): Length of the tracklet. tracklet_len (int): Length of the tracklet.
cls (any): Class label for the object. cls (Any): Class label for the object.
idx (int): Index or identifier for the object. idx (int): Index or identifier for the object.
frame_id (int): Current frame ID. frame_id (int): Current frame ID.
start_frame (int): Frame where the object was first detected. start_frame (int): Frame where the object was first detected.
@ -39,12 +39,31 @@ class STrack(BaseTrack):
update(new_track, frame_id): Update the state of a matched track. update(new_track, frame_id): Update the state of a matched track.
convert_coords(tlwh): Convert bounding box to x-y-aspect-height format. convert_coords(tlwh): Convert bounding box to x-y-aspect-height format.
tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format. tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
Examples:
Initialize and activate a new track
>>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls='person')
>>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1)
""" """
shared_kalman = KalmanFilterXYAH() shared_kalman = KalmanFilterXYAH()
def __init__(self, xywh, score, cls): def __init__(self, xywh, score, cls):
"""Initialize new STrack instance.""" """
Initialize a new STrack instance.
Args:
xywh (List[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where
(x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.
score (float): Confidence score of the detection.
cls (Any): Class label for the detected object.
Examples:
>>> xywh = [100.0, 150.0, 50.0, 75.0, 1]
>>> score = 0.9
>>> cls = 'person'
>>> track = STrack(xywh, score, cls)
"""
super().__init__() super().__init__()
# xywh+idx or xywha+idx # xywh+idx or xywha+idx
assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}" assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}"
@ -60,7 +79,7 @@ class STrack(BaseTrack):
self.angle = xywh[4] if len(xywh) == 6 else None self.angle = xywh[4] if len(xywh) == 6 else None
def predict(self): def predict(self):
"""Predicts mean and covariance using Kalman filter.""" """Predicts the next state (mean and covariance) of the object using the Kalman filter."""
mean_state = self.mean.copy() mean_state = self.mean.copy()
if self.state != TrackState.Tracked: if self.state != TrackState.Tracked:
mean_state[7] = 0 mean_state[7] = 0
@ -68,7 +87,7 @@ class STrack(BaseTrack):
@staticmethod @staticmethod
def multi_predict(stracks): def multi_predict(stracks):
"""Perform multi-object predictive tracking using Kalman filter for given stracks.""" """Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances."""
if len(stracks) <= 0: if len(stracks) <= 0:
return return
multi_mean = np.asarray([st.mean.copy() for st in stracks]) multi_mean = np.asarray([st.mean.copy() for st in stracks])
@ -83,7 +102,7 @@ class STrack(BaseTrack):
@staticmethod @staticmethod
def multi_gmc(stracks, H=np.eye(2, 3)): def multi_gmc(stracks, H=np.eye(2, 3)):
"""Update state tracks positions and covariances using a homography matrix.""" """Update state tracks positions and covariances using a homography matrix for multiple tracks."""
if len(stracks) > 0: if len(stracks) > 0:
multi_mean = np.asarray([st.mean.copy() for st in stracks]) multi_mean = np.asarray([st.mean.copy() for st in stracks])
multi_covariance = np.asarray([st.covariance for st in stracks]) multi_covariance = np.asarray([st.covariance for st in stracks])
@ -101,7 +120,7 @@ class STrack(BaseTrack):
stracks[i].covariance = cov stracks[i].covariance = cov
def activate(self, kalman_filter, frame_id): def activate(self, kalman_filter, frame_id):
"""Start a new tracklet.""" """Activate a new tracklet using the provided Kalman filter and initialize its state and covariance."""
self.kalman_filter = kalman_filter self.kalman_filter = kalman_filter
self.track_id = self.next_id() self.track_id = self.next_id()
self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh)) self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh))
@ -114,7 +133,7 @@ class STrack(BaseTrack):
self.start_frame = frame_id self.start_frame = frame_id
def re_activate(self, new_track, frame_id, new_id=False): def re_activate(self, new_track, frame_id, new_id=False):
"""Reactivates a previously lost track with a new detection.""" """Reactivates a previously lost track using new detection data and updates its state and attributes."""
self.mean, self.covariance = self.kalman_filter.update( self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.convert_coords(new_track.tlwh) self.mean, self.covariance, self.convert_coords(new_track.tlwh)
) )
@ -136,6 +155,12 @@ class STrack(BaseTrack):
Args: Args:
new_track (STrack): The new track containing updated information. new_track (STrack): The new track containing updated information.
frame_id (int): The ID of the current frame. frame_id (int): The ID of the current frame.
Examples:
Update the state of a track with new detection information
>>> track = STrack([100, 200, 50, 80, 0.9, 1])
>>> new_track = STrack([105, 205, 55, 85, 0.95, 1])
>>> track.update(new_track, 2)
""" """
self.frame_id = frame_id self.frame_id = frame_id
self.tracklet_len += 1 self.tracklet_len += 1
@ -158,7 +183,7 @@ class STrack(BaseTrack):
@property @property
def tlwh(self): def tlwh(self):
"""Get current position in bounding box format (top left x, top left y, width, height).""" """Returns the bounding box in top-left-width-height format from the current state estimate."""
if self.mean is None: if self.mean is None:
return self._tlwh.copy() return self._tlwh.copy()
ret = self.mean[:4].copy() ret = self.mean[:4].copy()
@ -168,16 +193,14 @@ class STrack(BaseTrack):
@property @property
def xyxy(self): def xyxy(self):
"""Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right).""" """Converts bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
ret = self.tlwh.copy() ret = self.tlwh.copy()
ret[2:] += ret[:2] ret[2:] += ret[:2]
return ret return ret
@staticmethod @staticmethod
def tlwh_to_xyah(tlwh): def tlwh_to_xyah(tlwh):
"""Convert bounding box to format (center x, center y, aspect ratio, height), where the aspect ratio is width / """Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format."""
height.
"""
ret = np.asarray(tlwh).copy() ret = np.asarray(tlwh).copy()
ret[:2] += ret[2:] / 2 ret[:2] += ret[2:] / 2
ret[2] /= ret[3] ret[2] /= ret[3]
@ -185,14 +208,14 @@ class STrack(BaseTrack):
@property @property
def xywh(self): def xywh(self):
"""Get current position in bounding box format (center x, center y, width, height).""" """Returns the current position of the bounding box in (center x, center y, width, height) format."""
ret = np.asarray(self.tlwh).copy() ret = np.asarray(self.tlwh).copy()
ret[:2] += ret[2:] / 2 ret[:2] += ret[2:] / 2
return ret return ret
@property @property
def xywha(self): def xywha(self):
"""Get current position in bounding box format (center x, center y, width, height, angle).""" """Returns position in (center x, center y, width, height, angle) format, warning if angle is missing."""
if self.angle is None: if self.angle is None:
LOGGER.warning("WARNING ⚠ `angle` attr not found, returning `xywh` instead.") LOGGER.warning("WARNING ⚠ `angle` attr not found, returning `xywh` instead.")
return self.xywh return self.xywh
@ -200,12 +223,12 @@ class STrack(BaseTrack):
@property @property
def result(self): def result(self):
"""Get current tracking results.""" """Returns the current tracking results in the appropriate bounding box format."""
coords = self.xyxy if self.angle is None else self.xywha coords = self.xyxy if self.angle is None else self.xywha
return coords.tolist() + [self.track_id, self.score, self.cls, self.idx] return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
def __repr__(self): def __repr__(self):
"""Return a string representation of the BYTETracker object with start and end frames and track ID.""" """Returns a string representation of the STrack object including start frame, end frame, and track ID."""
return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})" return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
@ -213,18 +236,18 @@ class BYTETracker:
""" """
BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking. BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
The class is responsible for initializing, updating, and managing the tracks for detected objects in a video Responsible for initializing, updating, and managing the tracks for detected objects in a video sequence.
sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for predicting
predicting the new object locations, and performs data association. the new object locations, and performs data association.
Attributes: Attributes:
tracked_stracks (list[STrack]): List of successfully activated tracks. tracked_stracks (List[STrack]): List of successfully activated tracks.
lost_stracks (list[STrack]): List of lost tracks. lost_stracks (List[STrack]): List of lost tracks.
removed_stracks (list[STrack]): List of removed tracks. removed_stracks (List[STrack]): List of removed tracks.
frame_id (int): The current frame ID. frame_id (int): The current frame ID.
args (namespace): Command-line arguments. args (Namespace): Command-line arguments.
max_time_lost (int): The maximum frames for a track to be considered as 'lost'. max_time_lost (int): The maximum frames for a track to be considered as 'lost'.
kalman_filter (object): Kalman Filter object. kalman_filter (KalmanFilterXYAH): Kalman Filter object.
Methods: Methods:
update(results, img=None): Updates object tracker with new detections. update(results, img=None): Updates object tracker with new detections.
@ -236,10 +259,27 @@ class BYTETracker:
joint_stracks(tlista, tlistb): Combines two lists of stracks. joint_stracks(tlista, tlistb): Combines two lists of stracks.
sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list. sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list.
remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IoU. remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IoU.
Examples:
Initialize BYTETracker and update with detection results
>>> tracker = BYTETracker(args, frame_rate=30)
>>> results = yolo_model.detect(image)
>>> tracked_objects = tracker.update(results)
""" """
def __init__(self, args, frame_rate=30): def __init__(self, args, frame_rate=30):
"""Initialize a YOLOv8 object to track objects with given arguments and frame rate.""" """
Initialize a BYTETracker instance for object tracking.
Args:
args (Namespace): Command-line arguments containing tracking parameters.
frame_rate (int): Frame rate of the video sequence.
Examples:
Initialize BYTETracker with command-line arguments and a frame rate of 30
>>> args = Namespace(track_buffer=30)
>>> tracker = BYTETracker(args, frame_rate=30)
"""
self.tracked_stracks = [] # type: list[STrack] self.tracked_stracks = [] # type: list[STrack]
self.lost_stracks = [] # type: list[STrack] self.lost_stracks = [] # type: list[STrack]
self.removed_stracks = [] # type: list[STrack] self.removed_stracks = [] # type: list[STrack]
@ -251,7 +291,7 @@ class BYTETracker:
self.reset_id() self.reset_id()
def update(self, results, img=None): def update(self, results, img=None):
"""Updates object tracker with new detections and returns tracked object bounding boxes.""" """Updates the tracker with new detections and returns the current list of tracked objects."""
self.frame_id += 1 self.frame_id += 1
activated_stracks = [] activated_stracks = []
refind_stracks = [] refind_stracks = []
@ -365,31 +405,31 @@ class BYTETracker:
return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32) return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
def get_kalmanfilter(self): def get_kalmanfilter(self):
"""Returns a Kalman filter object for tracking bounding boxes.""" """Returns a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH."""
return KalmanFilterXYAH() return KalmanFilterXYAH()
def init_track(self, dets, scores, cls, img=None): def init_track(self, dets, scores, cls, img=None):
"""Initialize object tracking with detections and scores using STrack algorithm.""" """Initializes object tracking with given detections, scores, and class labels using the STrack algorithm."""
return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
def get_dists(self, tracks, detections): def get_dists(self, tracks, detections):
"""Calculates the distance between tracks and detections using IoU and fuses scores.""" """Calculates the distance between tracks and detections using IoU and optionally fuses scores."""
dists = matching.iou_distance(tracks, detections) dists = matching.iou_distance(tracks, detections)
if self.args.fuse_score: if self.args.fuse_score:
dists = matching.fuse_score(dists, detections) dists = matching.fuse_score(dists, detections)
return dists return dists
def multi_predict(self, tracks): def multi_predict(self, tracks):
"""Returns the predicted tracks using the YOLOv8 network.""" """Predict the next states for multiple tracks using Kalman filter."""
STrack.multi_predict(tracks) STrack.multi_predict(tracks)
@staticmethod @staticmethod
def reset_id(): def reset_id():
"""Resets the ID counter of STrack.""" """Resets the ID counter for STrack instances to ensure unique track IDs across tracking sessions."""
STrack.reset_id() STrack.reset_id()
def reset(self): def reset(self):
"""Reset tracker.""" """Resets the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter."""
self.tracked_stracks = [] # type: list[STrack] self.tracked_stracks = [] # type: list[STrack]
self.lost_stracks = [] # type: list[STrack] self.lost_stracks = [] # type: list[STrack]
self.removed_stracks = [] # type: list[STrack] self.removed_stracks = [] # type: list[STrack]
@ -399,7 +439,7 @@ class BYTETracker:
@staticmethod @staticmethod
def joint_stracks(tlista, tlistb): def joint_stracks(tlista, tlistb):
"""Combine two lists of stracks into a single one.""" """Combines two lists of STrack objects into a single list, ensuring no duplicates based on track IDs."""
exists = {} exists = {}
res = [] res = []
for t in tlista: for t in tlista:
@ -414,20 +454,13 @@ class BYTETracker:
@staticmethod @staticmethod
def sub_stracks(tlista, tlistb): def sub_stracks(tlista, tlistb):
"""DEPRECATED CODE in https://github.com/ultralytics/ultralytics/pull/1890/ """Filters out the stracks present in the second list from the first list."""
stracks = {t.track_id: t for t in tlista}
for t in tlistb:
tid = t.track_id
if stracks.get(tid, 0):
del stracks[tid]
return list(stracks.values())
"""
track_ids_b = {t.track_id for t in tlistb} track_ids_b = {t.track_id for t in tlistb}
return [t for t in tlista if t.track_id not in track_ids_b] return [t for t in tlista if t.track_id not in track_ids_b]
@staticmethod @staticmethod
def remove_duplicate_stracks(stracksa, stracksb): def remove_duplicate_stracks(stracksa, stracksb):
"""Remove duplicate stracks with non-maximum IoU distance.""" """Removes duplicate stracks from two lists based on Intersection over Union (IoU) distance."""
pdist = matching.iou_distance(stracksa, stracksb) pdist = matching.iou_distance(stracksa, stracksb)
pairs = np.where(pdist < 0.15) pairs = np.where(pdist < 0.15)
dupa, dupb = [], [] dupa, dupb = [], []

@ -21,10 +21,15 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
Args: Args:
predictor (object): The predictor object to initialize trackers for. predictor (object): The predictor object to initialize trackers for.
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. persist (bool): Whether to persist the trackers if they already exist.
Raises: Raises:
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'. AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
Examples:
Initialize trackers for a predictor object:
>>> predictor = SomePredictorClass()
>>> on_predict_start(predictor, persist=True)
""" """
if hasattr(predictor, "trackers") and persist: if hasattr(predictor, "trackers") and persist:
return return
@ -51,7 +56,12 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
Args: Args:
predictor (object): The predictor object containing the predictions. predictor (object): The predictor object containing the predictions.
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. persist (bool): Whether to persist the trackers if they already exist.
Examples:
Postprocess predictions and update with tracking
>>> predictor = YourPredictorClass()
>>> on_predict_postprocess_end(predictor, persist=True)
""" """
path, im0s = predictor.batch[:2] path, im0s = predictor.batch[:2]
@ -84,6 +94,11 @@ def register_tracker(model: object, persist: bool) -> None:
Args: Args:
model (object): The model object to register tracking callbacks for. model (object): The model object to register tracking callbacks for.
persist (bool): Whether to persist the trackers if they already exist. persist (bool): Whether to persist the trackers if they already exist.
Examples:
Register tracking callbacks to a YOLO model
>>> model = YOLOModel()
>>> register_tracker(model, persist=True)
""" """
model.add_callback("on_predict_start", partial(on_predict_start, persist=persist)) model.add_callback("on_predict_start", partial(on_predict_start, persist=persist))
model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist)) model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist))

@ -19,27 +19,39 @@ class GMC:
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'. method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
downscale (int): Factor by which to downscale the frames for processing. downscale (int): Factor by which to downscale the frames for processing.
prevFrame (np.ndarray): Stores the previous frame for tracking. prevFrame (np.ndarray): Stores the previous frame for tracking.
prevKeyPoints (list): Stores the keypoints from the previous frame. prevKeyPoints (List): Stores the keypoints from the previous frame.
prevDescriptors (np.ndarray): Stores the descriptors from the previous frame. prevDescriptors (np.ndarray): Stores the descriptors from the previous frame.
initializedFirstFrame (bool): Flag to indicate if the first frame has been processed. initializedFirstFrame (bool): Flag to indicate if the first frame has been processed.
Methods: Methods:
__init__(self, method='sparseOptFlow', downscale=2): Initializes a GMC object with the specified method __init__: Initializes a GMC object with the specified method and downscale factor.
and downscale factor. apply: Applies the chosen method to a raw frame and optionally uses provided detections.
apply(self, raw_frame, detections=None): Applies the chosen method to a raw frame and optionally uses applyEcc: Applies the ECC algorithm to a raw frame.
provided detections. applyFeatures: Applies feature-based methods like ORB or SIFT to a raw frame.
applyEcc(self, raw_frame, detections=None): Applies the ECC algorithm to a raw frame. applySparseOptFlow: Applies the Sparse Optical Flow method to a raw frame.
applyFeatures(self, raw_frame, detections=None): Applies feature-based methods like ORB or SIFT to a raw frame. reset_params: Resets the internal parameters of the GMC object.
applySparseOptFlow(self, raw_frame, detections=None): Applies the Sparse Optical Flow method to a raw frame.
Examples:
Create a GMC object and apply it to a frame
>>> gmc = GMC(method='sparseOptFlow', downscale=2)
>>> frame = np.array([[1, 2, 3], [4, 5, 6]])
>>> processed_frame = gmc.apply(frame)
>>> print(processed_frame)
array([[1, 2, 3],
[4, 5, 6]])
""" """
def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None: def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None:
""" """
Initialize a video tracker with specified parameters. Initialize a Generalized Motion Compensation (GMC) object with tracking method and downscale factor.
Args: Args:
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'. method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
downscale (int): Downscale factor for processing frames. downscale (int): Downscale factor for processing frames.
Examples:
Initialize a GMC object with the 'sparseOptFlow' method and a downscale factor of 2
>>> gmc = GMC(method='sparseOptFlow', downscale=2)
""" """
super().__init__() super().__init__()
@ -79,20 +91,21 @@ class GMC:
def apply(self, raw_frame: np.array, detections: list = None) -> np.array: def apply(self, raw_frame: np.array, detections: list = None) -> np.array:
""" """
Apply object detection on a raw frame using specified method. Apply object detection on a raw frame using the specified method.
Args: Args:
raw_frame (np.ndarray): The raw frame to be processed. raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
detections (list): List of detections to be used in the processing. detections (List | None): List of detections to be used in the processing.
Returns: Returns:
(np.ndarray): Processed frame. (np.ndarray): Processed frame with applied object detection.
Examples: Examples:
>>> gmc = GMC() >>> gmc = GMC(method='sparseOptFlow')
>>> gmc.apply(np.array([[1, 2, 3], [4, 5, 6]])) >>> raw_frame = np.random.rand(480, 640, 3)
array([[1, 2, 3], >>> processed_frame = gmc.apply(raw_frame)
[4, 5, 6]]) >>> print(processed_frame.shape)
(480, 640, 3)
""" """
if self.method in {"orb", "sift"}: if self.method in {"orb", "sift"}:
return self.applyFeatures(raw_frame, detections) return self.applyFeatures(raw_frame, detections)
@ -105,19 +118,20 @@ class GMC:
def applyEcc(self, raw_frame: np.array) -> np.array: def applyEcc(self, raw_frame: np.array) -> np.array:
""" """
Apply ECC algorithm to a raw frame. Apply the ECC (Enhanced Correlation Coefficient) algorithm to a raw frame for motion compensation.
Args: Args:
raw_frame (np.ndarray): The raw frame to be processed. raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
Returns: Returns:
(np.ndarray): Processed frame. (np.ndarray): The processed frame with the applied ECC transformation.
Examples: Examples:
>>> gmc = GMC() >>> gmc = GMC(method='ecc')
>>> gmc.applyEcc(np.array([[1, 2, 3], [4, 5, 6]])) >>> processed_frame = gmc.applyEcc(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]))
array([[1, 2, 3], >>> print(processed_frame)
[4, 5, 6]]) [[1. 0. 0.]
[0. 1. 0.]]
""" """
height, width, _ = raw_frame.shape height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
@ -127,8 +141,6 @@ class GMC:
if self.downscale > 1.0: if self.downscale > 1.0:
frame = cv2.GaussianBlur(frame, (3, 3), 1.5) frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale)) frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
width = width // self.downscale
height = height // self.downscale
# Handle first frame # Handle first frame
if not self.initializedFirstFrame: if not self.initializedFirstFrame:
@ -154,17 +166,18 @@ class GMC:
Apply feature-based methods like ORB or SIFT to a raw frame. Apply feature-based methods like ORB or SIFT to a raw frame.
Args: Args:
raw_frame (np.ndarray): The raw frame to be processed. raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
detections (list): List of detections to be used in the processing. detections (List | None): List of detections to be used in the processing.
Returns: Returns:
(np.ndarray): Processed frame. (np.ndarray): Processed frame.
Examples: Examples:
>>> gmc = GMC() >>> gmc = GMC(method='orb')
>>> gmc.applyFeatures(np.array([[1, 2, 3], [4, 5, 6]])) >>> raw_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
array([[1, 2, 3], >>> processed_frame = gmc.applyFeatures(raw_frame)
[4, 5, 6]]) >>> print(processed_frame.shape)
(2, 3)
""" """
height, width, _ = raw_frame.shape height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
@ -296,16 +309,17 @@ class GMC:
Apply Sparse Optical Flow method to a raw frame. Apply Sparse Optical Flow method to a raw frame.
Args: Args:
raw_frame (np.ndarray): The raw frame to be processed. raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
Returns: Returns:
(np.ndarray): Processed frame. (np.ndarray): Processed frame with shape (2, 3).
Examples: Examples:
>>> gmc = GMC() >>> gmc = GMC()
>>> gmc.applySparseOptFlow(np.array([[1, 2, 3], [4, 5, 6]])) >>> result = gmc.applySparseOptFlow(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]))
array([[1, 2, 3], >>> print(result)
[4, 5, 6]]) [[1. 0. 0.]
[0. 1. 0.]]
""" """
height, width, _ = raw_frame.shape height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
@ -356,7 +370,7 @@ class GMC:
return H return H
def reset_params(self) -> None: def reset_params(self) -> None:
"""Reset parameters.""" """Reset the internal parameters including previous frame, keypoints, and descriptors."""
self.prevFrame = None self.prevFrame = None
self.prevKeyPoints = None self.prevKeyPoints = None
self.prevDescriptors = None self.prevDescriptors = None

@ -6,17 +6,49 @@ import scipy.linalg
class KalmanFilterXYAH: class KalmanFilterXYAH:
""" """
For bytetrack. A simple Kalman filter for tracking bounding boxes in image space. A KalmanFilterXYAH class for tracking bounding boxes in image space using a Kalman filter.
The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect Implements a simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space
ratio a, height h, and their respective velocities. (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect ratio a, height h, and their
respective velocities. Object motion follows a constant velocity model, and bounding box location (x, y, a, h) is
Object motion follows a constant velocity model. The bounding box location (x, y, a, h) is taken as direct taken as a direct observation of the state space (linear observation model).
observation of the state space (linear observation model).
Attributes:
_motion_mat (np.ndarray): The motion matrix for the Kalman filter.
_update_mat (np.ndarray): The update matrix for the Kalman filter.
_std_weight_position (float): Standard deviation weight for position.
_std_weight_velocity (float): Standard deviation weight for velocity.
Methods:
initiate: Creates a track from an unassociated measurement.
predict: Runs the Kalman filter prediction step.
project: Projects the state distribution to measurement space.
multi_predict: Runs the Kalman filter prediction step (vectorized version).
update: Runs the Kalman filter correction step.
gating_distance: Computes the gating distance between state distribution and measurements.
Examples:
Initialize the Kalman filter and create a track from a measurement
>>> kf = KalmanFilterXYAH()
>>> measurement = np.array([100, 200, 1.5, 50])
>>> mean, covariance = kf.initiate(measurement)
>>> print(mean)
>>> print(covariance)
""" """
def __init__(self): def __init__(self):
"""Initialize Kalman filter model matrices with motion and observation uncertainty weights.""" """
Initialize Kalman filter model matrices with motion and observation uncertainty weights.
The Kalman filter is initialized with an 8-dimensional state space (x, y, a, h, vx, vy, va, vh), where (x, y)
represents the bounding box center position, 'a' is the aspect ratio, 'h' is the height, and their respective
velocities are (vx, vy, va, vh). The filter uses a constant velocity model for object motion and a linear
observation model for bounding box location.
Examples:
Initialize a Kalman filter for tracking:
>>> kf = KalmanFilterXYAH()
"""
ndim, dt = 4, 1.0 ndim, dt = 4, 1.0
# Create Kalman filter model matrices # Create Kalman filter model matrices
@ -32,15 +64,20 @@ class KalmanFilterXYAH:
def initiate(self, measurement: np.ndarray) -> tuple: def initiate(self, measurement: np.ndarray) -> tuple:
""" """
Create track from unassociated measurement. Create a track from an unassociated measurement.
Args: Args:
measurement (ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a, measurement (ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a,
and height h. and height h.
Returns: Returns:
(tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) (tuple[ndarray, ndarray]): Returns the mean vector (8-dimensional) and covariance matrix (8x8 dimensional)
of the new track. Unobserved velocities are initialized to 0 mean. of the new track. Unobserved velocities are initialized to 0 mean.
Examples:
>>> kf = KalmanFilterXYAH()
>>> measurement = np.array([100, 50, 1.5, 200])
>>> mean, covariance = kf.initiate(measurement)
""" """
mean_pos = measurement mean_pos = measurement
mean_vel = np.zeros_like(mean_pos) mean_vel = np.zeros_like(mean_pos)
@ -64,12 +101,18 @@ class KalmanFilterXYAH:
Run Kalman filter prediction step. Run Kalman filter prediction step.
Args: Args:
mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step. mean (ndarray): The 8-dimensional mean vector of the object state at the previous time step.
covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step. covariance (ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
Returns: Returns:
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved (tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
velocities are initialized to 0 mean. velocities are initialized to 0 mean.
Examples:
>>> kf = KalmanFilterXYAH()
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
>>> covariance = np.eye(8)
>>> predicted_mean, predicted_covariance = kf.predict(mean, covariance)
""" """
std_pos = [ std_pos = [
self._std_weight_position * mean[3], self._std_weight_position * mean[3],
@ -100,6 +143,12 @@ class KalmanFilterXYAH:
Returns: Returns:
(tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate. (tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
Examples:
>>> kf = KalmanFilterXYAH()
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
>>> covariance = np.eye(8)
>>> projected_mean, projected_covariance = kf.project(mean, covariance)
""" """
std = [ std = [
self._std_weight_position * mean[3], self._std_weight_position * mean[3],
@ -115,15 +164,21 @@ class KalmanFilterXYAH:
def multi_predict(self, mean: np.ndarray, covariance: np.ndarray) -> tuple: def multi_predict(self, mean: np.ndarray, covariance: np.ndarray) -> tuple:
""" """
Run Kalman filter prediction step (Vectorized version). Run Kalman filter prediction step for multiple object states (Vectorized version).
Args: Args:
mean (ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step. mean (ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
covariance (ndarray): The Nx8x8 covariance matrix of the object states at the previous time step. covariance (ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
Returns: Returns:
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved (tuple[ndarray, ndarray]): Returns the mean matrix and covariance matrix of the predicted states.
velocities are initialized to 0 mean. The mean matrix has shape (N, 8) and the covariance matrix has shape (N, 8, 8). Unobserved velocities
are initialized to 0 mean.
Examples:
>>> mean = np.random.rand(10, 8) # 10 object states
>>> covariance = np.random.rand(10, 8, 8) # Covariance matrices for 10 object states
>>> predicted_mean, predicted_covariance = kalman_filter.multi_predict(mean, covariance)
""" """
std_pos = [ std_pos = [
self._std_weight_position * mean[:, 3], self._std_weight_position * mean[:, 3],
@ -160,6 +215,13 @@ class KalmanFilterXYAH:
Returns: Returns:
(tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution. (tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution.
Examples:
>>> kf = KalmanFilterXYAH()
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
>>> covariance = np.eye(8)
>>> measurement = np.array([1, 1, 1, 1])
>>> new_mean, new_covariance = kf.update(mean, covariance, measurement)
""" """
projected_mean, projected_cov = self.project(mean, covariance) projected_mean, projected_cov = self.project(mean, covariance)
@ -182,23 +244,31 @@ class KalmanFilterXYAH:
metric: str = "maha", metric: str = "maha",
) -> np.ndarray: ) -> np.ndarray:
""" """
Compute gating distance between state distribution and measurements. A suitable distance threshold can be Compute gating distance between state distribution and measurements.
obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of freedom,
otherwise 2. A suitable distance threshold can be obtained from `chi2inv95`. If `only_position` is False, the chi-square
distribution has 4 degrees of freedom, otherwise 2.
Args: Args:
mean (ndarray): Mean vector over the state distribution (8 dimensional). mean (ndarray): Mean vector over the state distribution (8 dimensional).
covariance (ndarray): Covariance of the state distribution (8x8 dimensional). covariance (ndarray): Covariance of the state distribution (8x8 dimensional).
measurements (ndarray): An Nx4 matrix of N measurements, each in format (x, y, a, h) where (x, y) measurements (ndarray): An (N, 4) matrix of N measurements, each in format (x, y, a, h) where (x, y) is the
is the bounding box center position, a the aspect ratio, and h the height. bounding box center position, a the aspect ratio, and h the height.
only_position (bool, optional): If True, distance computation is done with respect to the bounding box only_position (bool): If True, distance computation is done with respect to box center position only.
center position only. Defaults to False. metric (str): The metric to use for calculating the distance. Options are 'gaussian' for the squared
metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the Euclidean distance and 'maha' for the squared Mahalanobis distance.
squared Euclidean distance and 'maha' for the squared Mahalanobis distance. Defaults to 'maha'.
Returns: Returns:
(np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between (np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between
(mean, covariance) and `measurements[i]`. (mean, covariance) and `measurements[i]`.
Examples:
Compute gating distance using Mahalanobis metric:
>>> kf = KalmanFilterXYAH()
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
>>> covariance = np.eye(8)
>>> measurements = np.array([[1, 1, 1, 1], [2, 2, 1, 1]])
>>> distances = kf.gating_distance(mean, covariance, measurements, only_position=False, metric='maha')
""" """
mean, covariance = self.project(mean, covariance) mean, covariance = self.project(mean, covariance)
if only_position: if only_position:
@ -218,13 +288,33 @@ class KalmanFilterXYAH:
class KalmanFilterXYWH(KalmanFilterXYAH): class KalmanFilterXYWH(KalmanFilterXYAH):
""" """
For BoT-SORT. A simple Kalman filter for tracking bounding boxes in image space. A KalmanFilterXYWH class for tracking bounding boxes in image space using a Kalman filter.
The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y), width
w, height h, and their respective velocities.
Object motion follows a constant velocity model. The bounding box location (x, y, w, h) is taken as direct Implements a Kalman filter for tracking bounding boxes with state space (x, y, w, h, vx, vy, vw, vh), where
(x, y) is the center position, w is the width, h is the height, and vx, vy, vw, vh are their respective velocities.
The object motion follows a constant velocity model, and the bounding box location (x, y, w, h) is taken as a direct
observation of the state space (linear observation model). observation of the state space (linear observation model).
Attributes:
_motion_mat (np.ndarray): The motion matrix for the Kalman filter.
_update_mat (np.ndarray): The update matrix for the Kalman filter.
_std_weight_position (float): Standard deviation weight for position.
_std_weight_velocity (float): Standard deviation weight for velocity.
Methods:
initiate: Creates a track from an unassociated measurement.
predict: Runs the Kalman filter prediction step.
project: Projects the state distribution to measurement space.
multi_predict: Runs the Kalman filter prediction step in a vectorized manner.
update: Runs the Kalman filter correction step.
Examples:
Create a Kalman filter and initialize a track
>>> kf = KalmanFilterXYWH()
>>> measurement = np.array([100, 50, 20, 40])
>>> mean, covariance = kf.initiate(measurement)
>>> print(mean)
>>> print(covariance)
""" """
def initiate(self, measurement: np.ndarray) -> tuple: def initiate(self, measurement: np.ndarray) -> tuple:
@ -237,6 +327,22 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
Returns: Returns:
(tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) (tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional)
of the new track. Unobserved velocities are initialized to 0 mean. of the new track. Unobserved velocities are initialized to 0 mean.
Examples:
>>> kf = KalmanFilterXYWH()
>>> measurement = np.array([100, 50, 20, 40])
>>> mean, covariance = kf.initiate(measurement)
>>> print(mean)
[100. 50. 20. 40. 0. 0. 0. 0.]
>>> print(covariance)
[[ 4. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 4. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 4. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 4. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0.25 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0.25 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0.25 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0.25]]
""" """
mean_pos = measurement mean_pos = measurement
mean_vel = np.zeros_like(mean_pos) mean_vel = np.zeros_like(mean_pos)
@ -260,12 +366,18 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
Run Kalman filter prediction step. Run Kalman filter prediction step.
Args: Args:
mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step. mean (ndarray): The 8-dimensional mean vector of the object state at the previous time step.
covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step. covariance (ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
Returns: Returns:
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved (tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
velocities are initialized to 0 mean. velocities are initialized to 0 mean.
Examples:
>>> kf = KalmanFilterXYWH()
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
>>> covariance = np.eye(8)
>>> predicted_mean, predicted_covariance = kf.predict(mean, covariance)
""" """
std_pos = [ std_pos = [
self._std_weight_position * mean[2], self._std_weight_position * mean[2],
@ -296,6 +408,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
Returns: Returns:
(tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate. (tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
Examples:
>>> kf = KalmanFilterXYWH()
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
>>> covariance = np.eye(8)
>>> projected_mean, projected_cov = kf.project(mean, covariance)
""" """
std = [ std = [
self._std_weight_position * mean[2], self._std_weight_position * mean[2],
@ -320,6 +438,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
Returns: Returns:
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved (tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
velocities are initialized to 0 mean. velocities are initialized to 0 mean.
Examples:
>>> mean = np.random.rand(5, 8) # 5 objects with 8-dimensional state vectors
>>> covariance = np.random.rand(5, 8, 8) # 5 objects with 8x8 covariance matrices
>>> kf = KalmanFilterXYWH()
>>> predicted_mean, predicted_covariance = kf.multi_predict(mean, covariance)
""" """
std_pos = [ std_pos = [
self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 2],
@ -356,5 +480,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
Returns: Returns:
(tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution. (tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution.
Examples:
>>> kf = KalmanFilterXYWH()
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
>>> covariance = np.eye(8)
>>> measurement = np.array([0.5, 0.5, 1.2, 1.2])
>>> new_mean, new_covariance = kf.update(mean, covariance, measurement)
""" """
return super().update(mean, covariance, measurement) return super().update(mean, covariance, measurement)

@ -19,18 +19,23 @@ except (ImportError, AssertionError, AttributeError):
def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True) -> tuple: def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True) -> tuple:
""" """
Perform linear assignment using scipy or lap.lapjv. Perform linear assignment using either the scipy or lap.lapjv method.
Args: Args:
cost_matrix (np.ndarray): The matrix containing cost values for assignments. cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).
thresh (float): Threshold for considering an assignment valid. thresh (float): Threshold for considering an assignment valid.
use_lap (bool, optional): Whether to use lap.lapjv. Defaults to True. use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used.
Returns: Returns:
Tuple with: (tuple): A tuple containing:
- matched indices - matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches.
- unmatched indices from 'a' - unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,).
- unmatched indices from 'b' - unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,).
Examples:
>>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> thresh = 5.0
>>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True)
""" """
if cost_matrix.size == 0: if cost_matrix.size == 0:
@ -68,6 +73,12 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
Returns: Returns:
(np.ndarray): Cost matrix computed based on IoU. (np.ndarray): Cost matrix computed based on IoU.
Examples:
Compute IoU distance between two sets of tracks
>>> atracks = [np.array([0, 0, 10, 10]), np.array([20, 20, 30, 30])]
>>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])]
>>> cost_matrix = iou_distance(atracks, btracks)
""" """
if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray): if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):
@ -98,12 +109,19 @@ def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -
Compute distance between tracks and detections based on embeddings. Compute distance between tracks and detections based on embeddings.
Args: Args:
tracks (list[STrack]): List of tracks. tracks (list[STrack]): List of tracks, where each track contains embedding features.
detections (list[BaseTrack]): List of detections. detections (list[BaseTrack]): List of detections, where each detection contains embedding features.
metric (str, optional): Metric for distance computation. Defaults to 'cosine'. metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc.
Returns: Returns:
(np.ndarray): Cost matrix computed based on embeddings. (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks
and M is the number of detections.
Examples:
Compute the embedding distance between tracks and detections using cosine metric
>>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features
>>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features
>>> cost_matrix = embedding_distance(tracks, detections, metric='cosine')
""" """
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32) cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
@ -122,11 +140,17 @@ def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:
Fuses cost matrix with detection scores to produce a single similarity matrix. Fuses cost matrix with detection scores to produce a single similarity matrix.
Args: Args:
cost_matrix (np.ndarray): The matrix containing cost values for assignments. cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).
detections (list[BaseTrack]): List of detections with scores. detections (list[BaseTrack]): List of detections, each containing a score attribute.
Returns: Returns:
(np.ndarray): Fused similarity matrix. (np.ndarray): Fused similarity matrix with shape (N, M).
Examples:
Fuse a cost matrix with detection scores
>>> cost_matrix = np.random.rand(5, 10) # 5 tracks and 10 detections
>>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)]
>>> fused_matrix = fuse_score(cost_matrix, detections)
""" """
if cost_matrix.size == 0: if cost_matrix.size == 0:

@ -47,7 +47,7 @@ PYTHON_VERSION = platform.python_version()
TORCH_VERSION = torch.__version__ TORCH_VERSION = torch.__version__
TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision
HELP_MSG = """ HELP_MSG = """
Usage examples for running Ultralytics YOLO: Examples for running Ultralytics:
1. Install the ultralytics package: 1. Install the ultralytics package:

@ -11,19 +11,44 @@ from pathlib import Path
class WorkingDirectory(contextlib.ContextDecorator): class WorkingDirectory(contextlib.ContextDecorator):
"""Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager.""" """
A context manager and decorator for temporarily changing the working directory.
This class allows for the temporary change of the working directory using a context manager or decorator.
It ensures that the original working directory is restored after the context or decorated function completes.
Attributes:
dir (Path): The new directory to switch to.
cwd (Path): The original current working directory before the switch.
Methods:
__enter__: Changes the current directory to the specified directory.
__exit__: Restores the original working directory on context exit.
Examples:
Using as a context manager:
>>> with WorkingDirectory('/path/to/new/dir'):
>>> # Perform operations in the new directory
>>> pass
Using as a decorator:
>>> @WorkingDirectory('/path/to/new/dir')
>>> def some_function():
>>> # Perform operations in the new directory
>>> pass
"""
def __init__(self, new_dir): def __init__(self, new_dir):
"""Sets the working directory to 'new_dir' upon instantiation.""" """Sets the working directory to 'new_dir' upon instantiation for use with context managers or decorators."""
self.dir = new_dir # new dir self.dir = new_dir # new dir
self.cwd = Path.cwd().resolve() # current dir self.cwd = Path.cwd().resolve() # current dir
def __enter__(self): def __enter__(self):
"""Changes the current directory to the specified directory.""" """Changes the current working directory to the specified directory upon entering the context."""
os.chdir(self.dir) os.chdir(self.dir)
def __exit__(self, exc_type, exc_val, exc_tb): # noqa def __exit__(self, exc_type, exc_val, exc_tb): # noqa
"""Restore the current working directory on context exit.""" """Restores the original working directory when exiting the context."""
os.chdir(self.cwd) os.chdir(self.cwd)
@ -35,18 +60,16 @@ def spaces_in_path(path):
file/directory back to its original location. file/directory back to its original location.
Args: Args:
path (str | Path): The original path. path (str | Path): The original path that may contain spaces.
Yields: Yields:
(Path): Temporary path with spaces replaced by underscores if spaces were present, otherwise the original path. (Path): Temporary path with spaces replaced by underscores if spaces were present, otherwise the original path.
Example: Examples:
```python Use the context manager to handle paths with spaces:
with ultralytics.utils.files import spaces_in_path >>> from ultralytics.utils.files import spaces_in_path
>>> with spaces_in_path('/path/with spaces') as new_path:
with spaces_in_path('/path/with spaces') as new_path: >>> # Your code here
# Your code here
```
""" """
# If path has spaces, replace them with underscores # If path has spaces, replace them with underscores
@ -84,21 +107,35 @@ def spaces_in_path(path):
def increment_path(path, exist_ok=False, sep="", mkdir=False): def increment_path(path, exist_ok=False, sep="", mkdir=False):
""" """
Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. Increments a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to If the path exists and `exist_ok` is not True, the path will be incremented by appending a number and `sep` to
the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a number will be appended directly to the end of the path. If `mkdir` is set to True, the path will be created as a
directory if it does not already exist. directory if it does not already exist.
Args: Args:
path (str, pathlib.Path): Path to increment. path (str | pathlib.Path): Path to increment.
exist_ok (bool, optional): If True, the path will not be incremented and returned as-is. Defaults to False. exist_ok (bool): If True, the path will not be incremented and returned as-is.
sep (str, optional): Separator to use between the path and the incrementation number. Defaults to ''. sep (str): Separator to use between the path and the incrementation number.
mkdir (bool, optional): Create a directory if it does not exist. Defaults to False. mkdir (bool): Create a directory if it does not exist.
Returns: Returns:
(pathlib.Path): Incremented path. (pathlib.Path): Incremented path.
Examples:
Increment a directory path:
>>> from pathlib import Path
>>> path = Path("runs/exp")
>>> new_path = increment_path(path)
>>> print(new_path)
runs/exp2
Increment a file path:
>>> path = Path("runs/exp/results.txt")
>>> new_path = increment_path(path)
>>> print(new_path)
runs/exp/results2.txt
""" """
path = Path(path) # os-agnostic path = Path(path) # os-agnostic
if path.exists() and not exist_ok: if path.exists() and not exist_ok:
@ -118,19 +155,19 @@ def increment_path(path, exist_ok=False, sep="", mkdir=False):
def file_age(path=__file__): def file_age(path=__file__):
"""Return days since last file update.""" """Return days since the last modification of the specified file."""
dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta
return dt.days # + dt.seconds / 86400 # fractional days return dt.days # + dt.seconds / 86400 # fractional days
def file_date(path=__file__): def file_date(path=__file__):
"""Return human-readable file modification date, i.e. '2021-3-26'.""" """Returns the file modification date in 'YYYY-M-D' format."""
t = datetime.fromtimestamp(Path(path).stat().st_mtime) t = datetime.fromtimestamp(Path(path).stat().st_mtime)
return f"{t.year}-{t.month}-{t.day}" return f"{t.year}-{t.month}-{t.day}"
def file_size(path): def file_size(path):
"""Return file/dir size (MB).""" """Returns the size of a file or directory in megabytes (MB)."""
if isinstance(path, (str, Path)): if isinstance(path, (str, Path)):
mb = 1 << 20 # bytes to MiB (1024 ** 2) mb = 1 << 20 # bytes to MiB (1024 ** 2)
path = Path(path) path = Path(path)
@ -142,7 +179,7 @@ def file_size(path):
def get_latest_run(search_dir="."): def get_latest_run(search_dir="."):
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from).""" """Returns the path to the most recent 'last.pt' file in the specified directory for resuming training."""
last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True) last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
return max(last_list, key=os.path.getctime) if last_list else "" return max(last_list, key=os.path.getctime) if last_list else ""
@ -152,17 +189,15 @@ def update_models(model_names=("yolov8n.pt",), source_dir=Path("."), update_name
Updates and re-saves specified YOLO models in an 'updated_models' subdirectory. Updates and re-saves specified YOLO models in an 'updated_models' subdirectory.
Args: Args:
model_names (tuple, optional): Model filenames to update, defaults to ("yolov8n.pt"). model_names (Tuple[str, ...]): Model filenames to update.
source_dir (Path, optional): Directory containing models and target subdirectory, defaults to current directory. source_dir (Path): Directory containing models and target subdirectory.
update_names (bool, optional): Update model names from a data YAML. update_names (bool): Update model names from a data YAML.
Example: Examples:
```python Update specified YOLO models and save them in 'updated_models' subdirectory:
from ultralytics.utils.files import update_models >>> from ultralytics.utils.files import update_models
>>> model_names = ("yolov8n.pt", "yolov8s.pt")
model_names = (f"rtdetr-{size}.pt" for size in "lx") >>> update_models(model_names, source_dir=Path("/models"), update_names=True)
update_models(model_names)
```
""" """
from ultralytics import YOLO from ultralytics import YOLO
from ultralytics.nn.autobackend import default_class_names from ultralytics.nn.autobackend import default_class_names

@ -369,7 +369,7 @@ class Annotator:
# Convert im back to PIL and update draw # Convert im back to PIL and update draw
self.fromarray(self.im) self.fromarray(self.im)
def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True, conf_thres=0.25): def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True, conf_thres=0.25, kpt_color=None):
""" """
Plot keypoints on the image. Plot keypoints on the image.
@ -379,6 +379,7 @@ class Annotator:
radius (int, optional): Radius of the drawn keypoints. Default is 5. radius (int, optional): Radius of the drawn keypoints. Default is 5.
kpt_line (bool, optional): If True, the function will draw lines connecting keypoints kpt_line (bool, optional): If True, the function will draw lines connecting keypoints
for human pose. Default is True. for human pose. Default is True.
kpt_color (tuple, optional): The color of the keypoints (B, G, R).
Note: Note:
`kpt_line=True` currently only supports human pose plotting. `kpt_line=True` currently only supports human pose plotting.
@ -391,7 +392,7 @@ class Annotator:
is_pose = nkpt == 17 and ndim in {2, 3} is_pose = nkpt == 17 and ndim in {2, 3}
kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting
for i, k in enumerate(kpts): for i, k in enumerate(kpts):
color_k = [int(x) for x in self.kpt_color[i]] if is_pose else colors(i) color_k = kpt_color or (self.kpt_color[i].tolist() if is_pose else colors(i))
x_coord, y_coord = k[0], k[1] x_coord, y_coord = k[0], k[1]
if x_coord % shape[1] != 0 and y_coord % shape[0] != 0: if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
if len(k) == 3: if len(k) == 3:
@ -414,7 +415,14 @@ class Annotator:
continue continue
if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0: if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0:
continue continue
cv2.line(self.im, pos1, pos2, [int(x) for x in self.limb_color[i]], thickness=2, lineType=cv2.LINE_AA) cv2.line(
self.im,
pos1,
pos2,
kpt_color or self.limb_color[i].tolist(),
thickness=2,
lineType=cv2.LINE_AA,
)
if self.pil: if self.pil:
# Convert im back to PIL and update draw # Convert im back to PIL and update draw
self.fromarray(self.im) self.fromarray(self.im)

Loading…
Cancel
Save