From 5e38c7d71be0ad0c2548dbd91c94da68be13e8a5 Mon Sep 17 00:00:00 2001 From: Andy <39454881+yermandy@users.noreply.github.com> Date: Mon, 3 Jul 2023 16:20:49 +0200 Subject: [PATCH] Improve NMS speed (#3467) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- ultralytics/yolo/utils/ops.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/ultralytics/yolo/utils/ops.py b/ultralytics/yolo/utils/ops.py index b998df443f..b53370a7b7 100644 --- a/ultralytics/yolo/utils/ops.py +++ b/ultralytics/yolo/utils/ops.py @@ -200,12 +200,16 @@ def non_max_suppression( multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) merge = False # use merge-NMS + prediction = prediction.clone() # don't modify original + prediction = prediction.transpose(-1, -2) # to (batch, boxes, items) + prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy + t = time.time() output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs for xi, x in enumerate(prediction): # image index, image inference # Apply constraints # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height - x = x.transpose(0, -1)[xc[xi]] # confidence + x = x[xc[xi]] # confidence # Cat apriori labels if autolabelling if labels and len(labels[xi]): @@ -221,9 +225,9 @@ def non_max_suppression( # Detections matrix nx6 (xyxy, conf, cls) box, cls, mask = x.split((4, nc, nm), 1) - box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2) + if multi_label: - i, j = (cls > conf_thres).nonzero(as_tuple=False).T + i, j = torch.where(cls > conf_thres) x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1) else: # best class only conf, j = cls.max(1, keepdim=True) @@ -241,7 +245,9 @@ def non_max_suppression( n = x.shape[0] # number of boxes if not n: # no boxes continue - x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes + + if n > max_nms: # excess boxes + x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes # Batched NMS c = x[:, 5:6] * (0 if agnostic else max_wh) # classes