Return boxes for SAM prompts inference (#16276)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
pull/16280/head
Laughing 2 months ago committed by GitHub
parent c2068df9d9
commit 02e995383d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      tests/test_cli.py
  2. 12
      ultralytics/models/sam/predict.py

@ -101,7 +101,7 @@ def test_mobilesam():
model.predict(source, points=[900, 370], labels=[1])
# Predict a segment based on a box prompt
model.predict(source, bboxes=[439, 437, 524, 709])
model.predict(source, bboxes=[439, 437, 524, 709], save=True)
# Predict all
# model(source)

@ -450,16 +450,18 @@ class Predictor(BasePredictor):
results = []
for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]):
if pred_bboxes is not None:
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
if len(masks) == 0:
masks = None
else:
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
masks = masks > self.model.mask_threshold # to bool
if pred_bboxes is not None:
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
else:
pred_bboxes = batched_mask_to_box(masks)
# NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency.
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
# Reset segment-all mode.
self.segment_all = False

Loading…
Cancel
Save