|
|
|
@ -38,13 +38,13 @@ def on_predict_start(predictor, persist=False): |
|
|
|
|
predictor.trackers = trackers |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_predict_postprocess_end(predictor): |
|
|
|
|
def on_predict_postprocess_end(predictor, persist=False): |
|
|
|
|
"""Postprocess detected boxes and update with object tracking.""" |
|
|
|
|
bs = predictor.dataset.bs |
|
|
|
|
path, im0s = predictor.batch[:2] |
|
|
|
|
|
|
|
|
|
for i in range(bs): |
|
|
|
|
if predictor.vid_path[i] != str(predictor.save_dir / Path(path[i]).name): # new video |
|
|
|
|
if not persist and predictor.vid_path[i] != str(predictor.save_dir / Path(path[i]).name): # new video |
|
|
|
|
predictor.trackers[i].reset() |
|
|
|
|
|
|
|
|
|
det = predictor.results[i].boxes.cpu().numpy() |
|
|
|
@ -67,4 +67,4 @@ def register_tracker(model, persist): |
|
|
|
|
persist (bool): Whether to persist the trackers if they already exist. |
|
|
|
|
""" |
|
|
|
|
model.add_callback('on_predict_start', partial(on_predict_start, persist=persist)) |
|
|
|
|
model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end) |
|
|
|
|
model.add_callback('on_predict_postprocess_end', partial(on_predict_postprocess_end, persist=persist)) |
|
|
|
|