@ -30,21 +30,22 @@ class FastSAMPredictor(DetectionPredictor):
full_box [ 0 ] [ 4 ] = p [ 0 ] [ critical_iou_index ] [ : , 4 ]
full_box [ 0 ] [ 6 : ] = p [ 0 ] [ critical_iou_index ] [ : , 6 : ]
p [ 0 ] [ critical_iou_index ] = full_box
if not isinstance ( orig_imgs , list ) : # input images are a torch.Tensor, not a list
orig_imgs = ops . convert_torch2numpy_batch ( orig_imgs )
results = [ ]
is_list = isinstance ( orig_imgs , list ) # input images are a list, not a torch.Tensor
proto = preds [ 1 ] [ - 1 ] if len ( preds [ 1 ] ) == 3 else preds [ 1 ] # second output is len 3 if pt, but only 1 if exported
for i , pred in enumerate ( p ) :
orig_img = orig_imgs [ i ] if is_list else orig_imgs
orig_img = orig_imgs [ i ]
img_path = self . batch [ 0 ] [ i ]
if not len ( pred ) : # save empty boxes
masks = None
elif self . args . retina_masks :
if is_list :
pred [ : , : 4 ] = ops . scale_boxes ( img . shape [ 2 : ] , pred [ : , : 4 ] , orig_img . shape )
pred [ : , : 4 ] = ops . scale_boxes ( img . shape [ 2 : ] , pred [ : , : 4 ] , orig_img . shape )
masks = ops . process_mask_native ( proto [ i ] , pred [ : , 6 : ] , pred [ : , : 4 ] , orig_img . shape [ : 2 ] ) # HWC
else :
masks = ops . process_mask ( proto [ i ] , pred [ : , 6 : ] , pred [ : , : 4 ] , img . shape [ 2 : ] , upsample = True ) # HWC
if is_list :
pred [ : , : 4 ] = ops . scale_boxes ( img . shape [ 2 : ] , pred [ : , : 4 ] , orig_img . shape )
pred [ : , : 4 ] = ops . scale_boxes ( img . shape [ 2 : ] , pred [ : , : 4 ] , orig_img . shape )
results . append ( Results ( orig_img , path = img_path , names = self . model . names , boxes = pred [ : , : 6 ] , masks = masks ) )
return results