`ultralytics 8.3.13` SAM prompt-inference refactor (#16894)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/16937/head v8.3.13
Laughing 1 month ago committed by GitHub
parent 5a9ce863e4
commit 15e6133534
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 15
      docs/en/models/sam-2.md
  2. 8
      docs/en/models/sam.md
  3. 2
      ultralytics/__init__.py
  4. 114
      ultralytics/models/sam/predict.py

@ -142,11 +142,20 @@ SAM 2 can be utilized across a broad spectrum of tasks, including real-time vide
# Display model information (optional)
model.info()
# Segment with bounding box prompt
# Run inference with bboxes prompt
results = model("path/to/image.jpg", bboxes=[100, 100, 200, 200])
# Segment with point prompt
results = model("path/to/image.jpg", points=[150, 150], labels=[1])
# Run inference with single point
results = model(points=[900, 370], labels=[1])
# Run inference with multiple points
results = model(points=[[400, 370], [900, 370]], labels=[1, 1])
# Run inference with multiple points prompt per object
results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
# Run inference with negative points prompt
results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
```
#### Segment Everything

@ -59,16 +59,16 @@ The Segment Anything Model can be employed for a multitude of downstream tasks t
results = model("ultralytics/assets/zidane.jpg", bboxes=[439, 437, 524, 709])
# Run inference with single point
results = predictor(points=[900, 370], labels=[1])
results = model(points=[900, 370], labels=[1])
# Run inference with multiple points
results = predictor(points=[[400, 370], [900, 370]], labels=[1, 1])
results = model(points=[[400, 370], [900, 370]], labels=[1, 1])
# Run inference with multiple points prompt per object
results = predictor(points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
# Run inference with negative points prompt
results = predictor(points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
```
!!! example "Segment everything"

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.12"
__version__ = "8.3.13"
import os

@ -235,7 +235,42 @@ class Predictor(BasePredictor):
"""
features = self.get_im_features(im) if self.features is None else self.features
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
points = (points, labels) if points is not None else None
# Embed prompts
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)
# Predict masks
pred_masks, pred_scores = self.model.mask_decoder(
image_embeddings=features,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
# `d` could be 1 or 3 depends on `multimask_output`.
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
"""
Prepares and transforms the input prompts for processing based on the destination shape.
Args:
dst_shape (tuple): The target shape (height, width) for the prompts.
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array.
Raises:
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
Returns:
(tuple): A tuple containing transformed bounding boxes, points, labels, and masks.
"""
src_shape = self.batch[1][0].shape[:2]
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
# Transform input prompts
if points is not None:
@ -258,23 +293,7 @@ class Predictor(BasePredictor):
bboxes *= r
if masks is not None:
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
points = (points, labels) if points is not None else None
# Embed prompts
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)
# Predict masks
pred_masks, pred_scores = self.model.mask_decoder(
image_embeddings=features,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
# `d` could be 1 or 3 depends on `multimask_output`.
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
return bboxes, points, labels, masks
def generate(
self,
@ -693,34 +712,7 @@ class SAM2Predictor(Predictor):
"""
features = self.get_im_features(im) if self.features is None else self.features
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
# Transform input prompts
if points is not None:
points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
points = points[None] if points.ndim == 1 else points
# Assuming labels are all positive if users don't pass labels.
if labels is None:
labels = torch.ones(points.shape[0])
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
points *= r
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
points, labels = points[:, None], labels[:, None]
if bboxes is not None:
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
bboxes = bboxes.view(-1, 2, 2) * r
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
# NOTE: merge "boxes" and "points" into a single "points" input
# (where boxes are added at the beginning) to model.sam_prompt_encoder
if points is not None:
points = torch.cat([bboxes, points], dim=1)
labels = torch.cat([bbox_labels, labels], dim=1)
else:
points, labels = bboxes, bbox_labels
if masks is not None:
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
points = (points, labels) if points is not None else None
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
@ -744,6 +736,36 @@ class SAM2Predictor(Predictor):
# `d` could be 1 or 3 depends on `multimask_output`.
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
"""
Prepares and transforms the input prompts for processing based on the destination shape.
Args:
dst_shape (tuple): The target shape (height, width) for the prompts.
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array.
Raises:
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
Returns:
(tuple): A tuple containing transformed bounding boxes, points, labels, and masks.
"""
bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks)
if bboxes is not None:
bboxes = bboxes.view(-1, 2, 2)
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
# NOTE: merge "boxes" and "points" into a single "points" input
# (where boxes are added at the beginning) to model.sam_prompt_encoder
if points is not None:
points = torch.cat([bboxes, points], dim=1)
labels = torch.cat([bbox_labels, labels], dim=1)
else:
points, labels = bboxes, bbox_labels
return bboxes, points, labels, masks
def set_image(self, image):
"""
Preprocesses and sets a single image for inference using the SAM2 model.

Loading…
Cancel
Save