|
|
|
@ -200,12 +200,16 @@ def non_max_suppression( |
|
|
|
|
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) |
|
|
|
|
merge = False # use merge-NMS |
|
|
|
|
|
|
|
|
|
prediction = prediction.clone() # don't modify original |
|
|
|
|
prediction = prediction.transpose(-1, -2) # to (batch, boxes, items) |
|
|
|
|
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy |
|
|
|
|
|
|
|
|
|
t = time.time() |
|
|
|
|
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs |
|
|
|
|
for xi, x in enumerate(prediction): # image index, image inference |
|
|
|
|
# Apply constraints |
|
|
|
|
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height |
|
|
|
|
x = x.transpose(0, -1)[xc[xi]] # confidence |
|
|
|
|
x = x[xc[xi]] # confidence |
|
|
|
|
|
|
|
|
|
# Cat apriori labels if autolabelling |
|
|
|
|
if labels and len(labels[xi]): |
|
|
|
@ -221,9 +225,9 @@ def non_max_suppression( |
|
|
|
|
|
|
|
|
|
# Detections matrix nx6 (xyxy, conf, cls) |
|
|
|
|
box, cls, mask = x.split((4, nc, nm), 1) |
|
|
|
|
box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2) |
|
|
|
|
|
|
|
|
|
if multi_label: |
|
|
|
|
i, j = (cls > conf_thres).nonzero(as_tuple=False).T |
|
|
|
|
i, j = torch.where(cls > conf_thres) |
|
|
|
|
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1) |
|
|
|
|
else: # best class only |
|
|
|
|
conf, j = cls.max(1, keepdim=True) |
|
|
|
@ -241,7 +245,9 @@ def non_max_suppression( |
|
|
|
|
n = x.shape[0] # number of boxes |
|
|
|
|
if not n: # no boxes |
|
|
|
|
continue |
|
|
|
|
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes |
|
|
|
|
|
|
|
|
|
if n > max_nms: # excess boxes |
|
|
|
|
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes |
|
|
|
|
|
|
|
|
|
# Batched NMS |
|
|
|
|
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes |
|
|
|
|