diff --git a/docs/en/reference/utils/ops.md b/docs/en/reference/utils/ops.md
index 43a3e0efe2..4cd9d5f304 100644
--- a/docs/en/reference/utils/ops.md
+++ b/docs/en/reference/utils/ops.md
@@ -119,6 +119,10 @@ keywords: Ultralytics YOLO, Utility Operations, segment2box, make_divisible, cli
+## ::: ultralytics.utils.ops.regularize_rboxes
+
+
+
## ::: ultralytics.utils.ops.masks2segments
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index 3a91aec13e..2f90a6812e 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-__version__ = "8.1.4"
+__version__ = "8.1.5"
from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO
diff --git a/ultralytics/engine/results.py b/ultralytics/engine/results.py
index 5a59ea334f..4faaeb99ab 100644
--- a/ultralytics/engine/results.py
+++ b/ultralytics/engine/results.py
@@ -115,7 +115,7 @@ class Results(SimpleClass):
if v is not None:
return len(v)
- def update(self, boxes=None, masks=None, probs=None):
+ def update(self, boxes=None, masks=None, probs=None, obb=None):
"""Update the boxes, masks, and probs attributes of the Results object."""
if boxes is not None:
self.boxes = Boxes(ops.clip_boxes(boxes, self.orig_shape), self.orig_shape)
@@ -123,6 +123,8 @@ class Results(SimpleClass):
self.masks = Masks(masks, self.orig_shape)
if probs is not None:
self.probs = probs
+ if obb is not None:
+ self.obb = OBB(obb, self.orig_shape)
def _apply(self, fn, *args, **kwargs):
"""
diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py
index b12f7b1424..8a4d8c59a4 100644
--- a/ultralytics/hub/session.py
+++ b/ultralytics/hub/session.py
@@ -225,14 +225,14 @@ class HUBTrainingSession:
break # Timeout reached, exit loop
response = request_func(*args, **kwargs)
- if progress_total:
- self._show_upload_progress(progress_total, response)
-
if response is None:
LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}")
time.sleep(2**i) # Exponential backoff before retrying
continue # Skip further processing and retry
+ if progress_total:
+ self._show_upload_progress(progress_total, response)
+
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
return response # Success, no need to retry
diff --git a/ultralytics/models/yolo/obb/predict.py b/ultralytics/models/yolo/obb/predict.py
index c01d1df848..bb8d4d3f13 100644
--- a/ultralytics/models/yolo/obb/predict.py
+++ b/ultralytics/models/yolo/obb/predict.py
@@ -45,8 +45,9 @@ class OBBPredictor(DetectionPredictor):
results = []
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
- pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True)
+ rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
+ rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
# xywh, r, conf, cls
- obb = torch.cat([pred[:, :4], pred[:, -1:], pred[:, 4:6]], dim=-1)
+ obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))
return results
diff --git a/ultralytics/trackers/byte_tracker.py b/ultralytics/trackers/byte_tracker.py
index 619ea12cdf..7e10b8d850 100644
--- a/ultralytics/trackers/byte_tracker.py
+++ b/ultralytics/trackers/byte_tracker.py
@@ -5,6 +5,8 @@ import numpy as np
from .basetrack import BaseTrack, TrackState
from .utils import matching
from .utils.kalman_filter import KalmanFilterXYAH
+from ..utils.ops import xywh2ltwh
+from ..utils import LOGGER
class STrack(BaseTrack):
@@ -35,18 +37,18 @@ class STrack(BaseTrack):
activate(kalman_filter, frame_id): Activate a new tracklet.
re_activate(new_track, frame_id, new_id): Reactivate a previously lost tracklet.
update(new_track, frame_id): Update the state of a matched track.
- convert_coords(tlwh): Convert bounding box to x-y-angle-height format.
+ convert_coords(tlwh): Convert bounding box to x-y-aspect-height format.
tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
- tlbr_to_tlwh(tlbr): Convert tlbr bounding box to tlwh format.
- tlwh_to_tlbr(tlwh): Convert tlwh bounding box to tlbr format.
"""
shared_kalman = KalmanFilterXYAH()
- def __init__(self, tlwh, score, cls):
+ def __init__(self, xywh, score, cls):
"""Initialize new STrack instance."""
super().__init__()
- self._tlwh = np.asarray(self.tlbr_to_tlwh(tlwh[:-1]), dtype=np.float32)
+ # xywh+idx or xywha+idx
+ assert len(xywh) in [5, 6], f"expected 5 or 6 values but got {len(xywh)}"
+ self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32)
self.kalman_filter = None
self.mean, self.covariance = None, None
self.is_activated = False
@@ -54,7 +56,8 @@ class STrack(BaseTrack):
self.score = score
self.tracklet_len = 0
self.cls = cls
- self.idx = tlwh[-1]
+ self.idx = xywh[-1]
+ self.angle = xywh[4] if len(xywh) == 6 else None
def predict(self):
"""Predicts mean and covariance using Kalman filter."""
@@ -123,6 +126,7 @@ class STrack(BaseTrack):
self.track_id = self.next_id()
self.score = new_track.score
self.cls = new_track.cls
+ self.angle = new_track.angle
self.idx = new_track.idx
def update(self, new_track, frame_id):
@@ -145,10 +149,11 @@ class STrack(BaseTrack):
self.score = new_track.score
self.cls = new_track.cls
+ self.angle = new_track.angle
self.idx = new_track.idx
def convert_coords(self, tlwh):
- """Convert a bounding box's top-left-width-height format to its x-y-angle-height equivalent."""
+ """Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent."""
return self.tlwh_to_xyah(tlwh)
@property
@@ -162,7 +167,7 @@ class STrack(BaseTrack):
return ret
@property
- def tlbr(self):
+ def xyxy(self):
"""Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right)."""
ret = self.tlwh.copy()
ret[2:] += ret[:2]
@@ -178,19 +183,26 @@ class STrack(BaseTrack):
ret[2] /= ret[3]
return ret
- @staticmethod
- def tlbr_to_tlwh(tlbr):
- """Converts top-left bottom-right format to top-left width height format."""
- ret = np.asarray(tlbr).copy()
- ret[2:] -= ret[:2]
+ @property
+ def xywh(self):
+ """Get current position in bounding box format (center x, center y, width, height)."""
+ ret = np.asarray(self.tlwh).copy()
+ ret[:2] += ret[2:] / 2
return ret
- @staticmethod
- def tlwh_to_tlbr(tlwh):
- """Converts tlwh bounding box format to tlbr format."""
- ret = np.asarray(tlwh).copy()
- ret[2:] += ret[:2]
- return ret
+ @property
+ def xywha(self):
+ """Get current position in bounding box format (center x, center y, width, height, angle)."""
+ if self.angle is None:
+ LOGGER.warning("WARNING ⚠️ `angle` attr not found, returning `xywh` instead.")
+ return self.xywh
+ return np.concatenate([self.xywh, self.angle[None]])
+
+ @property
+ def result(self):
+ """Get current tracking results."""
+ coords = self.xyxy if self.angle is None else self.xywha
+ return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
def __repr__(self):
"""Return a string representation of the BYTETracker object with start and end frames and track ID."""
@@ -247,7 +259,7 @@ class BYTETracker:
removed_stracks = []
scores = results.conf
- bboxes = results.xyxy
+ bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh
# Add index
bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
cls = results.cls
@@ -349,10 +361,8 @@ class BYTETracker:
self.removed_stracks.extend(removed_stracks)
if len(self.removed_stracks) > 1000:
self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum
- return np.asarray(
- [x.tlbr.tolist() + [x.track_id, x.score, x.cls, x.idx] for x in self.tracked_stracks if x.is_activated],
- dtype=np.float32,
- )
+
+ return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
def get_kalmanfilter(self):
"""Returns a Kalman filter object for tracking bounding boxes."""
diff --git a/ultralytics/trackers/track.py b/ultralytics/trackers/track.py
index ab48349236..c80c54da0f 100644
--- a/ultralytics/trackers/track.py
+++ b/ultralytics/trackers/track.py
@@ -25,8 +25,6 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
Raises:
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
"""
- if predictor.args.task == "obb":
- raise NotImplementedError("ERROR ❌ OBB task does not support track mode!")
if hasattr(predictor, "trackers") and persist:
return
@@ -54,11 +52,12 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
bs = predictor.dataset.bs
path, im0s = predictor.batch[:2]
+ is_obb = predictor.args.task == "obb"
for i in range(bs):
if not persist and predictor.vid_path[i] != str(predictor.save_dir / Path(path[i]).name): # new video
predictor.trackers[i].reset()
- det = predictor.results[i].boxes.cpu().numpy()
+ det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy()
if len(det) == 0:
continue
tracks = predictor.trackers[i].update(det, im0s[i])
@@ -66,7 +65,10 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
continue
idx = tracks[:, -1].astype(int)
predictor.results[i] = predictor.results[i][idx]
- predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1]))
+
+ update_args = dict()
+ update_args["obb" if is_obb else "boxes"] = torch.as_tensor(tracks[:, :-1])
+ predictor.results[i].update(**update_args)
def register_tracker(model: object, persist: bool) -> None:
diff --git a/ultralytics/trackers/utils/matching.py b/ultralytics/trackers/utils/matching.py
index d5d89b2c4f..fa72b8b8ea 100644
--- a/ultralytics/trackers/utils/matching.py
+++ b/ultralytics/trackers/utils/matching.py
@@ -4,7 +4,7 @@ import numpy as np
import scipy
from scipy.spatial.distance import cdist
-from ultralytics.utils.metrics import bbox_ioa
+from ultralytics.utils.metrics import bbox_ioa, batch_probiou
try:
import lap # for linear_assignment
@@ -74,14 +74,22 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
atlbrs = atracks
btlbrs = btracks
else:
- atlbrs = [track.tlbr for track in atracks]
- btlbrs = [track.tlbr for track in btracks]
+ atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks]
+ btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks]
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
if len(atlbrs) and len(btlbrs):
- ious = bbox_ioa(
- np.ascontiguousarray(atlbrs, dtype=np.float32), np.ascontiguousarray(btlbrs, dtype=np.float32), iou=True
- )
+ if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5:
+ ious = batch_probiou(
+ np.ascontiguousarray(atlbrs, dtype=np.float32),
+ np.ascontiguousarray(btlbrs, dtype=np.float32),
+ ).numpy()
+ else:
+ ious = bbox_ioa(
+ np.ascontiguousarray(atlbrs, dtype=np.float32),
+ np.ascontiguousarray(btlbrs, dtype=np.float32),
+ iou=True,
+ )
return 1 - ious # cost matrix
diff --git a/ultralytics/utils/callbacks/hub.py b/ultralytics/utils/callbacks/hub.py
index 6002431a08..8d93093cc2 100644
--- a/ultralytics/utils/callbacks/hub.py
+++ b/ultralytics/utils/callbacks/hub.py
@@ -46,7 +46,7 @@ def on_model_save(trainer):
# Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness
if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]:
- LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_file}")
+ LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_id}")
session.upload_model(trainer.epoch, trainer.last, is_best)
session.timers["ckpt"] = time() # reset timer
diff --git a/ultralytics/utils/metrics.py b/ultralytics/utils/metrics.py
index 676ef73773..f10ae329f4 100644
--- a/ultralytics/utils/metrics.py
+++ b/ultralytics/utils/metrics.py
@@ -239,13 +239,16 @@ def batch_probiou(obb1, obb2, eps=1e-7):
Calculate the prob iou between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.
Args:
- obb1 (torch.Tensor): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
- obb2 (torch.Tensor): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.
+ obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
+ obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
Returns:
(torch.Tensor): A tensor of shape (N, M) representing obb similarities.
"""
+ obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1
+ obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2
+
x1, y1 = obb1[..., :2].split(1, dim=-1)
x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1))
a1, b1, c1 = _get_covariance_matrix(obb1)
diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py
index 94f5b3b53b..5632fd9030 100644
--- a/ultralytics/utils/ops.py
+++ b/ultralytics/utils/ops.py
@@ -774,6 +774,24 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False
return coords
+def regularize_rboxes(rboxes):
+ """
+ Regularize rotated boxes in range [0, pi/2].
+
+ Args:
+ rboxes (torch.Tensor): (N, 5), xywhr.
+
+ Returns:
+ (torch.Tensor): The regularized boxes.
+ """
+ x, y, w, h, t = rboxes.unbind(dim=-1)
+ # Swap edge and angle if h >= w
+ w_ = torch.where(w > h, w, h)
+ h_ = torch.where(w > h, h, w)
+ t = torch.where(w > h, t, t + math.pi / 2) % math.pi
+ return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes
+
+
def masks2segments(masks, strategy="largest"):
"""
It takes a list of masks(n,h,w) and returns a list of segments(n,xy)