`ultralytics 8.3.12` SAM and SAM2 multi-point prompts (#16643)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/16881/head v8.3.12
Mohammed Yasin 1 month ago committed by GitHub
parent 1c4d788aa1
commit b89d6f4070
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 24
      docs/en/models/mobile-sam.md
  2. 30
      docs/en/models/sam.md
  3. 5
      tests/test_cli.py
  4. 14
      tests/test_cuda.py
  5. 2
      ultralytics/__init__.py
  6. 17
      ultralytics/models/sam/predict.py

@ -90,8 +90,17 @@ You can download the model [here](https://github.com/ChaoningZhang/MobileSAM/blo
# Load the model
model = SAM("mobile_sam.pt")
# Predict a segment based on a point prompt
# Predict a segment based on a single point prompt
model.predict("ultralytics/assets/zidane.jpg", points=[900, 370], labels=[1])
# Predict multiple segments based on multiple points prompt
model.predict("ultralytics/assets/zidane.jpg", points=[[400, 370], [900, 370]], labels=[1, 1])
# Predict a segment based on multiple points prompt per object
model.predict("ultralytics/assets/zidane.jpg", points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
# Predict a segment using both positive and negative prompts.
model.predict("ultralytics/assets/zidane.jpg", points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
```
### Box Prompt
@ -106,8 +115,17 @@ You can download the model [here](https://github.com/ChaoningZhang/MobileSAM/blo
# Load the model
model = SAM("mobile_sam.pt")
# Predict a segment based on a box prompt
model.predict("ultralytics/assets/zidane.jpg", bboxes=[439, 437, 524, 709])
# Predict a segment based on a single point prompt
model.predict("ultralytics/assets/zidane.jpg", points=[900, 370], labels=[1])
# Predict mutiple segments based on multiple points prompt
model.predict("ultralytics/assets/zidane.jpg", points=[[400, 370], [900, 370]], labels=[1, 1])
# Predict a segment based on multiple points prompt per object
model.predict("ultralytics/assets/zidane.jpg", points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
# Predict a segment using both positive and negative prompts.
model.predict("ultralytics/assets/zidane.jpg", points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
```
We have implemented `MobileSAM` and `SAM` using the same API. For more usage information, please see the [SAM page](sam.md).

@ -58,8 +58,17 @@ The Segment Anything Model can be employed for a multitude of downstream tasks t
# Run inference with bboxes prompt
results = model("ultralytics/assets/zidane.jpg", bboxes=[439, 437, 524, 709])
# Run inference with points prompt
results = model("ultralytics/assets/zidane.jpg", points=[900, 370], labels=[1])
# Run inference with single point
results = predictor(points=[900, 370], labels=[1])
# Run inference with multiple points
results = predictor(points=[[400, 370], [900, 370]], labels=[1, 1])
# Run inference with multiple points prompt per object
results = predictor(points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
# Run inference with negative points prompt
results = predictor(points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
```
!!! example "Segment everything"
@ -107,8 +116,16 @@ The Segment Anything Model can be employed for a multitude of downstream tasks t
predictor.set_image("ultralytics/assets/zidane.jpg") # set with image file
predictor.set_image(cv2.imread("ultralytics/assets/zidane.jpg")) # set with np.ndarray
results = predictor(bboxes=[439, 437, 524, 709])
# Run inference with single point prompt
results = predictor(points=[900, 370], labels=[1])
# Run inference with multiple points prompt
results = predictor(points=[[400, 370], [900, 370]], labels=[[1, 1]])
# Run inference with negative points prompt
results = predictor(points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
# Reset image
predictor.reset_image()
```
@ -245,6 +262,15 @@ model("ultralytics/assets/zidane.jpg", bboxes=[439, 437, 524, 709])
# Segment with points prompt
model("ultralytics/assets/zidane.jpg", points=[900, 370], labels=[1])
# Segment with multiple points prompt
model("ultralytics/assets/zidane.jpg", points=[[400, 370], [900, 370]], labels=[[1, 1]])
# Segment with multiple points prompt per object
model("ultralytics/assets/zidane.jpg", points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
# Segment with negative points prompt.
model("ultralytics/assets/zidane.jpg", points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
```
Alternatively, you can run inference with SAM in the command line interface (CLI):

@ -97,9 +97,12 @@ def test_mobilesam():
# Source
source = ASSETS / "zidane.jpg"
# Predict a segment based on a point prompt
# Predict a segment based on a 1D point prompt and 1D labels.
model.predict(source, points=[900, 370], labels=[1])
# Predict a segment based on 3D points and 2D labels (multiple points per object).
model.predict(source, points=[[[900, 370], [1000, 100]]], labels=[[1, 1]])
# Predict a segment based on a box prompt
model.predict(source, bboxes=[439, 437, 524, 709], save=True)

@ -127,9 +127,21 @@ def test_predict_sam():
# Run inference with bboxes prompt
model(SOURCE, bboxes=[439, 437, 524, 709], device=0)
# Run inference with points prompt
# Run inference with no labels
model(ASSETS / "zidane.jpg", points=[900, 370], device=0)
# Run inference with 1D points and 1D labels
model(ASSETS / "zidane.jpg", points=[900, 370], labels=[1], device=0)
# Run inference with 2D points and 1D labels
model(ASSETS / "zidane.jpg", points=[[900, 370]], labels=[1], device=0)
# Run inference with multiple 2D points and 1D labels
model(ASSETS / "zidane.jpg", points=[[400, 370], [900, 370]], labels=[1, 1], device=0)
# Run inference with 3D points and 2D labels (multiple points per object)
model(ASSETS / "zidane.jpg", points=[[[900, 370], [1000, 100]]], labels=[[1, 1]], device=0)
# Create SAMPredictor
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024, model=WEIGHTS_DIR / "mobile_sam.pt")
predictor = SAMPredictor(overrides=overrides)

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.11"
__version__ = "8.3.12"
import os

@ -213,11 +213,14 @@ class Predictor(BasePredictor):
Args:
im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
labels (np.ndarray | List | None): Point prompt labels with shape (N,). 1 for foreground, 0 for background.
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.
multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
Raises:
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
Returns:
(tuple): Tuple containing:
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
@ -240,11 +243,15 @@ class Predictor(BasePredictor):
points = points[None] if points.ndim == 1 else points
# Assuming labels are all positive if users don't pass labels.
if labels is None:
labels = np.ones(points.shape[0])
labels = np.ones(points.shape[:-1])
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
assert (
points.shape[-2] == labels.shape[-1]
), f"Number of points {points.shape[-2]} should match number of labels {labels.shape[-1]}."
points *= r
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
points, labels = points[:, None, :], labels[:, None]
if points.ndim == 2:
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
points, labels = points[:, None, :], labels[:, None]
if bboxes is not None:
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes

Loading…
Cancel
Save