diff --git a/tests/test_cli.py b/tests/test_cli.py index 9582dde8f1..3eadf3c24e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -101,7 +101,7 @@ def test_mobilesam(): model.predict(source, points=[900, 370], labels=[1]) # Predict a segment based on a box prompt - model.predict(source, bboxes=[439, 437, 524, 709]) + model.predict(source, bboxes=[439, 437, 524, 709], save=True) # Predict all # model(source) diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py index 8ecb069eeb..686ef70c63 100644 --- a/ultralytics/models/sam/predict.py +++ b/ultralytics/models/sam/predict.py @@ -450,16 +450,18 @@ class Predictor(BasePredictor): results = [] for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]): - if pred_bboxes is not None: - pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False) - cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device) - pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1) - if len(masks) == 0: masks = None else: masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0] masks = masks > self.model.mask_threshold # to bool + if pred_bboxes is not None: + pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False) + else: + pred_bboxes = batched_mask_to_box(masks) + # NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency. + cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device) + pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1) results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes)) # Reset segment-all mode. self.segment_all = False