# Ultralytics YOLO 🚀, AGPL-3.0 license from collections import deque import numpy as np from .basetrack import TrackState from .byte_tracker import BYTETracker, STrack from .utils import matching from .utils.gmc import GMC from .utils.kalman_filter import KalmanFilterXYWH class BOTrack(STrack): shared_kalman = KalmanFilterXYWH() def __init__(self, tlwh, score, cls, feat=None, feat_history=50): """Initialize YOLOv8 object with temporal parameters, such as feature history, alpha and current features.""" super().__init__(tlwh, score, cls) self.smooth_feat = None self.curr_feat = None if feat is not None: self.update_features(feat) self.features = deque([], maxlen=feat_history) self.alpha = 0.9 def update_features(self, feat): """Update features vector and smooth it using exponential moving average.""" feat /= np.linalg.norm(feat) self.curr_feat = feat if self.smooth_feat is None: self.smooth_feat = feat else: self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat self.features.append(feat) self.smooth_feat /= np.linalg.norm(self.smooth_feat) def predict(self): """Predicts the mean and covariance using Kalman filter.""" mean_state = self.mean.copy() if self.state != TrackState.Tracked: mean_state[6] = 0 mean_state[7] = 0 self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) def re_activate(self, new_track, frame_id, new_id=False): """Reactivates a track with updated features and optionally assigns a new ID.""" if new_track.curr_feat is not None: self.update_features(new_track.curr_feat) super().re_activate(new_track, frame_id, new_id) def update(self, new_track, frame_id): """Update the YOLOv8 instance with new track and frame ID.""" if new_track.curr_feat is not None: self.update_features(new_track.curr_feat) super().update(new_track, frame_id) @property def tlwh(self): """Get current position in bounding box format `(top left x, top left y, width, height)`. """ if self.mean is None: return self._tlwh.copy() ret = self.mean[:4].copy() ret[:2] -= ret[2:] / 2 return ret @staticmethod def multi_predict(stracks): """Predicts the mean and covariance of multiple object tracks using shared Kalman filter.""" if len(stracks) <= 0: return multi_mean = np.asarray([st.mean.copy() for st in stracks]) multi_covariance = np.asarray([st.covariance for st in stracks]) for i, st in enumerate(stracks): if st.state != TrackState.Tracked: multi_mean[i][6] = 0 multi_mean[i][7] = 0 multi_mean, multi_covariance = BOTrack.shared_kalman.multi_predict(multi_mean, multi_covariance) for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): stracks[i].mean = mean stracks[i].covariance = cov def convert_coords(self, tlwh): """Converts Top-Left-Width-Height bounding box coordinates to X-Y-Width-Height format.""" return self.tlwh_to_xywh(tlwh) @staticmethod def tlwh_to_xywh(tlwh): """Convert bounding box to format `(center x, center y, width, height)`. """ ret = np.asarray(tlwh).copy() ret[:2] += ret[2:] / 2 return ret class BOTSORT(BYTETracker): def __init__(self, args, frame_rate=30): """Initialize YOLOv8 object with ReID module and GMC algorithm.""" super().__init__(args, frame_rate) # ReID module self.proximity_thresh = args.proximity_thresh self.appearance_thresh = args.appearance_thresh if args.with_reid: # Haven't supported BoT-SORT(reid) yet self.encoder = None # self.gmc = GMC(method=args.cmc_method, verbose=[args.name, args.ablation]) self.gmc = GMC(method=args.cmc_method) def get_kalmanfilter(self): """Returns an instance of KalmanFilterXYWH for object tracking.""" return KalmanFilterXYWH() def init_track(self, dets, scores, cls, img=None): """Initialize track with detections, scores, and classes.""" if len(dets) == 0: return [] if self.args.with_reid and self.encoder is not None: features_keep = self.encoder.inference(img, dets) return [BOTrack(xyxy, s, c, f) for (xyxy, s, c, f) in zip(dets, scores, cls, features_keep)] # detections else: return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections def get_dists(self, tracks, detections): """Get distances between tracks and detections using IoU and (optionally) ReID embeddings.""" dists = matching.iou_distance(tracks, detections) dists_mask = (dists > self.proximity_thresh) # TODO: mot20 # if not self.args.mot20: dists = matching.fuse_score(dists, detections) if self.args.with_reid and self.encoder is not None: emb_dists = matching.embedding_distance(tracks, detections) / 2.0 emb_dists[emb_dists > self.appearance_thresh] = 1.0 emb_dists[dists_mask] = 1.0 dists = np.minimum(dists, emb_dists) return dists def multi_predict(self, tracks): """Predict and track multiple objects with YOLOv8 model.""" BOTrack.multi_predict(tracks)