|
|
|
@ -24,6 +24,7 @@ from pathlib import Path |
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
import torch |
|
|
|
|
from scipy.optimize import linear_sum_assignment |
|
|
|
|
|
|
|
|
|
from ultralytics.cfg import get_cfg, get_save_dir |
|
|
|
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset |
|
|
|
@ -204,7 +205,7 @@ class BaseValidator: |
|
|
|
|
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}") |
|
|
|
|
return stats |
|
|
|
|
|
|
|
|
|
def match_predictions(self, pred_classes, true_classes, iou): |
|
|
|
|
def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False): |
|
|
|
|
""" |
|
|
|
|
Matches predictions to ground truth objects (pred_classes, true_classes) using IoU. |
|
|
|
|
|
|
|
|
@ -212,23 +213,35 @@ class BaseValidator: |
|
|
|
|
pred_classes (torch.Tensor): Predicted class indices of shape(N,). |
|
|
|
|
true_classes (torch.Tensor): Target class indices of shape(M,). |
|
|
|
|
iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth |
|
|
|
|
use_scipy (bool): Whether to use scipy for matching (more precise). |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
(torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds. |
|
|
|
|
""" |
|
|
|
|
# Dx10 matrix, where D - detections, 10 - IoU thresholds |
|
|
|
|
correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool) |
|
|
|
|
# LxD matrix where L - labels (rows), D - detections (columns) |
|
|
|
|
correct_class = true_classes[:, None] == pred_classes |
|
|
|
|
for i, iouv in enumerate(self.iouv): |
|
|
|
|
x = torch.nonzero(iou.ge(iouv) & correct_class) # IoU > threshold and classes match |
|
|
|
|
if x.shape[0]: |
|
|
|
|
# Concatenate [label, detect, iou] |
|
|
|
|
matches = torch.cat((x, iou[x[:, 0], x[:, 1]].unsqueeze(1)), 1).cpu().numpy() |
|
|
|
|
if x.shape[0] > 1: |
|
|
|
|
matches = matches[matches[:, 2].argsort()[::-1]] |
|
|
|
|
matches = matches[np.unique(matches[:, 1], return_index=True)[1]] |
|
|
|
|
# matches = matches[matches[:, 2].argsort()[::-1]] |
|
|
|
|
matches = matches[np.unique(matches[:, 0], return_index=True)[1]] |
|
|
|
|
correct[matches[:, 1].astype(int), i] = True |
|
|
|
|
iou = iou * correct_class # zero out the wrong classes |
|
|
|
|
iou = iou.cpu().numpy() |
|
|
|
|
for i, threshold in enumerate(self.iouv.cpu().tolist()): |
|
|
|
|
if use_scipy: |
|
|
|
|
cost_matrix = iou * (iou >= threshold) |
|
|
|
|
if cost_matrix.any(): |
|
|
|
|
labels_idx, detections_idx = linear_sum_assignment(cost_matrix, maximize=True) |
|
|
|
|
valid = cost_matrix[labels_idx, detections_idx] > 0 |
|
|
|
|
if valid.any(): |
|
|
|
|
correct[detections_idx[valid], i] = True |
|
|
|
|
else: |
|
|
|
|
matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match |
|
|
|
|
matches = np.array(matches).T |
|
|
|
|
if matches.shape[0]: |
|
|
|
|
if matches.shape[0] > 1: |
|
|
|
|
matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]] |
|
|
|
|
matches = matches[np.unique(matches[:, 1], return_index=True)[1]] |
|
|
|
|
# matches = matches[matches[:, 2].argsort()[::-1]] |
|
|
|
|
matches = matches[np.unique(matches[:, 0], return_index=True)[1]] |
|
|
|
|
correct[matches[:, 1].astype(int), i] = True |
|
|
|
|
return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device) |
|
|
|
|
|
|
|
|
|
def add_callback(self, event: str, callback): |
|
|
|
|