diff --git a/docs/en/models/sam-2.md b/docs/en/models/sam-2.md index 562a130029..5120498e24 100644 --- a/docs/en/models/sam-2.md +++ b/docs/en/models/sam-2.md @@ -142,11 +142,20 @@ SAM 2 can be utilized across a broad spectrum of tasks, including real-time vide # Display model information (optional) model.info() - # Segment with bounding box prompt + # Run inference with bboxes prompt results = model("path/to/image.jpg", bboxes=[100, 100, 200, 200]) - # Segment with point prompt - results = model("path/to/image.jpg", points=[150, 150], labels=[1]) + # Run inference with single point + results = model(points=[900, 370], labels=[1]) + + # Run inference with multiple points + results = model(points=[[400, 370], [900, 370]], labels=[1, 1]) + + # Run inference with multiple points prompt per object + results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 1]]) + + # Run inference with negative points prompt + results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 0]]) ``` #### Segment Everything diff --git a/docs/en/models/sam.md b/docs/en/models/sam.md index 1a5c0db4a7..f9acad72df 100644 --- a/docs/en/models/sam.md +++ b/docs/en/models/sam.md @@ -59,16 +59,16 @@ The Segment Anything Model can be employed for a multitude of downstream tasks t results = model("ultralytics/assets/zidane.jpg", bboxes=[439, 437, 524, 709]) # Run inference with single point - results = predictor(points=[900, 370], labels=[1]) + results = model(points=[900, 370], labels=[1]) # Run inference with multiple points - results = predictor(points=[[400, 370], [900, 370]], labels=[1, 1]) + results = model(points=[[400, 370], [900, 370]], labels=[1, 1]) # Run inference with multiple points prompt per object - results = predictor(points=[[[400, 370], [900, 370]]], labels=[[1, 1]]) + results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 1]]) # Run inference with negative points prompt - results = predictor(points=[[[400, 370], [900, 370]]], labels=[[1, 0]]) + results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 0]]) ``` !!! example "Segment everything" diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 5360c25e18..06ee07e308 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.3.12" +__version__ = "8.3.13" import os diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py index 978f7cfd68..4002e092b6 100644 --- a/ultralytics/models/sam/predict.py +++ b/ultralytics/models/sam/predict.py @@ -235,7 +235,42 @@ class Predictor(BasePredictor): """ features = self.get_im_features(im) if self.features is None else self.features - src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:] + bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + points = (points, labels) if points is not None else None + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks) + + # Predict masks + pred_masks, pred_scores = self.model.mask_decoder( + image_embeddings=features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) + # `d` could be 1 or 3 depends on `multimask_output`. + return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) + + def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None): + """ + Prepares and transforms the input prompts for processing based on the destination shape. + + Args: + dst_shape (tuple): The target shape (height, width) for the prompts. + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background. + masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (tuple): A tuple containing transformed bounding boxes, points, labels, and masks. + """ + src_shape = self.batch[1][0].shape[:2] r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) # Transform input prompts if points is not None: @@ -258,23 +293,7 @@ class Predictor(BasePredictor): bboxes *= r if masks is not None: masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) - - points = (points, labels) if points is not None else None - # Embed prompts - sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks) - - # Predict masks - pred_masks, pred_scores = self.model.mask_decoder( - image_embeddings=features, - image_pe=self.model.prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=multimask_output, - ) - - # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) - # `d` could be 1 or 3 depends on `multimask_output`. - return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) + return bboxes, points, labels, masks def generate( self, @@ -693,34 +712,7 @@ class SAM2Predictor(Predictor): """ features = self.get_im_features(im) if self.features is None else self.features - src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:] - r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) - # Transform input prompts - if points is not None: - points = torch.as_tensor(points, dtype=torch.float32, device=self.device) - points = points[None] if points.ndim == 1 else points - # Assuming labels are all positive if users don't pass labels. - if labels is None: - labels = torch.ones(points.shape[0]) - labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device) - points *= r - # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1) - points, labels = points[:, None], labels[:, None] - if bboxes is not None: - bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device) - bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes - bboxes = bboxes.view(-1, 2, 2) * r - bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1) - # NOTE: merge "boxes" and "points" into a single "points" input - # (where boxes are added at the beginning) to model.sam_prompt_encoder - if points is not None: - points = torch.cat([bboxes, points], dim=1) - labels = torch.cat([bbox_labels, labels], dim=1) - else: - points, labels = bboxes, bbox_labels - if masks is not None: - masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) - + bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) points = (points, labels) if points is not None else None sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( @@ -744,6 +736,36 @@ class SAM2Predictor(Predictor): # `d` could be 1 or 3 depends on `multimask_output`. return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) + def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None): + """ + Prepares and transforms the input prompts for processing based on the destination shape. + + Args: + dst_shape (tuple): The target shape (height, width) for the prompts. + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background. + masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (tuple): A tuple containing transformed bounding boxes, points, labels, and masks. + """ + bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks) + if bboxes is not None: + bboxes = bboxes.view(-1, 2, 2) + bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1) + # NOTE: merge "boxes" and "points" into a single "points" input + # (where boxes are added at the beginning) to model.sam_prompt_encoder + if points is not None: + points = torch.cat([bboxes, points], dim=1) + labels = torch.cat([bbox_labels, labels], dim=1) + else: + points, labels = bboxes, bbox_labels + return bboxes, points, labels, masks + def set_image(self, image): """ Preprocesses and sets a single image for inference using the SAM2 model.