|
|
|
@ -30,11 +30,10 @@ class Sam(nn.Module): |
|
|
|
|
SAM predicts object masks from an image and input prompts. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
image_encoder (ImageEncoderViT): The backbone used to encode the |
|
|
|
|
image into image embeddings that allow for efficient mask prediction. |
|
|
|
|
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for |
|
|
|
|
efficient mask prediction. |
|
|
|
|
prompt_encoder (PromptEncoder): Encodes various types of input prompts. |
|
|
|
|
mask_decoder (MaskDecoder): Predicts masks from the image embeddings |
|
|
|
|
and encoded prompts. |
|
|
|
|
mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts. |
|
|
|
|
pixel_mean (list(float)): Mean values for normalizing pixels in the input image. |
|
|
|
|
pixel_std (list(float)): Std values for normalizing pixels in the input image. |
|
|
|
|
""" |
|
|
|
@ -65,34 +64,25 @@ class Sam(nn.Module): |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt |
|
|
|
|
key can be excluded if it is not present. |
|
|
|
|
'image': The image as a torch tensor in 3xHxW format, |
|
|
|
|
already transformed for input to the model. |
|
|
|
|
'original_size': (tuple(int, int)) The original size of |
|
|
|
|
the image before transformation, as (H, W). |
|
|
|
|
'point_coords': (torch.Tensor) Batched point prompts for |
|
|
|
|
this image, with shape BxNx2. Already transformed to the |
|
|
|
|
input frame of the model. |
|
|
|
|
'point_labels': (torch.Tensor) Batched labels for point prompts, |
|
|
|
|
with shape BxN. |
|
|
|
|
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. |
|
|
|
|
Already transformed to the input frame of the model. |
|
|
|
|
'mask_inputs': (torch.Tensor) Batched mask inputs to the model, |
|
|
|
|
in the form Bx1xHxW. |
|
|
|
|
key can be excluded if it is not present. |
|
|
|
|
'image': The image as a torch tensor in 3xHxW format, already transformed for input to the model. |
|
|
|
|
'original_size': (tuple(int, int)) The original size of the image before transformation, as (H, W). |
|
|
|
|
'point_coords': (torch.Tensor) Batched point prompts for this image, with shape BxNx2. Already |
|
|
|
|
transformed to the input frame of the model. |
|
|
|
|
'point_labels': (torch.Tensor) Batched labels for point prompts, with shape BxN. |
|
|
|
|
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. Already transformed to the input frame of |
|
|
|
|
the model. |
|
|
|
|
'mask_inputs': (torch.Tensor) Batched mask inputs to the model, in the form Bx1xHxW. |
|
|
|
|
multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single |
|
|
|
|
mask. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
(list(dict)): A list over input images, where each element is as dictionary with the following keys. |
|
|
|
|
'masks': (torch.Tensor) Batched binary mask predictions, |
|
|
|
|
with shape BxCxHxW, where B is the number of input prompts, |
|
|
|
|
C is determined by multimask_output, and (H, W) is the |
|
|
|
|
original size of the image. |
|
|
|
|
'iou_predictions': (torch.Tensor) The model's predictions |
|
|
|
|
of mask quality, in shape BxC. |
|
|
|
|
'low_res_logits': (torch.Tensor) Low resolution logits with |
|
|
|
|
shape BxCxHxW, where H=W=256. Can be passed as mask input |
|
|
|
|
to subsequent iterations of prediction. |
|
|
|
|
'masks': (torch.Tensor) Batched binary mask predictions, with shape BxCxHxW, where B is the number of |
|
|
|
|
input prompts, C is determined by multimask_output, and (H, W) is the original size of the image. |
|
|
|
|
'iou_predictions': (torch.Tensor) The model's predictions of mask quality, in shape BxC. |
|
|
|
|
'low_res_logits': (torch.Tensor) Low resolution logits with shape BxCxHxW, where H=W=256. Can be passed |
|
|
|
|
as mask input to subsequent iterations of prediction. |
|
|
|
|
""" |
|
|
|
|
input_images = torch.stack([self.preprocess(x['image']) for x in batched_input], dim=0) |
|
|
|
|
image_embeddings = self.image_encoder(input_images) |
|
|
|
@ -137,16 +127,12 @@ class Sam(nn.Module): |
|
|
|
|
Remove padding and upscale masks to the original image size. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
masks (torch.Tensor): Batched masks from the mask_decoder, |
|
|
|
|
in BxCxHxW format. |
|
|
|
|
input_size (tuple(int, int)): The size of the image input to the |
|
|
|
|
model, in (H, W) format. Used to remove padding. |
|
|
|
|
original_size (tuple(int, int)): The original size of the image |
|
|
|
|
before resizing for input to the model, in (H, W) format. |
|
|
|
|
masks (torch.Tensor): Batched masks from the mask_decoder, in BxCxHxW format. |
|
|
|
|
input_size (tuple(int, int)): The size of the model input image, in (H, W) format. Used to remove padding. |
|
|
|
|
original_size (tuple(int, int)): The original image size before resizing for input to the model, in (H, W). |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W) |
|
|
|
|
is given by original_size. |
|
|
|
|
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size. |
|
|
|
|
""" |
|
|
|
|
masks = F.interpolate( |
|
|
|
|
masks, |
|
|
|
|