`ultralytics 8.2.72` SAM 2 multiple-`bboxes` support (#14928)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/14980/head v8.2.72
Laughing 3 months ago committed by GitHub
parent 2187649f99
commit bea4c93278
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      ultralytics/__init__.py
  2. 25
      ultralytics/models/sam2/predict.py

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.71"
__version__ = "8.2.72"
import os

@ -102,28 +102,23 @@ class SAM2Predictor(Predictor):
if bboxes is not None:
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
bboxes *= r
bboxes = bboxes.view(-1, 2, 2) * r
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
# NOTE: merge "boxes" and "points" into a single "points" input
# (where boxes are added at the beginning) to model.sam_prompt_encoder
if points is not None:
points = torch.cat([bboxes, points], dim=1)
labels = torch.cat([bbox_labels, labels], dim=1)
else:
points, labels = bboxes, bbox_labels
if masks is not None:
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
points = (points, labels) if points is not None else None
# TODO: Embed prompts
# if bboxes is not None:
# box_coords = bboxes.reshape(-1, 2, 2)
# box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=bboxes.device)
# box_labels = box_labels.repeat(bboxes.size(0), 1)
# # we merge "boxes" and "points" into a single "concat_points" input (where
# # boxes are added at the beginning) to sam_prompt_encoder
# if concat_points is not None:
# concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
# concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
# concat_points = (concat_coords, concat_labels)
# else:
# concat_points = (box_coords, box_labels)
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
points=points,
boxes=bboxes,
boxes=None,
masks=masks,
)
# Predict masks

Loading…
Cancel
Save