diff --git a/docs/en/reference/utils/ops.md b/docs/en/reference/utils/ops.md index 43a3e0efe2..4cd9d5f304 100644 --- a/docs/en/reference/utils/ops.md +++ b/docs/en/reference/utils/ops.md @@ -119,6 +119,10 @@ keywords: Ultralytics YOLO, Utility Operations, segment2box, make_divisible, cli

+## ::: ultralytics.utils.ops.regularize_rboxes + +

+ ## ::: ultralytics.utils.ops.masks2segments

diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 3a91aec13e..2f90a6812e 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.1.4" +__version__ = "8.1.5" from ultralytics.data.explorer.explorer import Explorer from ultralytics.models import RTDETR, SAM, YOLO diff --git a/ultralytics/engine/results.py b/ultralytics/engine/results.py index 5a59ea334f..4faaeb99ab 100644 --- a/ultralytics/engine/results.py +++ b/ultralytics/engine/results.py @@ -115,7 +115,7 @@ class Results(SimpleClass): if v is not None: return len(v) - def update(self, boxes=None, masks=None, probs=None): + def update(self, boxes=None, masks=None, probs=None, obb=None): """Update the boxes, masks, and probs attributes of the Results object.""" if boxes is not None: self.boxes = Boxes(ops.clip_boxes(boxes, self.orig_shape), self.orig_shape) @@ -123,6 +123,8 @@ class Results(SimpleClass): self.masks = Masks(masks, self.orig_shape) if probs is not None: self.probs = probs + if obb is not None: + self.obb = OBB(obb, self.orig_shape) def _apply(self, fn, *args, **kwargs): """ diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py index b12f7b1424..8a4d8c59a4 100644 --- a/ultralytics/hub/session.py +++ b/ultralytics/hub/session.py @@ -225,14 +225,14 @@ class HUBTrainingSession: break # Timeout reached, exit loop response = request_func(*args, **kwargs) - if progress_total: - self._show_upload_progress(progress_total, response) - if response is None: LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}") time.sleep(2**i) # Exponential backoff before retrying continue # Skip further processing and retry + if progress_total: + self._show_upload_progress(progress_total, response) + if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES: return response # Success, no need to retry diff --git a/ultralytics/models/yolo/obb/predict.py b/ultralytics/models/yolo/obb/predict.py index c01d1df848..bb8d4d3f13 100644 --- a/ultralytics/models/yolo/obb/predict.py +++ b/ultralytics/models/yolo/obb/predict.py @@ -45,8 +45,9 @@ class OBBPredictor(DetectionPredictor): results = [] for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]): - pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True) + rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1)) + rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True) # xywh, r, conf, cls - obb = torch.cat([pred[:, :4], pred[:, -1:], pred[:, 4:6]], dim=-1) + obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1) results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb)) return results diff --git a/ultralytics/trackers/byte_tracker.py b/ultralytics/trackers/byte_tracker.py index 619ea12cdf..7e10b8d850 100644 --- a/ultralytics/trackers/byte_tracker.py +++ b/ultralytics/trackers/byte_tracker.py @@ -5,6 +5,8 @@ import numpy as np from .basetrack import BaseTrack, TrackState from .utils import matching from .utils.kalman_filter import KalmanFilterXYAH +from ..utils.ops import xywh2ltwh +from ..utils import LOGGER class STrack(BaseTrack): @@ -35,18 +37,18 @@ class STrack(BaseTrack): activate(kalman_filter, frame_id): Activate a new tracklet. re_activate(new_track, frame_id, new_id): Reactivate a previously lost tracklet. update(new_track, frame_id): Update the state of a matched track. - convert_coords(tlwh): Convert bounding box to x-y-angle-height format. + convert_coords(tlwh): Convert bounding box to x-y-aspect-height format. tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format. - tlbr_to_tlwh(tlbr): Convert tlbr bounding box to tlwh format. - tlwh_to_tlbr(tlwh): Convert tlwh bounding box to tlbr format. """ shared_kalman = KalmanFilterXYAH() - def __init__(self, tlwh, score, cls): + def __init__(self, xywh, score, cls): """Initialize new STrack instance.""" super().__init__() - self._tlwh = np.asarray(self.tlbr_to_tlwh(tlwh[:-1]), dtype=np.float32) + # xywh+idx or xywha+idx + assert len(xywh) in [5, 6], f"expected 5 or 6 values but got {len(xywh)}" + self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32) self.kalman_filter = None self.mean, self.covariance = None, None self.is_activated = False @@ -54,7 +56,8 @@ class STrack(BaseTrack): self.score = score self.tracklet_len = 0 self.cls = cls - self.idx = tlwh[-1] + self.idx = xywh[-1] + self.angle = xywh[4] if len(xywh) == 6 else None def predict(self): """Predicts mean and covariance using Kalman filter.""" @@ -123,6 +126,7 @@ class STrack(BaseTrack): self.track_id = self.next_id() self.score = new_track.score self.cls = new_track.cls + self.angle = new_track.angle self.idx = new_track.idx def update(self, new_track, frame_id): @@ -145,10 +149,11 @@ class STrack(BaseTrack): self.score = new_track.score self.cls = new_track.cls + self.angle = new_track.angle self.idx = new_track.idx def convert_coords(self, tlwh): - """Convert a bounding box's top-left-width-height format to its x-y-angle-height equivalent.""" + """Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent.""" return self.tlwh_to_xyah(tlwh) @property @@ -162,7 +167,7 @@ class STrack(BaseTrack): return ret @property - def tlbr(self): + def xyxy(self): """Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right).""" ret = self.tlwh.copy() ret[2:] += ret[:2] @@ -178,19 +183,26 @@ class STrack(BaseTrack): ret[2] /= ret[3] return ret - @staticmethod - def tlbr_to_tlwh(tlbr): - """Converts top-left bottom-right format to top-left width height format.""" - ret = np.asarray(tlbr).copy() - ret[2:] -= ret[:2] + @property + def xywh(self): + """Get current position in bounding box format (center x, center y, width, height).""" + ret = np.asarray(self.tlwh).copy() + ret[:2] += ret[2:] / 2 return ret - @staticmethod - def tlwh_to_tlbr(tlwh): - """Converts tlwh bounding box format to tlbr format.""" - ret = np.asarray(tlwh).copy() - ret[2:] += ret[:2] - return ret + @property + def xywha(self): + """Get current position in bounding box format (center x, center y, width, height, angle).""" + if self.angle is None: + LOGGER.warning("WARNING ⚠️ `angle` attr not found, returning `xywh` instead.") + return self.xywh + return np.concatenate([self.xywh, self.angle[None]]) + + @property + def result(self): + """Get current tracking results.""" + coords = self.xyxy if self.angle is None else self.xywha + return coords.tolist() + [self.track_id, self.score, self.cls, self.idx] def __repr__(self): """Return a string representation of the BYTETracker object with start and end frames and track ID.""" @@ -247,7 +259,7 @@ class BYTETracker: removed_stracks = [] scores = results.conf - bboxes = results.xyxy + bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh # Add index bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1) cls = results.cls @@ -349,10 +361,8 @@ class BYTETracker: self.removed_stracks.extend(removed_stracks) if len(self.removed_stracks) > 1000: self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum - return np.asarray( - [x.tlbr.tolist() + [x.track_id, x.score, x.cls, x.idx] for x in self.tracked_stracks if x.is_activated], - dtype=np.float32, - ) + + return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32) def get_kalmanfilter(self): """Returns a Kalman filter object for tracking bounding boxes.""" diff --git a/ultralytics/trackers/track.py b/ultralytics/trackers/track.py index ab48349236..c80c54da0f 100644 --- a/ultralytics/trackers/track.py +++ b/ultralytics/trackers/track.py @@ -25,8 +25,6 @@ def on_predict_start(predictor: object, persist: bool = False) -> None: Raises: AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'. """ - if predictor.args.task == "obb": - raise NotImplementedError("ERROR ❌ OBB task does not support track mode!") if hasattr(predictor, "trackers") and persist: return @@ -54,11 +52,12 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None bs = predictor.dataset.bs path, im0s = predictor.batch[:2] + is_obb = predictor.args.task == "obb" for i in range(bs): if not persist and predictor.vid_path[i] != str(predictor.save_dir / Path(path[i]).name): # new video predictor.trackers[i].reset() - det = predictor.results[i].boxes.cpu().numpy() + det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy() if len(det) == 0: continue tracks = predictor.trackers[i].update(det, im0s[i]) @@ -66,7 +65,10 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None continue idx = tracks[:, -1].astype(int) predictor.results[i] = predictor.results[i][idx] - predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1])) + + update_args = dict() + update_args["obb" if is_obb else "boxes"] = torch.as_tensor(tracks[:, :-1]) + predictor.results[i].update(**update_args) def register_tracker(model: object, persist: bool) -> None: diff --git a/ultralytics/trackers/utils/matching.py b/ultralytics/trackers/utils/matching.py index d5d89b2c4f..fa72b8b8ea 100644 --- a/ultralytics/trackers/utils/matching.py +++ b/ultralytics/trackers/utils/matching.py @@ -4,7 +4,7 @@ import numpy as np import scipy from scipy.spatial.distance import cdist -from ultralytics.utils.metrics import bbox_ioa +from ultralytics.utils.metrics import bbox_ioa, batch_probiou try: import lap # for linear_assignment @@ -74,14 +74,22 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray: atlbrs = atracks btlbrs = btracks else: - atlbrs = [track.tlbr for track in atracks] - btlbrs = [track.tlbr for track in btracks] + atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks] + btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks] ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32) if len(atlbrs) and len(btlbrs): - ious = bbox_ioa( - np.ascontiguousarray(atlbrs, dtype=np.float32), np.ascontiguousarray(btlbrs, dtype=np.float32), iou=True - ) + if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5: + ious = batch_probiou( + np.ascontiguousarray(atlbrs, dtype=np.float32), + np.ascontiguousarray(btlbrs, dtype=np.float32), + ).numpy() + else: + ious = bbox_ioa( + np.ascontiguousarray(atlbrs, dtype=np.float32), + np.ascontiguousarray(btlbrs, dtype=np.float32), + iou=True, + ) return 1 - ious # cost matrix diff --git a/ultralytics/utils/callbacks/hub.py b/ultralytics/utils/callbacks/hub.py index 6002431a08..8d93093cc2 100644 --- a/ultralytics/utils/callbacks/hub.py +++ b/ultralytics/utils/callbacks/hub.py @@ -46,7 +46,7 @@ def on_model_save(trainer): # Upload checkpoints with rate limiting is_best = trainer.best_fitness == trainer.fitness if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]: - LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_file}") + LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_id}") session.upload_model(trainer.epoch, trainer.last, is_best) session.timers["ckpt"] = time() # reset timer diff --git a/ultralytics/utils/metrics.py b/ultralytics/utils/metrics.py index 676ef73773..f10ae329f4 100644 --- a/ultralytics/utils/metrics.py +++ b/ultralytics/utils/metrics.py @@ -239,13 +239,16 @@ def batch_probiou(obb1, obb2, eps=1e-7): Calculate the prob iou between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf. Args: - obb1 (torch.Tensor): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format. - obb2 (torch.Tensor): A tensor of shape (M, 5) representing predicted obbs, with xywhr format. + obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format. + obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format. eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7. Returns: (torch.Tensor): A tensor of shape (N, M) representing obb similarities. """ + obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1 + obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2 + x1, y1 = obb1[..., :2].split(1, dim=-1) x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1)) a1, b1, c1 = _get_covariance_matrix(obb1) diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py index 94f5b3b53b..5632fd9030 100644 --- a/ultralytics/utils/ops.py +++ b/ultralytics/utils/ops.py @@ -774,6 +774,24 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False return coords +def regularize_rboxes(rboxes): + """ + Regularize rotated boxes in range [0, pi/2]. + + Args: + rboxes (torch.Tensor): (N, 5), xywhr. + + Returns: + (torch.Tensor): The regularized boxes. + """ + x, y, w, h, t = rboxes.unbind(dim=-1) + # Swap edge and angle if h >= w + w_ = torch.where(w > h, w, h) + h_ = torch.where(w > h, h, w) + t = torch.where(w > h, t, t + math.pi / 2) % math.pi + return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes + + def masks2segments(masks, strategy="largest"): """ It takes a list of masks(n,h,w) and returns a list of segments(n,xy)