`ultralytics 8.1.5` add OBB Tracking support (#7731)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Hassaan Farooq <103611273+hassaanfarooq01@users.noreply.github.com>
pull/7770/head v8.1.5
Laughing 10 months ago committed by GitHub
parent 12a741c76f
commit f56dd0f48e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 4
      docs/en/reference/utils/ops.md
  2. 2
      ultralytics/__init__.py
  3. 4
      ultralytics/engine/results.py
  4. 6
      ultralytics/hub/session.py
  5. 5
      ultralytics/models/yolo/obb/predict.py
  6. 58
      ultralytics/trackers/byte_tracker.py
  7. 10
      ultralytics/trackers/track.py
  8. 20
      ultralytics/trackers/utils/matching.py
  9. 2
      ultralytics/utils/callbacks/hub.py
  10. 7
      ultralytics/utils/metrics.py
  11. 18
      ultralytics/utils/ops.py

@ -119,6 +119,10 @@ keywords: Ultralytics YOLO, Utility Operations, segment2box, make_divisible, cli
<br><br> <br><br>
## ::: ultralytics.utils.ops.regularize_rboxes
<br><br>
## ::: ultralytics.utils.ops.masks2segments ## ::: ultralytics.utils.ops.masks2segments
<br><br> <br><br>

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.4" __version__ = "8.1.5"
from ultralytics.data.explorer.explorer import Explorer from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO from ultralytics.models import RTDETR, SAM, YOLO

@ -115,7 +115,7 @@ class Results(SimpleClass):
if v is not None: if v is not None:
return len(v) return len(v)
def update(self, boxes=None, masks=None, probs=None): def update(self, boxes=None, masks=None, probs=None, obb=None):
"""Update the boxes, masks, and probs attributes of the Results object.""" """Update the boxes, masks, and probs attributes of the Results object."""
if boxes is not None: if boxes is not None:
self.boxes = Boxes(ops.clip_boxes(boxes, self.orig_shape), self.orig_shape) self.boxes = Boxes(ops.clip_boxes(boxes, self.orig_shape), self.orig_shape)
@ -123,6 +123,8 @@ class Results(SimpleClass):
self.masks = Masks(masks, self.orig_shape) self.masks = Masks(masks, self.orig_shape)
if probs is not None: if probs is not None:
self.probs = probs self.probs = probs
if obb is not None:
self.obb = OBB(obb, self.orig_shape)
def _apply(self, fn, *args, **kwargs): def _apply(self, fn, *args, **kwargs):
""" """

@ -225,14 +225,14 @@ class HUBTrainingSession:
break # Timeout reached, exit loop break # Timeout reached, exit loop
response = request_func(*args, **kwargs) response = request_func(*args, **kwargs)
if progress_total:
self._show_upload_progress(progress_total, response)
if response is None: if response is None:
LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}") LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}")
time.sleep(2**i) # Exponential backoff before retrying time.sleep(2**i) # Exponential backoff before retrying
continue # Skip further processing and retry continue # Skip further processing and retry
if progress_total:
self._show_upload_progress(progress_total, response)
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES: if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
return response # Success, no need to retry return response # Success, no need to retry

@ -45,8 +45,9 @@ class OBBPredictor(DetectionPredictor):
results = [] results = []
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]): for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True) rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
# xywh, r, conf, cls # xywh, r, conf, cls
obb = torch.cat([pred[:, :4], pred[:, -1:], pred[:, 4:6]], dim=-1) obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb)) results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))
return results return results

