|
|
|
@ -6,11 +6,10 @@ |
|
|
|
|
# This source code is licensed under the license found in the |
|
|
|
|
# LICENSE file in the root directory of this source tree. |
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Tuple |
|
|
|
|
from typing import List |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
from torch import nn |
|
|
|
|
from torch.nn import functional as F |
|
|
|
|
|
|
|
|
|
from .decoders import MaskDecoder |
|
|
|
|
from .encoders import ImageEncoderViT, PromptEncoder |
|
|
|
@ -31,6 +30,9 @@ class Sam(nn.Module): |
|
|
|
|
""" |
|
|
|
|
SAM predicts object masks from an image and input prompts. |
|
|
|
|
|
|
|
|
|
Note: |
|
|
|
|
All forward() operations moved to SAMPredictor. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for |
|
|
|
|
efficient mask prediction. |
|
|
|
@ -45,109 +47,3 @@ class Sam(nn.Module): |
|
|
|
|
self.mask_decoder = mask_decoder |
|
|
|
|
self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False) |
|
|
|
|
self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False) |
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
def device(self) -> Any: |
|
|
|
|
return self.pixel_mean.device |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
|
def forward( |
|
|
|
|
self, |
|
|
|
|
batched_input: List[Dict[str, Any]], |
|
|
|
|
multimask_output: bool, |
|
|
|
|
) -> List[Dict[str, torch.Tensor]]: |
|
|
|
|
""" |
|
|
|
|
Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using |
|
|
|
|
SamPredictor is recommended over calling the model directly. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt |
|
|
|
|
key can be excluded if it is not present. |
|
|
|
|
'image': The image as a torch tensor in 3xHxW format, already transformed for input to the model. |
|
|
|
|
'original_size': (tuple(int, int)) The original size of the image before transformation, as (H, W). |
|
|
|
|
'point_coords': (torch.Tensor) Batched point prompts for this image, with shape BxNx2. Already |
|
|
|
|
transformed to the input frame of the model. |
|
|
|
|
'point_labels': (torch.Tensor) Batched labels for point prompts, with shape BxN. |
|
|
|
|
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. Already transformed to the input frame of |
|
|
|
|
the model. |
|
|
|
|
'mask_inputs': (torch.Tensor) Batched mask inputs to the model, in the form Bx1xHxW. |
|
|
|
|
multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single |
|
|
|
|
mask. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
(list(dict)): A list over input images, where each element is as dictionary with the following keys. |
|
|
|
|
'masks': (torch.Tensor) Batched binary mask predictions, with shape BxCxHxW, where B is the number of |
|
|
|
|
input prompts, C is determined by multimask_output, and (H, W) is the original size of the image. |
|
|
|
|
'iou_predictions': (torch.Tensor) The model's predictions of mask quality, in shape BxC. |
|
|
|
|
'low_res_logits': (torch.Tensor) Low resolution logits with shape BxCxHxW, where H=W=256. Can be passed |
|
|
|
|
as mask input to subsequent iterations of prediction. |
|
|
|
|
""" |
|
|
|
|
input_images = torch.stack([self.preprocess(x['image']) for x in batched_input], dim=0) |
|
|
|
|
image_embeddings = self.image_encoder(input_images) |
|
|
|
|
|
|
|
|
|
outputs = [] |
|
|
|
|
for image_record, curr_embedding in zip(batched_input, image_embeddings): |
|
|
|
|
if 'point_coords' in image_record: |
|
|
|
|
points = (image_record['point_coords'], image_record['point_labels']) |
|
|
|
|
else: |
|
|
|
|
points = None |
|
|
|
|
sparse_embeddings, dense_embeddings = self.prompt_encoder( |
|
|
|
|
points=points, |
|
|
|
|
boxes=image_record.get('boxes', None), |
|
|
|
|
masks=image_record.get('mask_inputs', None), |
|
|
|
|
) |
|
|
|
|
low_res_masks, iou_predictions = self.mask_decoder( |
|
|
|
|
image_embeddings=curr_embedding.unsqueeze(0), |
|
|
|
|
image_pe=self.prompt_encoder.get_dense_pe(), |
|
|
|
|
sparse_prompt_embeddings=sparse_embeddings, |
|
|
|
|
dense_prompt_embeddings=dense_embeddings, |
|
|
|
|
multimask_output=multimask_output, |
|
|
|
|
) |
|
|
|
|
masks = self.postprocess_masks( |
|
|
|
|
low_res_masks, |
|
|
|
|
input_size=image_record['image'].shape[-2:], |
|
|
|
|
original_size=image_record['original_size'], |
|
|
|
|
) |
|
|
|
|
masks = masks > self.mask_threshold |
|
|
|
|
outputs.append({ |
|
|
|
|
'masks': masks, |
|
|
|
|
'iou_predictions': iou_predictions, |
|
|
|
|
'low_res_logits': low_res_masks, }) |
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
def postprocess_masks( |
|
|
|
|
self, |
|
|
|
|
masks: torch.Tensor, |
|
|
|
|
input_size: Tuple[int, ...], |
|
|
|
|
original_size: Tuple[int, ...], |
|
|
|
|
) -> torch.Tensor: |
|
|
|
|
""" |
|
|
|
|
Remove padding and upscale masks to the original image size. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
masks (torch.Tensor): Batched masks from the mask_decoder, in BxCxHxW format. |
|
|
|
|
input_size (tuple(int, int)): The size of the model input image, in (H, W) format. Used to remove padding. |
|
|
|
|
original_size (tuple(int, int)): The original image size before resizing for input to the model, in (H, W). |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size. |
|
|
|
|
""" |
|
|
|
|
masks = F.interpolate( |
|
|
|
|
masks, |
|
|
|
|
(self.image_encoder.img_size, self.image_encoder.img_size), |
|
|
|
|
mode='bilinear', |
|
|
|
|
align_corners=False, |
|
|
|
|
) |
|
|
|
|
masks = masks[..., :input_size[0], :input_size[1]] |
|
|
|
|
return F.interpolate(masks, original_size, mode='bilinear', align_corners=False) |
|
|
|
|
|
|
|
|
|
def preprocess(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
"""Normalize pixel values and pad to a square input.""" |
|
|
|
|
# Normalize colors |
|
|
|
|
x = (x - self.pixel_mean) / self.pixel_std |
|
|
|
|
|
|
|
|
|
# Pad |
|
|
|
|
h, w = x.shape[-2:] |
|
|
|
|
padh = self.image_encoder.img_size - h |
|
|
|
|
padw = self.image_encoder.img_size - w |
|
|
|
|
return F.pad(x, (0, padw, 0, padh)) |
|
|
|
|