@ -8,7 +8,6 @@ import torch
import torch . nn as nn
from torch . nn . init import constant_ , xavier_uniform_
from ultralytics . utils import MACOS
from ultralytics . utils . tal import TORCH_1_10 , dist2bbox , dist2rbox , make_anchors
from . block import DFL , BNContrastiveHead , ContrastiveHead , Proto
@ -133,38 +132,26 @@ class Detect(nn.Module):
@staticmethod
def postprocess ( preds : torch . Tensor , max_det : int , nc : int = 80 ) :
"""
Post - processes the predictions obtained from a YOLOv10 model .
Post - processes YOLO model predictions .
Args :
preds ( torch . Tensor ) : The predictions obtained from the model . It should have a shape of ( batch_size , num_boxes , 4 + num_classes ) .
max_det ( int ) : The maximum number of detections to keep .
nc ( int , optional ) : The number of classes . Defaults to 80.
preds ( torch . Tensor ) : Raw predictions with shape ( batch_size , num_anchors , 4 + nc ) with last dimension
format [ x , y , w , h , class_probs ] .
max_det ( int ) : Maximum detections per image .
nc ( int , optional ) : Number of classes . Default : 80.
Returns :
( torch . Tensor ) : The post - processed predictions with shape ( batch_size , max_det , 6 ) ,
including bounding boxes , scores and cls .
( torch . Tensor ) : Processed predictions with shape ( batch_size , min ( max_det , num_anchors ) , 6 ) and last
dimension format [ x , y , w , h , max_class_prob , class_index ] .
"""
assert 4 + nc == preds . shape [ - 1 ]
batch_size , anchors , predictions = preds . shape # i.e. shape(16,8400,84)
boxes , scores = preds . split ( [ 4 , nc ] , dim = - 1 )
max_scores = scores . amax ( dim = - 1 )
max_scores , index = torch . topk ( max_scores , min ( max_det , max_scores . shape [ 1 ] ) , axis = - 1 )
index = index . unsqueeze ( - 1 )
boxes = torch . gather ( boxes , dim = 1 , index = index . repeat ( 1 , 1 , boxes . shape [ - 1 ] ) )
scores = torch . gather ( scores , dim = 1 , index = index . repeat ( 1 , 1 , scores . shape [ - 1 ] ) )
# NOTE: simplify result but slightly lower mAP
# scores, labels = scores.max(dim=-1)
# return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
scores , index = torch . topk ( scores . flatten ( 1 ) , max_det , axis = - 1 )
labels = index % nc
index = index / / nc
# Set int64 dtype for MPS and CoreML compatibility to avoid 'gather_along_axis' ops error
if MACOS :
index = index . to ( torch . int64 )
boxes = boxes . gather ( dim = 1 , index = index . unsqueeze ( - 1 ) . repeat ( 1 , 1 , boxes . shape [ - 1 ] ) )
return torch . cat ( [ boxes , scores . unsqueeze ( - 1 ) , labels . unsqueeze ( - 1 ) . to ( boxes . dtype ) ] , dim = - 1 )
index = scores . amax ( dim = - 1 ) . topk ( min ( max_det , anchors ) ) [ 1 ] . unsqueeze ( - 1 )
boxes = boxes . gather ( dim = 1 , index = index . repeat ( 1 , 1 , 4 ) )
scores = scores . gather ( dim = 1 , index = index . repeat ( 1 , 1 , nc ) )
scores , index = scores . flatten ( 1 ) . topk ( max_det )
i = torch . arange ( batch_size ) [ . . . , None ] # batch indices
return torch . cat ( [ boxes [ i , index / / nc ] , scores [ . . . , None ] , ( index % nc ) [ . . . , None ] . float ( ) ] , dim = - 1 )
class Segment ( Detect ) :