@ -5,6 +5,8 @@ import numpy as np
from .basetrack import BaseTrack, TrackState from .basetrack import BaseTrack, TrackState
from .utils import matching from .utils import matching
from .utils.kalman_filter import KalmanFilterXYAH from .utils.kalman_filter import KalmanFilterXYAH
from ..utils.ops import xywh2ltwh
from ..utils import LOGGER
class STrack(BaseTrack): class STrack(BaseTrack):
@ -35,18 +37,18 @@ class STrack(BaseTrack):
activate(kalman_filter, frame_id): Activate a new tracklet. activate(kalman_filter, frame_id): Activate a new tracklet.
re_activate(new_track, frame_id, new_id): Reactivate a previously lost tracklet. re_activate(new_track, frame_id, new_id): Reactivate a previously lost tracklet.
update(new_track, frame_id): Update the state of a matched track. update(new_track, frame_id): Update the state of a matched track.
convert_coords(tlwh): Convert bounding box to x-y-angle-height format. convert_coords(tlwh): Convert bounding box to x-y-aspect-height format.
tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format. tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
tlbr_to_tlwh(tlbr): Convert tlbr bounding box to tlwh format.
tlwh_to_tlbr(tlwh): Convert tlwh bounding box to tlbr format.
""" """
shared_kalman = KalmanFilterXYAH() shared_kalman = KalmanFilterXYAH()
def __init__(self, tlwh, score, cls): def __init__(self, xywh, score, cls):
"""Initialize new STrack instance.""" """Initialize new STrack instance."""
super().__init__() super().__init__()
self._tlwh = np.asarray(self.tlbr_to_tlwh(tlwh[:-1]), dtype=np.float32) # xywh+idx or xywha+idx
assert len(xywh) in [5, 6], f"expected 5 or 6 values but got {len(xywh)}"
self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32)
self.kalman_filter = None self.kalman_filter = None
self.mean, self.covariance = None, None self.mean, self.covariance = None, None
self.is_activated = False self.is_activated = False
@ -54,7 +56,8 @@ class STrack(BaseTrack):
self.score = score self.score = score
self.tracklet_len = 0 self.tracklet_len = 0
self.cls = cls self.cls = cls
self.idx = tlwh[-1] self.idx = xywh[-1]
self.angle = xywh[4] if len(xywh) == 6 else None
def predict(self): def predict(self):
"""Predicts mean and covariance using Kalman filter.""" """Predicts mean and covariance using Kalman filter."""
@ -123,6 +126,7 @@ class STrack(BaseTrack):
self.track_id = self.next_id() self.track_id = self.next_id()
self.score = new_track.score self.score = new_track.score
self.cls = new_track.cls self.cls = new_track.cls
self.angle = new_track.angle
self.idx = new_track.idx self.idx = new_track.idx
def update(self, new_track, frame_id): def update(self, new_track, frame_id):
@ -145,10 +149,11 @@ class STrack(BaseTrack):
self.score = new_track.score self.score = new_track.score
self.cls = new_track.cls self.cls = new_track.cls
self.angle = new_track.angle
self.idx = new_track.idx self.idx = new_track.idx
def convert_coords(self, tlwh): def convert_coords(self, tlwh):
"""Convert a bounding box's top-left-width-height format to its x-y-angle-height equivalent.""" """Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent."""
return self.tlwh_to_xyah(tlwh) return self.tlwh_to_xyah(tlwh)
@property @property
@ -162,7 +167,7 @@ class STrack(BaseTrack):
return ret return ret
@property @property
def tlbr(self): def xyxy(self):
"""Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right).""" """Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right)."""
ret = self.tlwh.copy() ret = self.tlwh.copy()
ret[2:] += ret[:2] ret[2:] += ret[:2]
@ -178,19 +183,26 @@ class STrack(BaseTrack):
ret[2] /= ret[3] ret[2] /= ret[3]
return ret return ret
@staticmethod @property
def tlbr_to_tlwh(tlbr): def xywh(self):
"""Converts top-left bottom-right format to top-left width height format.""" """Get current position in bounding box format (center x, center y, width, height)."""
ret = np.asarray(tlbr).copy() ret = np.asarray(self.tlwh).copy()
ret[2:] -= ret[:2] ret[:2] += ret[2:] / 2
return ret return ret
@staticmethod @property
def tlwh_to_tlbr(tlwh): def xywha(self):
"""Converts tlwh bounding box format to tlbr format.""" """Get current position in bounding box format (center x, center y, width, height, angle)."""
ret = np.asarray(tlwh).copy() if self.angle is None:
ret[2:] += ret[:2] LOGGER.warning("WARNING ⚠ `angle` attr not found, returning `xywh` instead.")
return ret return self.xywh
return np.concatenate([self.xywh, self.angle[None]])
@property
def result(self):
"""Get current tracking results."""
coords = self.xyxy if self.angle is None else self.xywha
return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
def __repr__(self): def __repr__(self):
"""Return a string representation of the BYTETracker object with start and end frames and track ID.""" """Return a string representation of the BYTETracker object with start and end frames and track ID."""
@ -247,7 +259,7 @@ class BYTETracker:
removed_stracks = [] removed_stracks = []
scores = results.conf scores = results.conf
bboxes = results.xyxy bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh
# Add index # Add index
bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1) bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
cls = results.cls cls = results.cls
@ -349,10 +361,8 @@ class BYTETracker:
self.removed_stracks.extend(removed_stracks) self.removed_stracks.extend(removed_stracks)
if len(self.removed_stracks) > 1000: if len(self.removed_stracks) > 1000:
self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum
return np.asarray(
[x.tlbr.tolist() + [x.track_id, x.score, x.cls, x.idx] for x in self.tracked_stracks if x.is_activated], return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
dtype=np.float32,
)
def get_kalmanfilter(self): def get_kalmanfilter(self):
"""Returns a Kalman filter object for tracking bounding boxes.""" """Returns a Kalman filter object for tracking bounding boxes."""

