|
|
@ -102,28 +102,23 @@ class SAM2Predictor(Predictor): |
|
|
|
if bboxes is not None: |
|
|
|
if bboxes is not None: |
|
|
|
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device) |
|
|
|
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device) |
|
|
|
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes |
|
|
|
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes |
|
|
|
bboxes *= r |
|
|
|
bboxes = bboxes.view(-1, 2, 2) * r |
|
|
|
|
|
|
|
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1) |
|
|
|
|
|
|
|
# NOTE: merge "boxes" and "points" into a single "points" input |
|
|
|
|
|
|
|
# (where boxes are added at the beginning) to model.sam_prompt_encoder |
|
|
|
|
|
|
|
if points is not None: |
|
|
|
|
|
|
|
points = torch.cat([bboxes, points], dim=1) |
|
|
|
|
|
|
|
labels = torch.cat([bbox_labels, labels], dim=1) |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
points, labels = bboxes, bbox_labels |
|
|
|
if masks is not None: |
|
|
|
if masks is not None: |
|
|
|
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) |
|
|
|
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
points = (points, labels) if points is not None else None |
|
|
|
points = (points, labels) if points is not None else None |
|
|
|
# TODO: Embed prompts |
|
|
|
|
|
|
|
# if bboxes is not None: |
|
|
|
|
|
|
|
# box_coords = bboxes.reshape(-1, 2, 2) |
|
|
|
|
|
|
|
# box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=bboxes.device) |
|
|
|
|
|
|
|
# box_labels = box_labels.repeat(bboxes.size(0), 1) |
|
|
|
|
|
|
|
# # we merge "boxes" and "points" into a single "concat_points" input (where |
|
|
|
|
|
|
|
# # boxes are added at the beginning) to sam_prompt_encoder |
|
|
|
|
|
|
|
# if concat_points is not None: |
|
|
|
|
|
|
|
# concat_coords = torch.cat([box_coords, concat_points[0]], dim=1) |
|
|
|
|
|
|
|
# concat_labels = torch.cat([box_labels, concat_points[1]], dim=1) |
|
|
|
|
|
|
|
# concat_points = (concat_coords, concat_labels) |
|
|
|
|
|
|
|
# else: |
|
|
|
|
|
|
|
# concat_points = (box_coords, box_labels) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( |
|
|
|
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( |
|
|
|
points=points, |
|
|
|
points=points, |
|
|
|
boxes=bboxes, |
|
|
|
boxes=None, |
|
|
|
masks=masks, |
|
|
|
masks=masks, |
|
|
|
) |
|
|
|
) |
|
|
|
# Predict masks |
|
|
|
# Predict masks |
|
|
|