|
|
@ -199,6 +199,7 @@ def non_max_suppression( |
|
|
|
max_nms (int): The maximum number of boxes into torchvision.ops.nms(). |
|
|
|
max_nms (int): The maximum number of boxes into torchvision.ops.nms(). |
|
|
|
max_wh (int): The maximum box width and height in pixels. |
|
|
|
max_wh (int): The maximum box width and height in pixels. |
|
|
|
in_place (bool): If True, the input prediction tensor will be modified in place. |
|
|
|
in_place (bool): If True, the input prediction tensor will be modified in place. |
|
|
|
|
|
|
|
rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS. |
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
Returns: |
|
|
|
(List[torch.Tensor]): A list of length batch_size, where each element is a tensor of |
|
|
|
(List[torch.Tensor]): A list of length batch_size, where each element is a tensor of |
|
|
@ -212,11 +213,16 @@ def non_max_suppression( |
|
|
|
assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0" |
|
|
|
assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0" |
|
|
|
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out) |
|
|
|
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out) |
|
|
|
prediction = prediction[0] # select only inference output |
|
|
|
prediction = prediction[0] # select only inference output |
|
|
|
|
|
|
|
if classes is not None: |
|
|
|
|
|
|
|
classes = torch.tensor(classes, device=prediction.device) |
|
|
|
|
|
|
|
|
|
|
|
if prediction.shape[-1] == 6: # end-to-end model |
|
|
|
if prediction.shape[-1] == 6: # end-to-end model (BNC, i.e. 1,300,6) |
|
|
|
return [pred[pred[:, 4] > conf_thres] for pred in prediction] |
|
|
|
output = [pred[pred[:, 4] > conf_thres] for pred in prediction] |
|
|
|
|
|
|
|
if classes is not None: |
|
|
|
|
|
|
|
output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output] |
|
|
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
bs = prediction.shape[0] # batch size |
|
|
|
bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300) |
|
|
|
nc = nc or (prediction.shape[1] - 4) # number of classes |
|
|
|
nc = nc or (prediction.shape[1] - 4) # number of classes |
|
|
|
nm = prediction.shape[1] - nc - 4 # number of masks |
|
|
|
nm = prediction.shape[1] - nc - 4 # number of masks |
|
|
|
mi = 4 + nc # mask start index |
|
|
|
mi = 4 + nc # mask start index |
|
|
@ -265,7 +271,7 @@ def non_max_suppression( |
|
|
|
|
|
|
|
|
|
|
|
# Filter by class |
|
|
|
# Filter by class |
|
|
|
if classes is not None: |
|
|
|
if classes is not None: |
|
|
|
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] |
|
|
|
x = x[(x[:, 5:6] == classes).any(1)] |
|
|
|
|
|
|
|
|
|
|
|
# Check shape |
|
|
|
# Check shape |
|
|
|
n = x.shape[0] # number of boxes |
|
|
|
n = x.shape[0] # number of boxes |
|
|
|