From 02b857e14ced2e41f6502372e2702e44bb105a47 Mon Sep 17 00:00:00 2001 From: Andy <39454881+yermandy@users.noreply.github.com> Date: Sun, 3 Sep 2023 18:19:30 +0200 Subject: [PATCH] Faster IoU prediction matching by removing `torch.cat` (#4708) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/engine/validator.py | 37 ++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py index 6f6d55eebf..685a116dca 100644 --- a/ultralytics/engine/validator.py +++ b/ultralytics/engine/validator.py @@ -24,6 +24,7 @@ from pathlib import Path import numpy as np import torch +from scipy.optimize import linear_sum_assignment from ultralytics.cfg import get_cfg, get_save_dir from ultralytics.data.utils import check_cls_dataset, check_det_dataset @@ -204,7 +205,7 @@ class BaseValidator: LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}") return stats - def match_predictions(self, pred_classes, true_classes, iou): + def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False): """ Matches predictions to ground truth objects (pred_classes, true_classes) using IoU. @@ -212,23 +213,35 @@ class BaseValidator: pred_classes (torch.Tensor): Predicted class indices of shape(N,). true_classes (torch.Tensor): Target class indices of shape(M,). iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth + use_scipy (bool): Whether to use scipy for matching (more precise). Returns: (torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds. """ + # Dx10 matrix, where D - detections, 10 - IoU thresholds correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool) + # LxD matrix where L - labels (rows), D - detections (columns) correct_class = true_classes[:, None] == pred_classes - for i, iouv in enumerate(self.iouv): - x = torch.nonzero(iou.ge(iouv) & correct_class) # IoU > threshold and classes match - if x.shape[0]: - # Concatenate [label, detect, iou] - matches = torch.cat((x, iou[x[:, 0], x[:, 1]].unsqueeze(1)), 1).cpu().numpy() - if x.shape[0] > 1: - matches = matches[matches[:, 2].argsort()[::-1]] - matches = matches[np.unique(matches[:, 1], return_index=True)[1]] - # matches = matches[matches[:, 2].argsort()[::-1]] - matches = matches[np.unique(matches[:, 0], return_index=True)[1]] - correct[matches[:, 1].astype(int), i] = True + iou = iou * correct_class # zero out the wrong classes + iou = iou.cpu().numpy() + for i, threshold in enumerate(self.iouv.cpu().tolist()): + if use_scipy: + cost_matrix = iou * (iou >= threshold) + if cost_matrix.any(): + labels_idx, detections_idx = linear_sum_assignment(cost_matrix, maximize=True) + valid = cost_matrix[labels_idx, detections_idx] > 0 + if valid.any(): + correct[detections_idx[valid], i] = True + else: + matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match + matches = np.array(matches).T + if matches.shape[0]: + if matches.shape[0] > 1: + matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + # matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + correct[matches[:, 1].astype(int), i] = True return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device) def add_callback(self, event: str, callback):