Fix error with `torch` tensor input in `model.track()` (#19278)

Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/19249/head
Mohammed Yasin 3 weeks ago committed by GitHub
parent e981fc629d
commit 0f81777af5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 12
      ultralytics/trackers/track.py

@ -66,25 +66,23 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
>>> predictor = YourPredictorClass()
>>> on_predict_postprocess_end(predictor, persist=True)
"""
path, im0s = predictor.batch[:2]
is_obb = predictor.args.task == "obb"
is_stream = predictor.dataset.mode == "stream"
for i in range(len(im0s)):
for i, result in enumerate(predictor.results):
tracker = predictor.trackers[i if is_stream else 0]
vid_path = predictor.save_dir / Path(path[i]).name
vid_path = predictor.save_dir / Path(result.path).name
if not persist and predictor.vid_path[i if is_stream else 0] != vid_path:
tracker.reset()
predictor.vid_path[i if is_stream else 0] = vid_path
det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy()
det = (result.obb if is_obb else result.boxes).cpu().numpy()
if len(det) == 0:
continue
tracks = tracker.update(det, im0s[i])
tracks = tracker.update(det, result.orig_img)
if len(tracks) == 0:
continue
idx = tracks[:, -1].astype(int)
predictor.results[i] = predictor.results[i][idx]
predictor.results[i] = result[idx]
update_args = {"obb" if is_obb else "boxes": torch.as_tensor(tracks[:, :-1])}
predictor.results[i].update(**update_args)

Loading…
Cancel
Save