|
|
|
@ -235,7 +235,42 @@ class Predictor(BasePredictor): |
|
|
|
|
""" |
|
|
|
|
features = self.get_im_features(im) if self.features is None else self.features |
|
|
|
|
|
|
|
|
|
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:] |
|
|
|
|
bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) |
|
|
|
|
points = (points, labels) if points is not None else None |
|
|
|
|
# Embed prompts |
|
|
|
|
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks) |
|
|
|
|
|
|
|
|
|
# Predict masks |
|
|
|
|
pred_masks, pred_scores = self.model.mask_decoder( |
|
|
|
|
image_embeddings=features, |
|
|
|
|
image_pe=self.model.prompt_encoder.get_dense_pe(), |
|
|
|
|
sparse_prompt_embeddings=sparse_embeddings, |
|
|
|
|
dense_prompt_embeddings=dense_embeddings, |
|
|
|
|
multimask_output=multimask_output, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) |
|
|
|
|
# `d` could be 1 or 3 depends on `multimask_output`. |
|
|
|
|
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) |
|
|
|
|
|
|
|
|
|
def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None): |
|
|
|
|
""" |
|
|
|
|
Prepares and transforms the input prompts for processing based on the destination shape. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
dst_shape (tuple): The target shape (height, width) for the prompts. |
|
|
|
|
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). |
|
|
|
|
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. |
|
|
|
|
labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background. |
|
|
|
|
masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array. |
|
|
|
|
|
|
|
|
|
Raises: |
|
|
|
|
AssertionError: If the number of points don't match the number of labels, in case labels were passed. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
(tuple): A tuple containing transformed bounding boxes, points, labels, and masks. |
|
|
|
|
""" |
|
|
|
|
src_shape = self.batch[1][0].shape[:2] |
|
|
|
|
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) |
|
|
|
|
# Transform input prompts |
|
|
|
|
if points is not None: |
|
|
|
@ -258,23 +293,7 @@ class Predictor(BasePredictor): |
|
|
|
|
bboxes *= r |
|
|
|
|
if masks is not None: |
|
|
|
|
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) |
|
|
|
|
|
|
|
|
|
points = (points, labels) if points is not None else None |
|
|
|
|
# Embed prompts |
|
|
|
|
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks) |
|
|
|
|
|
|
|
|
|
# Predict masks |
|
|
|
|
pred_masks, pred_scores = self.model.mask_decoder( |
|
|
|
|
image_embeddings=features, |
|
|
|
|
image_pe=self.model.prompt_encoder.get_dense_pe(), |
|
|
|
|
sparse_prompt_embeddings=sparse_embeddings, |
|
|
|
|
dense_prompt_embeddings=dense_embeddings, |
|
|
|
|
multimask_output=multimask_output, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) |
|
|
|
|
# `d` could be 1 or 3 depends on `multimask_output`. |
|
|
|
|
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) |
|
|
|
|
return bboxes, points, labels, masks |
|
|
|
|
|
|
|
|
|
def generate( |
|
|
|
|
self, |
|
|
|
@ -693,34 +712,7 @@ class SAM2Predictor(Predictor): |
|
|
|
|
""" |
|
|
|
|
features = self.get_im_features(im) if self.features is None else self.features |
|
|
|
|
|
|
|
|
|
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:] |
|
|
|
|
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) |
|
|
|
|
# Transform input prompts |
|
|
|
|
if points is not None: |
|
|
|
|
points = torch.as_tensor(points, dtype=torch.float32, device=self.device) |
|
|
|
|
points = points[None] if points.ndim == 1 else points |
|
|
|
|
# Assuming labels are all positive if users don't pass labels. |
|
|
|
|
if labels is None: |
|
|
|
|
labels = torch.ones(points.shape[0]) |
|
|
|
|
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device) |
|
|
|
|
points *= r |
|
|
|
|
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1) |
|
|
|
|
points, labels = points[:, None], labels[:, None] |
|
|
|
|
if bboxes is not None: |
|
|
|
|
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device) |
|
|
|
|
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes |
|
|
|
|
bboxes = bboxes.view(-1, 2, 2) * r |
|
|
|
|
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1) |
|
|
|
|
# NOTE: merge "boxes" and "points" into a single "points" input |
|
|
|
|
# (where boxes are added at the beginning) to model.sam_prompt_encoder |
|
|
|
|
if points is not None: |
|
|
|
|
points = torch.cat([bboxes, points], dim=1) |
|
|
|
|
labels = torch.cat([bbox_labels, labels], dim=1) |
|
|
|
|
else: |
|
|
|
|
points, labels = bboxes, bbox_labels |
|
|
|
|
if masks is not None: |
|
|
|
|
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) |
|
|
|
|
|
|
|
|
|
bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) |
|
|
|
|
points = (points, labels) if points is not None else None |
|
|
|
|
|
|
|
|
|
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( |
|
|
|
@ -744,6 +736,36 @@ class SAM2Predictor(Predictor): |
|
|
|
|
# `d` could be 1 or 3 depends on `multimask_output`. |
|
|
|
|
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) |
|
|
|
|
|
|
|
|
|
def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None): |
|
|
|
|
""" |
|
|
|
|
Prepares and transforms the input prompts for processing based on the destination shape. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
dst_shape (tuple): The target shape (height, width) for the prompts. |
|
|
|
|
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). |
|
|
|
|
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. |
|
|
|
|
labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background. |
|
|
|
|
masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array. |
|
|
|
|
|
|
|
|
|
Raises: |
|
|
|
|
AssertionError: If the number of points don't match the number of labels, in case labels were passed. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
(tuple): A tuple containing transformed bounding boxes, points, labels, and masks. |
|
|
|
|
""" |
|
|
|
|
bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks) |
|
|
|
|
if bboxes is not None: |
|
|
|
|
bboxes = bboxes.view(-1, 2, 2) |
|
|
|
|
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1) |
|
|
|
|
# NOTE: merge "boxes" and "points" into a single "points" input |
|
|
|
|
# (where boxes are added at the beginning) to model.sam_prompt_encoder |
|
|
|
|
if points is not None: |
|
|
|
|
points = torch.cat([bboxes, points], dim=1) |
|
|
|
|
labels = torch.cat([bbox_labels, labels], dim=1) |
|
|
|
|
else: |
|
|
|
|
points, labels = bboxes, bbox_labels |
|
|
|
|
return bboxes, points, labels, masks |
|
|
|
|
|
|
|
|
|
def set_image(self, image): |
|
|
|
|
""" |
|
|
|
|
Preprocesses and sets a single image for inference using the SAM2 model. |
|
|
|
|