@ -25,8 +25,6 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
Raises: Raises:
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'. AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
""" """
if predictor.args.task == "obb":
raise NotImplementedError("ERROR ❌ OBB task does not support track mode!")
if hasattr(predictor, "trackers") and persist: if hasattr(predictor, "trackers") and persist:
return return
@ -54,11 +52,12 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
bs = predictor.dataset.bs bs = predictor.dataset.bs
path, im0s = predictor.batch[:2] path, im0s = predictor.batch[:2]
is_obb = predictor.args.task == "obb"
for i in range(bs): for i in range(bs):
if not persist and predictor.vid_path[i] != str(predictor.save_dir / Path(path[i]).name): # new video if not persist and predictor.vid_path[i] != str(predictor.save_dir / Path(path[i]).name): # new video
predictor.trackers[i].reset() predictor.trackers[i].reset()
det = predictor.results[i].boxes.cpu().numpy() det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy()
if len(det) == 0: if len(det) == 0:
continue continue
tracks = predictor.trackers[i].update(det, im0s[i]) tracks = predictor.trackers[i].update(det, im0s[i])
@ -66,7 +65,10 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
continue continue
idx = tracks[:, -1].astype(int) idx = tracks[:, -1].astype(int)
predictor.results[i] = predictor.results[i][idx] predictor.results[i] = predictor.results[i][idx]
predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1]))
update_args = dict()
update_args["obb" if is_obb else "boxes"] = torch.as_tensor(tracks[:, :-1])
predictor.results[i].update(**update_args)
def register_tracker(model: object, persist: bool) -> None: def register_tracker(model: object, persist: bool) -> None:

@ -4,7 +4,7 @@ import numpy as np
import scipy import scipy
from scipy.spatial.distance import cdist from scipy.spatial.distance import cdist
from ultralytics.utils.metrics import bbox_ioa from ultralytics.utils.metrics import bbox_ioa, batch_probiou
try: try:
import lap # for linear_assignment import lap # for linear_assignment
@ -74,14 +74,22 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
atlbrs = atracks atlbrs = atracks
btlbrs = btracks btlbrs = btracks
else: else:
atlbrs = [track.tlbr for track in atracks] atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks]
btlbrs = [track.tlbr for track in btracks] btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks]
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32) ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
if len(atlbrs) and len(btlbrs): if len(atlbrs) and len(btlbrs):
ious = bbox_ioa( if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5:
np.ascontiguousarray(atlbrs, dtype=np.float32), np.ascontiguousarray(btlbrs, dtype=np.float32), iou=True ious = batch_probiou(
) np.ascontiguousarray(atlbrs, dtype=np.float32),
np.ascontiguousarray(btlbrs, dtype=np.float32),
).numpy()
else:
ious = bbox_ioa(
np.ascontiguousarray(atlbrs, dtype=np.float32),
np.ascontiguousarray(btlbrs, dtype=np.float32),
iou=True,
)
return 1 - ious # cost matrix return 1 - ious # cost matrix

@ -46,7 +46,7 @@ def on_model_save(trainer):
# Upload checkpoints with rate limiting # Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness is_best = trainer.best_fitness == trainer.fitness
if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]: if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]:
LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_file}") LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_id}")
session.upload_model(trainer.epoch, trainer.last, is_best) session.upload_model(trainer.epoch, trainer.last, is_best)
session.timers["ckpt"] = time() # reset timer session.timers["ckpt"] = time() # reset timer

@ -239,13 +239,16 @@ def batch_probiou(obb1, obb2, eps=1e-7):
Calculate the prob iou between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf. Calculate the prob iou between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.
Args: Args:
obb1 (torch.Tensor): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format. obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
obb2 (torch.Tensor): A tensor of shape (M, 5) representing predicted obbs, with xywhr format. obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7. eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns: Returns:
(torch.Tensor): A tensor of shape (N, M) representing obb similarities. (torch.Tensor): A tensor of shape (N, M) representing obb similarities.
""" """
obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1
obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2
x1, y1 = obb1[..., :2].split(1, dim=-1) x1, y1 = obb1[..., :2].split(1, dim=-1)
x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1)) x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1))
a1, b1, c1 = _get_covariance_matrix(obb1) a1, b1, c1 = _get_covariance_matrix(obb1)

@ -774,6 +774,24 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False
return coords return coords
def regularize_rboxes(rboxes):
"""
Regularize rotated boxes in range [0, pi/2].
Args:
rboxes (torch.Tensor): (N, 5), xywhr.
Returns:
(torch.Tensor): The regularized boxes.
"""
x, y, w, h, t = rboxes.unbind(dim=-1)
# Swap edge and angle if h >= w
w_ = torch.where(w > h, w, h)
h_ = torch.where(w > h, h, w)
t = torch.where(w > h, t, t + math.pi / 2) % math.pi
return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes
def masks2segments(masks, strategy="largest"): def masks2segments(masks, strategy="largest"):
""" """
It takes a list of masks(n,h,w) and returns a list of segments(n,xy) It takes a list of masks(n,h,w) and returns a list of segments(n,xy)

Loading…
Cancel
Save