commit
8cee9df7f2
44 changed files with 4516 additions and 3598 deletions
@ -1,6 +1,6 @@ |
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
# Ultralytics YOLO 🚀, AGPL-3.0 license |
||||||
|
|
||||||
from .model import SAM |
from .model import SAM |
||||||
from .predict import Predictor |
from .predict import Predictor, SAM2Predictor |
||||||
|
|
||||||
__all__ = "SAM", "Predictor" # tuple or list |
__all__ = "SAM", "Predictor", "SAM2Predictor" # tuple or list |
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,6 +0,0 @@ |
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
|
||||||
|
|
||||||
from .model import SAM2 |
|
||||||
from .predict import SAM2Predictor |
|
||||||
|
|
||||||
__all__ = "SAM2", "SAM2Predictor" # tuple or list |
|
@ -1,156 +0,0 @@ |
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
|
||||||
|
|
||||||
import torch |
|
||||||
|
|
||||||
from ultralytics.utils.downloads import attempt_download_asset |
|
||||||
|
|
||||||
from .modules.encoders import FpnNeck, Hiera, ImageEncoder, MemoryEncoder |
|
||||||
from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer |
|
||||||
from .modules.sam2 import SAM2Model |
|
||||||
|
|
||||||
|
|
||||||
def build_sam2_t(checkpoint=None): |
|
||||||
"""Build and return a Segment Anything Model (SAM2) tiny-size model with specified architecture parameters.""" |
|
||||||
return _build_sam2( |
|
||||||
encoder_embed_dim=96, |
|
||||||
encoder_stages=[1, 2, 7, 2], |
|
||||||
encoder_num_heads=1, |
|
||||||
encoder_global_att_blocks=[5, 7, 9], |
|
||||||
encoder_window_spec=[8, 4, 14, 7], |
|
||||||
encoder_backbone_channel_list=[768, 384, 192, 96], |
|
||||||
checkpoint=checkpoint, |
|
||||||
) |
|
||||||
|
|
||||||
|
|
||||||
def build_sam2_s(checkpoint=None): |
|
||||||
"""Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters.""" |
|
||||||
return _build_sam2( |
|
||||||
encoder_embed_dim=96, |
|
||||||
encoder_stages=[1, 2, 11, 2], |
|
||||||
encoder_num_heads=1, |
|
||||||
encoder_global_att_blocks=[7, 10, 13], |
|
||||||
encoder_window_spec=[8, 4, 14, 7], |
|
||||||
encoder_backbone_channel_list=[768, 384, 192, 96], |
|
||||||
checkpoint=checkpoint, |
|
||||||
) |
|
||||||
|
|
||||||
|
|
||||||
def build_sam2_b(checkpoint=None): |
|
||||||
"""Builds and returns a Segment Anything Model (SAM2) base-size model with specified architecture parameters.""" |
|
||||||
return _build_sam2( |
|
||||||
encoder_embed_dim=112, |
|
||||||
encoder_stages=[2, 3, 16, 3], |
|
||||||
encoder_num_heads=2, |
|
||||||
encoder_global_att_blocks=[12, 16, 20], |
|
||||||
encoder_window_spec=[8, 4, 14, 7], |
|
||||||
encoder_window_spatial_size=[14, 14], |
|
||||||
encoder_backbone_channel_list=[896, 448, 224, 112], |
|
||||||
checkpoint=checkpoint, |
|
||||||
) |
|
||||||
|
|
||||||
|
|
||||||
def build_sam2_l(checkpoint=None): |
|
||||||
"""Build and return a Segment Anything Model (SAM2) large-size model with specified architecture parameters.""" |
|
||||||
return _build_sam2( |
|
||||||
encoder_embed_dim=144, |
|
||||||
encoder_stages=[2, 6, 36, 4], |
|
||||||
encoder_num_heads=2, |
|
||||||
encoder_global_att_blocks=[23, 33, 43], |
|
||||||
encoder_window_spec=[8, 4, 16, 8], |
|
||||||
encoder_backbone_channel_list=[1152, 576, 288, 144], |
|
||||||
checkpoint=checkpoint, |
|
||||||
) |
|
||||||
|
|
||||||
|
|
||||||
def _build_sam2( |
|
||||||
encoder_embed_dim=1280, |
|
||||||
encoder_stages=[2, 6, 36, 4], |
|
||||||
encoder_num_heads=2, |
|
||||||
encoder_global_att_blocks=[7, 15, 23, 31], |
|
||||||
encoder_backbone_channel_list=[1152, 576, 288, 144], |
|
||||||
encoder_window_spatial_size=[7, 7], |
|
||||||
encoder_window_spec=[8, 4, 16, 8], |
|
||||||
checkpoint=None, |
|
||||||
): |
|
||||||
"""Builds a SAM2 model with specified architecture parameters and optional checkpoint loading.""" |
|
||||||
image_encoder = ImageEncoder( |
|
||||||
trunk=Hiera( |
|
||||||
embed_dim=encoder_embed_dim, |
|
||||||
num_heads=encoder_num_heads, |
|
||||||
stages=encoder_stages, |
|
||||||
global_att_blocks=encoder_global_att_blocks, |
|
||||||
window_pos_embed_bkg_spatial_size=encoder_window_spatial_size, |
|
||||||
window_spec=encoder_window_spec, |
|
||||||
), |
|
||||||
neck=FpnNeck( |
|
||||||
d_model=256, |
|
||||||
backbone_channel_list=encoder_backbone_channel_list, |
|
||||||
fpn_top_down_levels=[2, 3], |
|
||||||
fpn_interp_model="nearest", |
|
||||||
), |
|
||||||
scalp=1, |
|
||||||
) |
|
||||||
memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer()) |
|
||||||
memory_encoder = MemoryEncoder(out_dim=64) |
|
||||||
|
|
||||||
sam2 = SAM2Model( |
|
||||||
image_encoder=image_encoder, |
|
||||||
memory_attention=memory_attention, |
|
||||||
memory_encoder=memory_encoder, |
|
||||||
num_maskmem=7, |
|
||||||
image_size=1024, |
|
||||||
sigmoid_scale_for_mem_enc=20.0, |
|
||||||
sigmoid_bias_for_mem_enc=-10.0, |
|
||||||
use_mask_input_as_output_without_sam=True, |
|
||||||
directly_add_no_mem_embed=True, |
|
||||||
use_high_res_features_in_sam=True, |
|
||||||
multimask_output_in_sam=True, |
|
||||||
iou_prediction_use_sigmoid=True, |
|
||||||
use_obj_ptrs_in_encoder=True, |
|
||||||
add_tpos_enc_to_obj_ptrs=True, |
|
||||||
only_obj_ptrs_in_the_past_for_eval=True, |
|
||||||
pred_obj_scores=True, |
|
||||||
pred_obj_scores_mlp=True, |
|
||||||
fixed_no_obj_ptr=True, |
|
||||||
multimask_output_for_tracking=True, |
|
||||||
use_multimask_token_for_obj_ptr=True, |
|
||||||
multimask_min_pt_num=0, |
|
||||||
multimask_max_pt_num=1, |
|
||||||
use_mlp_for_obj_ptr_proj=True, |
|
||||||
compile_image_encoder=False, |
|
||||||
sam_mask_decoder_extra_args=dict( |
|
||||||
dynamic_multimask_via_stability=True, |
|
||||||
dynamic_multimask_stability_delta=0.05, |
|
||||||
dynamic_multimask_stability_thresh=0.98, |
|
||||||
), |
|
||||||
) |
|
||||||
|
|
||||||
if checkpoint is not None: |
|
||||||
checkpoint = attempt_download_asset(checkpoint) |
|
||||||
with open(checkpoint, "rb") as f: |
|
||||||
state_dict = torch.load(f)["model"] |
|
||||||
sam2.load_state_dict(state_dict) |
|
||||||
sam2.eval() |
|
||||||
return sam2 |
|
||||||
|
|
||||||
|
|
||||||
sam_model_map = { |
|
||||||
"sam2_t.pt": build_sam2_t, |
|
||||||
"sam2_s.pt": build_sam2_s, |
|
||||||
"sam2_b.pt": build_sam2_b, |
|
||||||
"sam2_l.pt": build_sam2_l, |
|
||||||
} |
|
||||||
|
|
||||||
|
|
||||||
def build_sam2(ckpt="sam_b.pt"): |
|
||||||
"""Constructs a Segment Anything Model (SAM2) based on the specified checkpoint, with various size options.""" |
|
||||||
model_builder = None |
|
||||||
ckpt = str(ckpt) # to allow Path ckpt types |
|
||||||
for k in sam_model_map.keys(): |
|
||||||
if ckpt.endswith(k): |
|
||||||
model_builder = sam_model_map.get(k) |
|
||||||
|
|
||||||
if not model_builder: |
|
||||||
raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}") |
|
||||||
|
|
||||||
return model_builder(ckpt) |
|
@ -1,97 +0,0 @@ |
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
|
||||||
""" |
|
||||||
SAM2 model interface. |
|
||||||
|
|
||||||
This module provides an interface to the Segment Anything Model (SAM2) from Ultralytics, designed for real-time image |
|
||||||
segmentation tasks. The SAM2 model allows for promptable segmentation with unparalleled versatility in image analysis, |
|
||||||
and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new |
|
||||||
image distributions and tasks without prior knowledge. |
|
||||||
|
|
||||||
Key Features: |
|
||||||
- Promptable segmentation |
|
||||||
- Real-time performance |
|
||||||
- Zero-shot transfer capabilities |
|
||||||
- Trained on SA-1B dataset |
|
||||||
""" |
|
||||||
|
|
||||||
from ultralytics.models.sam import SAM |
|
||||||
|
|
||||||
from .build import build_sam2 |
|
||||||
from .predict import SAM2Predictor |
|
||||||
|
|
||||||
|
|
||||||
class SAM2(SAM): |
|
||||||
""" |
|
||||||
SAM2 class for real-time image segmentation using the Segment Anything Model (SAM2). |
|
||||||
|
|
||||||
This class extends the SAM base class, providing an interface to the SAM2 model for promptable segmentation |
|
||||||
tasks. It supports loading pre-trained weights and offers zero-shot performance capabilities. |
|
||||||
|
|
||||||
Attributes: |
|
||||||
model (torch.nn.Module): The loaded SAM2 model. |
|
||||||
task_map (Dict[str, Type[SAM2Predictor]]): Mapping of 'segment' task to SAM2Predictor. |
|
||||||
|
|
||||||
Methods: |
|
||||||
__init__: Initializes the SAM2 model with pre-trained weights. |
|
||||||
_load: Loads specified weights into the SAM2 model. |
|
||||||
|
|
||||||
Examples: |
|
||||||
>>> sam2 = SAM2("sam2_b.pt") |
|
||||||
>>> sam2._load('path/to/sam2_weights.pt') |
|
||||||
>>> task_map = sam2.task_map |
|
||||||
>>> print(task_map) |
|
||||||
{'segment': SAM2Predictor} |
|
||||||
|
|
||||||
Notes: |
|
||||||
- Supports .pt and .pth file extensions for model weights. |
|
||||||
- Offers zero-shot transfer capabilities for new image distributions and tasks. |
|
||||||
""" |
|
||||||
|
|
||||||
def __init__(self, model="sam2_b.pt") -> None: |
|
||||||
""" |
|
||||||
Initializes the SAM2 model with a pre-trained model file. |
|
||||||
|
|
||||||
Args: |
|
||||||
model (str): Path to the pre-trained SAM2 model file. File should have a .pt or .pth extension. |
|
||||||
|
|
||||||
Raises: |
|
||||||
NotImplementedError: If the model file extension is not .pt or .pth. |
|
||||||
|
|
||||||
Examples: |
|
||||||
>>> sam2 = SAM2("sam2_b.pt") |
|
||||||
""" |
|
||||||
super().__init__(model=model) |
|
||||||
|
|
||||||
def _load(self, weights: str, task=None): |
|
||||||
""" |
|
||||||
Loads the specified weights into the SAM2 model. |
|
||||||
|
|
||||||
This method is responsible for loading pre-trained weights into the SAM2 model. It supports loading |
|
||||||
weights from files with .pt or .pth extensions. |
|
||||||
|
|
||||||
Args: |
|
||||||
weights (str): Path to the weights file. Should be a file with .pt or .pth extension. |
|
||||||
task (str | None): Task name. If provided, it may be used to configure model-specific settings. |
|
||||||
|
|
||||||
Examples: |
|
||||||
>>> sam2_model = SAM2() |
|
||||||
>>> sam2_model._load('path/to/sam2_weights.pt') |
|
||||||
""" |
|
||||||
self.model = build_sam2(weights) |
|
||||||
|
|
||||||
@property |
|
||||||
def task_map(self): |
|
||||||
""" |
|
||||||
Provides a mapping from the 'segment' task to its corresponding 'Predictor'. |
|
||||||
|
|
||||||
Returns: |
|
||||||
(Dict[str, Type[SAM2Predictor]]): A dictionary mapping the 'segment' task to its corresponding |
|
||||||
SAM2Predictor class. |
|
||||||
|
|
||||||
Examples: |
|
||||||
>>> sam2 = SAM2() |
|
||||||
>>> task_map = sam2.task_map |
|
||||||
>>> print(task_map) |
|
||||||
{'segment': SAM2Predictor} |
|
||||||
""" |
|
||||||
return {"segment": {"predictor": SAM2Predictor}} |
|
@ -1 +0,0 @@ |
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
|
@ -1,305 +0,0 @@ |
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Type |
|
||||||
|
|
||||||
import torch |
|
||||||
from torch import nn |
|
||||||
|
|
||||||
from ultralytics.nn.modules import MLP, LayerNorm2d |
|
||||||
|
|
||||||
|
|
||||||
class MaskDecoder(nn.Module): |
|
||||||
"""Transformer-based decoder predicting instance segmentation masks from image and prompt embeddings.""" |
|
||||||
|
|
||||||
def __init__( |
|
||||||
self, |
|
||||||
transformer_dim: int, |
|
||||||
transformer: nn.Module, |
|
||||||
num_multimask_outputs: int = 3, |
|
||||||
activation: Type[nn.Module] = nn.GELU, |
|
||||||
iou_head_depth: int = 3, |
|
||||||
iou_head_hidden_dim: int = 256, |
|
||||||
use_high_res_features: bool = False, |
|
||||||
iou_prediction_use_sigmoid=False, |
|
||||||
dynamic_multimask_via_stability=False, |
|
||||||
dynamic_multimask_stability_delta=0.05, |
|
||||||
dynamic_multimask_stability_thresh=0.98, |
|
||||||
pred_obj_scores: bool = False, |
|
||||||
pred_obj_scores_mlp: bool = False, |
|
||||||
use_multimask_token_for_obj_ptr: bool = False, |
|
||||||
) -> None: |
|
||||||
""" |
|
||||||
Initializes the MaskDecoder module for predicting instance segmentation masks. |
|
||||||
|
|
||||||
Args: |
|
||||||
transformer_dim (int): Channel dimension of the transformer. |
|
||||||
transformer (nn.Module): Transformer used to predict masks. |
|
||||||
num_multimask_outputs (int): Number of masks to predict when disambiguating masks. |
|
||||||
activation (Type[nn.Module]): Type of activation to use when upscaling masks. |
|
||||||
iou_head_depth (int): Depth of the MLP used to predict mask quality. |
|
||||||
iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality. |
|
||||||
use_high_res_features (bool): Whether to use high-resolution features. |
|
||||||
iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction. |
|
||||||
dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability. |
|
||||||
dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability. |
|
||||||
dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability. |
|
||||||
pred_obj_scores (bool): Whether to predict object scores. |
|
||||||
pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction. |
|
||||||
use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer. |
|
||||||
|
|
||||||
Attributes: |
|
||||||
transformer_dim (int): Channel dimension of the transformer. |
|
||||||
transformer (nn.Module): Transformer used to predict masks. |
|
||||||
num_multimask_outputs (int): Number of masks to predict when disambiguating masks. |
|
||||||
iou_token (nn.Embedding): Embedding for IOU token. |
|
||||||
num_mask_tokens (int): Total number of mask tokens. |
|
||||||
mask_tokens (nn.Embedding): Embedding for mask tokens. |
|
||||||
pred_obj_scores (bool): Whether to predict object scores. |
|
||||||
obj_score_token (nn.Embedding): Embedding for object score token. |
|
||||||
use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer. |
|
||||||
output_upscaling (nn.Sequential): Upscaling layers for output. |
|
||||||
use_high_res_features (bool): Whether to use high-resolution features. |
|
||||||
conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0). |
|
||||||
conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1). |
|
||||||
output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks. |
|
||||||
iou_prediction_head (MLP): MLP for IOU prediction. |
|
||||||
pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction. |
|
||||||
dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability. |
|
||||||
dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability. |
|
||||||
""" |
|
||||||
super().__init__() |
|
||||||
self.transformer_dim = transformer_dim |
|
||||||
self.transformer = transformer |
|
||||||
|
|
||||||
self.num_multimask_outputs = num_multimask_outputs |
|
||||||
|
|
||||||
self.iou_token = nn.Embedding(1, transformer_dim) |
|
||||||
self.num_mask_tokens = num_multimask_outputs + 1 |
|
||||||
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) |
|
||||||
|
|
||||||
self.pred_obj_scores = pred_obj_scores |
|
||||||
if self.pred_obj_scores: |
|
||||||
self.obj_score_token = nn.Embedding(1, transformer_dim) |
|
||||||
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr |
|
||||||
|
|
||||||
self.output_upscaling = nn.Sequential( |
|
||||||
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), |
|
||||||
LayerNorm2d(transformer_dim // 4), |
|
||||||
activation(), |
|
||||||
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), |
|
||||||
activation(), |
|
||||||
) |
|
||||||
self.use_high_res_features = use_high_res_features |
|
||||||
if use_high_res_features: |
|
||||||
self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1) |
|
||||||
self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1) |
|
||||||
|
|
||||||
self.output_hypernetworks_mlps = nn.ModuleList( |
|
||||||
[MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)] |
|
||||||
) |
|
||||||
|
|
||||||
self.iou_prediction_head = MLP( |
|
||||||
transformer_dim, |
|
||||||
iou_head_hidden_dim, |
|
||||||
self.num_mask_tokens, |
|
||||||
iou_head_depth, |
|
||||||
sigmoid=iou_prediction_use_sigmoid, |
|
||||||
) |
|
||||||
if self.pred_obj_scores: |
|
||||||
self.pred_obj_score_head = nn.Linear(transformer_dim, 1) |
|
||||||
if pred_obj_scores_mlp: |
|
||||||
self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) |
|
||||||
|
|
||||||
# When outputting a single mask, optionally we can dynamically fall back to the best |
|
||||||
# multimask output token if the single mask output token gives low stability scores. |
|
||||||
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability |
|
||||||
self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta |
|
||||||
self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh |
|
||||||
|
|
||||||
def forward( |
|
||||||
self, |
|
||||||
image_embeddings: torch.Tensor, |
|
||||||
image_pe: torch.Tensor, |
|
||||||
sparse_prompt_embeddings: torch.Tensor, |
|
||||||
dense_prompt_embeddings: torch.Tensor, |
|
||||||
multimask_output: bool, |
|
||||||
repeat_image: bool, |
|
||||||
high_res_features: Optional[List[torch.Tensor]] = None, |
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
||||||
""" |
|
||||||
Predicts masks given image and prompt embeddings. |
|
||||||
|
|
||||||
Args: |
|
||||||
image_embeddings (torch.Tensor): Embeddings from the image encoder. |
|
||||||
image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings. |
|
||||||
sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes. |
|
||||||
dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs. |
|
||||||
multimask_output (bool): Whether to return multiple masks or a single mask. |
|
||||||
repeat_image (bool): Flag to repeat the image embeddings. |
|
||||||
high_res_features (List[torch.Tensor] | None): Optional high-resolution features. |
|
||||||
|
|
||||||
Returns: |
|
||||||
(Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing: |
|
||||||
- masks (torch.Tensor): Batched predicted masks. |
|
||||||
- iou_pred (torch.Tensor): Batched predictions of mask quality. |
|
||||||
- sam_tokens_out (torch.Tensor): Batched SAM token for mask output. |
|
||||||
|
|
||||||
Examples: |
|
||||||
>>> image_embeddings = torch.rand(1, 256, 64, 64) |
|
||||||
>>> image_pe = torch.rand(1, 256, 64, 64) |
|
||||||
>>> sparse_prompt_embeddings = torch.rand(1, 2, 256) |
|
||||||
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64) |
|
||||||
>>> decoder = MaskDecoder(256, transformer) |
|
||||||
>>> masks, iou_pred, sam_tokens_out = decoder.forward(image_embeddings, image_pe, |
|
||||||
... sparse_prompt_embeddings, dense_prompt_embeddings, True, False) |
|
||||||
""" |
|
||||||
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( |
|
||||||
image_embeddings=image_embeddings, |
|
||||||
image_pe=image_pe, |
|
||||||
sparse_prompt_embeddings=sparse_prompt_embeddings, |
|
||||||
dense_prompt_embeddings=dense_prompt_embeddings, |
|
||||||
repeat_image=repeat_image, |
|
||||||
high_res_features=high_res_features, |
|
||||||
) |
|
||||||
|
|
||||||
# Select the correct mask or masks for output |
|
||||||
if multimask_output: |
|
||||||
masks = masks[:, 1:, :, :] |
|
||||||
iou_pred = iou_pred[:, 1:] |
|
||||||
elif self.dynamic_multimask_via_stability and not self.training: |
|
||||||
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) |
|
||||||
else: |
|
||||||
masks = masks[:, 0:1, :, :] |
|
||||||
iou_pred = iou_pred[:, 0:1] |
|
||||||
|
|
||||||
if multimask_output and self.use_multimask_token_for_obj_ptr: |
|
||||||
sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape |
|
||||||
else: |
|
||||||
# Take the mask output token. Here we *always* use the token for single mask output. |
|
||||||
# At test time, even if we track after 1-click (and using multimask_output=True), |
|
||||||
# we still take the single mask token here. The rationale is that we always track |
|
||||||
# after multiple clicks during training, so the past tokens seen during training |
|
||||||
# are always the single mask token (and we'll let it be the object-memory token). |
|
||||||
sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape |
|
||||||
|
|
||||||
# Prepare output |
|
||||||
return masks, iou_pred, sam_tokens_out, object_score_logits |
|
||||||
|
|
||||||
def predict_masks( |
|
||||||
self, |
|
||||||
image_embeddings: torch.Tensor, |
|
||||||
image_pe: torch.Tensor, |
|
||||||
sparse_prompt_embeddings: torch.Tensor, |
|
||||||
dense_prompt_embeddings: torch.Tensor, |
|
||||||
repeat_image: bool, |
|
||||||
high_res_features: Optional[List[torch.Tensor]] = None, |
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
||||||
"""Predicts instance segmentation masks from image and prompt embeddings using a transformer architecture.""" |
|
||||||
# Concatenate output tokens |
|
||||||
s = 0 |
|
||||||
if self.pred_obj_scores: |
|
||||||
output_tokens = torch.cat( |
|
||||||
[ |
|
||||||
self.obj_score_token.weight, |
|
||||||
self.iou_token.weight, |
|
||||||
self.mask_tokens.weight, |
|
||||||
], |
|
||||||
dim=0, |
|
||||||
) |
|
||||||
s = 1 |
|
||||||
else: |
|
||||||
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) |
|
||||||
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) |
|
||||||
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) |
|
||||||
|
|
||||||
# Expand per-image data in batch direction to be per-mask |
|
||||||
if repeat_image: |
|
||||||
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) |
|
||||||
else: |
|
||||||
assert image_embeddings.shape[0] == tokens.shape[0] |
|
||||||
src = image_embeddings |
|
||||||
src = src + dense_prompt_embeddings |
|
||||||
assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" |
|
||||||
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) |
|
||||||
b, c, h, w = src.shape |
|
||||||
|
|
||||||
# Run the transformer |
|
||||||
hs, src = self.transformer(src, pos_src, tokens) |
|
||||||
iou_token_out = hs[:, s, :] |
|
||||||
mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] |
|
||||||
|
|
||||||
# Upscale mask embeddings and predict masks using the mask tokens |
|
||||||
src = src.transpose(1, 2).view(b, c, h, w) |
|
||||||
if not self.use_high_res_features: |
|
||||||
upscaled_embedding = self.output_upscaling(src) |
|
||||||
else: |
|
||||||
dc1, ln1, act1, dc2, act2 = self.output_upscaling |
|
||||||
feat_s0, feat_s1 = high_res_features |
|
||||||
upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) |
|
||||||
upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) |
|
||||||
|
|
||||||
hyper_in_list: List[torch.Tensor] = [] |
|
||||||
for i in range(self.num_mask_tokens): |
|
||||||
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) |
|
||||||
hyper_in = torch.stack(hyper_in_list, dim=1) |
|
||||||
b, c, h, w = upscaled_embedding.shape |
|
||||||
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) |
|
||||||
|
|
||||||
# Generate mask quality predictions |
|
||||||
iou_pred = self.iou_prediction_head(iou_token_out) |
|
||||||
if self.pred_obj_scores: |
|
||||||
assert s == 1 |
|
||||||
object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) |
|
||||||
else: |
|
||||||
# Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 |
|
||||||
object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) |
|
||||||
|
|
||||||
return masks, iou_pred, mask_tokens_out, object_score_logits |
|
||||||
|
|
||||||
def _get_stability_scores(self, mask_logits): |
|
||||||
"""Computes mask stability scores based on IoU between upper and lower thresholds.""" |
|
||||||
mask_logits = mask_logits.flatten(-2) |
|
||||||
stability_delta = self.dynamic_multimask_stability_delta |
|
||||||
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() |
|
||||||
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() |
|
||||||
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) |
|
||||||
return stability_scores |
|
||||||
|
|
||||||
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): |
|
||||||
""" |
|
||||||
Dynamically selects the most stable mask output based on stability scores and IoU predictions. |
|
||||||
|
|
||||||
When outputting a single mask, if the stability score from the current single-mask output (based on output token |
|
||||||
0) falls below a threshold, we instead select from multi-mask outputs (based on output token 1~3) the mask with |
|
||||||
the highest predicted IoU score. |
|
||||||
|
|
||||||
This is intended to ensure a valid mask for both clicking and tracking. |
|
||||||
""" |
|
||||||
# The best mask from multimask output tokens (1~3) |
|
||||||
multimask_logits = all_mask_logits[:, 1:, :, :] |
|
||||||
multimask_iou_scores = all_iou_scores[:, 1:] |
|
||||||
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) |
|
||||||
batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device) |
|
||||||
best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] |
|
||||||
best_multimask_logits = best_multimask_logits.unsqueeze(1) |
|
||||||
best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] |
|
||||||
best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) |
|
||||||
|
|
||||||
# The mask from singlemask output token 0 and its stability score |
|
||||||
singlemask_logits = all_mask_logits[:, 0:1, :, :] |
|
||||||
singlemask_iou_scores = all_iou_scores[:, 0:1] |
|
||||||
stability_scores = self._get_stability_scores(singlemask_logits) |
|
||||||
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh |
|
||||||
|
|
||||||
# Dynamically fall back to best multimask output upon low stability scores. |
|
||||||
mask_logits_out = torch.where( |
|
||||||
is_stable[..., None, None].expand_as(singlemask_logits), |
|
||||||
singlemask_logits, |
|
||||||
best_multimask_logits, |
|
||||||
) |
|
||||||
iou_scores_out = torch.where( |
|
||||||
is_stable.expand_as(singlemask_iou_scores), |
|
||||||
singlemask_iou_scores, |
|
||||||
best_multimask_iou_scores, |
|
||||||
) |
|
||||||
return mask_logits_out, iou_scores_out |
|
@ -1,332 +0,0 @@ |
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
|
||||||
|
|
||||||
from typing import List, Optional, Tuple |
|
||||||
|
|
||||||
import torch |
|
||||||
import torch.nn as nn |
|
||||||
import torch.nn.functional as F |
|
||||||
|
|
||||||
from ultralytics.models.sam.modules.encoders import PatchEmbed |
|
||||||
|
|
||||||
from .sam2_blocks import CXBlock, Fuser, MaskDownSampler, MultiScaleBlock, PositionEmbeddingSine |
|
||||||
|
|
||||||
|
|
||||||
class MemoryEncoder(nn.Module): |
|
||||||
"""Encodes pixel features and masks into a memory representation for efficient image segmentation.""" |
|
||||||
|
|
||||||
def __init__( |
|
||||||
self, |
|
||||||
out_dim, |
|
||||||
in_dim=256, # in_dim of pix_feats |
|
||||||
): |
|
||||||
"""Initializes the MemoryEncoder module for encoding pixel features and masks in SAM-like models.""" |
|
||||||
super().__init__() |
|
||||||
|
|
||||||
self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1) |
|
||||||
|
|
||||||
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) |
|
||||||
self.fuser = Fuser(CXBlock(dim=256), num_layers=2) |
|
||||||
self.position_encoding = PositionEmbeddingSine(num_pos_feats=64) |
|
||||||
self.out_proj = nn.Identity() |
|
||||||
if out_dim != in_dim: |
|
||||||
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) |
|
||||||
|
|
||||||
def forward( |
|
||||||
self, |
|
||||||
pix_feat: torch.Tensor, |
|
||||||
masks: torch.Tensor, |
|
||||||
skip_mask_sigmoid: bool = False, |
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
||||||
"""Processes pixel features and masks, fusing them to generate encoded memory representations.""" |
|
||||||
if not skip_mask_sigmoid: |
|
||||||
masks = F.sigmoid(masks) |
|
||||||
masks = self.mask_downsampler(masks) |
|
||||||
|
|
||||||
# Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA |
|
||||||
pix_feat = pix_feat.to(masks.device) |
|
||||||
|
|
||||||
x = self.pix_feat_proj(pix_feat) |
|
||||||
x = x + masks |
|
||||||
x = self.fuser(x) |
|
||||||
x = self.out_proj(x) |
|
||||||
|
|
||||||
pos = self.position_encoding(x).to(x.dtype) |
|
||||||
|
|
||||||
return {"vision_features": x, "vision_pos_enc": [pos]} |
|
||||||
|
|
||||||
|
|
||||||
class ImageEncoder(nn.Module): |
|
||||||
"""Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings.""" |
|
||||||
|
|
||||||
def __init__( |
|
||||||
self, |
|
||||||
trunk: nn.Module, |
|
||||||
neck: nn.Module, |
|
||||||
scalp: int = 0, |
|
||||||
): |
|
||||||
"""Initializes an image encoder with a trunk, neck, and optional scalp for feature extraction.""" |
|
||||||
super().__init__() |
|
||||||
self.trunk = trunk |
|
||||||
self.neck = neck |
|
||||||
self.scalp = scalp |
|
||||||
assert ( |
|
||||||
self.trunk.channel_list == self.neck.backbone_channel_list |
|
||||||
), f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match." |
|
||||||
|
|
||||||
def forward(self, sample: torch.Tensor): |
|
||||||
"""Processes image input through trunk and neck, returning features, positional encodings, and FPN outputs.""" |
|
||||||
features, pos = self.neck(self.trunk(sample)) |
|
||||||
if self.scalp > 0: |
|
||||||
# Discard the lowest resolution features |
|
||||||
features, pos = features[: -self.scalp], pos[: -self.scalp] |
|
||||||
|
|
||||||
src = features[-1] |
|
||||||
output = { |
|
||||||
"vision_features": src, |
|
||||||
"vision_pos_enc": pos, |
|
||||||
"backbone_fpn": features, |
|
||||||
} |
|
||||||
return output |
|
||||||
|
|
||||||
|
|
||||||
class FpnNeck(nn.Module): |
|
||||||
"""Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.""" |
|
||||||
|
|
||||||
def __init__( |
|
||||||
self, |
|
||||||
d_model: int, |
|
||||||
backbone_channel_list: List[int], |
|
||||||
kernel_size: int = 1, |
|
||||||
stride: int = 1, |
|
||||||
padding: int = 0, |
|
||||||
fpn_interp_model: str = "bilinear", |
|
||||||
fuse_type: str = "sum", |
|
||||||
fpn_top_down_levels: Optional[List[int]] = None, |
|
||||||
): |
|
||||||
""" |
|
||||||
Initializes a modified Feature Pyramid Network (FPN) neck. |
|
||||||
|
|
||||||
This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing, |
|
||||||
similar to ViT positional embedding interpolation. |
|
||||||
|
|
||||||
Args: |
|
||||||
d_model (int): Dimension of the model. |
|
||||||
backbone_channel_list (List[int]): List of channel dimensions from the backbone. |
|
||||||
kernel_size (int): Kernel size for the convolutional layers. |
|
||||||
stride (int): Stride for the convolutional layers. |
|
||||||
padding (int): Padding for the convolutional layers. |
|
||||||
fpn_interp_model (str): Interpolation mode for FPN feature resizing. |
|
||||||
fuse_type (str): Type of feature fusion, either 'sum' or 'avg'. |
|
||||||
fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs. |
|
||||||
|
|
||||||
Attributes: |
|
||||||
position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding. |
|
||||||
convs (nn.ModuleList): List of convolutional layers for each backbone level. |
|
||||||
backbone_channel_list (List[int]): List of channel dimensions from the backbone. |
|
||||||
fpn_interp_model (str): Interpolation mode for FPN feature resizing. |
|
||||||
fuse_type (str): Type of feature fusion. |
|
||||||
fpn_top_down_levels (List[int]): Levels with top-down feature propagation. |
|
||||||
|
|
||||||
Examples: |
|
||||||
>>> backbone_channels = [64, 128, 256, 512] |
|
||||||
>>> fpn_neck = FpnNeck(256, backbone_channels) |
|
||||||
>>> print(fpn_neck) |
|
||||||
""" |
|
||||||
super().__init__() |
|
||||||
self.position_encoding = PositionEmbeddingSine(num_pos_feats=256) |
|
||||||
self.convs = nn.ModuleList() |
|
||||||
self.backbone_channel_list = backbone_channel_list |
|
||||||
for dim in backbone_channel_list: |
|
||||||
current = nn.Sequential() |
|
||||||
current.add_module( |
|
||||||
"conv", |
|
||||||
nn.Conv2d( |
|
||||||
in_channels=dim, |
|
||||||
out_channels=d_model, |
|
||||||
kernel_size=kernel_size, |
|
||||||
stride=stride, |
|
||||||
padding=padding, |
|
||||||
), |
|
||||||
) |
|
||||||
|
|
||||||
self.convs.append(current) |
|
||||||
self.fpn_interp_model = fpn_interp_model |
|
||||||
assert fuse_type in ["sum", "avg"] |
|
||||||
self.fuse_type = fuse_type |
|
||||||
|
|
||||||
# levels to have top-down features in its outputs |
|
||||||
# e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 |
|
||||||
# have top-down propagation, while outputs of level 0 and level 1 have only |
|
||||||
# lateral features from the same backbone level. |
|
||||||
if fpn_top_down_levels is None: |
|
||||||
# default is to have top-down features on all levels |
|
||||||
fpn_top_down_levels = range(len(self.convs)) |
|
||||||
self.fpn_top_down_levels = list(fpn_top_down_levels) |
|
||||||
|
|
||||||
def forward(self, xs: List[torch.Tensor]): |
|
||||||
""" |
|
||||||
Performs forward pass through the Feature Pyramid Network (FPN) neck. |
|
||||||
|
|
||||||
Args: |
|
||||||
xs (List[torch.Tensor]): List of input tensors from the backbone, with shape (B, C, H, W) for each tensor. |
|
||||||
|
|
||||||
Returns: |
|
||||||
(Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing two lists: |
|
||||||
- out: List of output feature maps after FPN processing, with shape (B, d_model, H, W) for each tensor. |
|
||||||
- pos: List of positional encodings corresponding to each output feature map. |
|
||||||
|
|
||||||
Examples: |
|
||||||
>>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512]) |
|
||||||
>>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]] |
|
||||||
>>> outputs, positions = fpn_neck(inputs) |
|
||||||
""" |
|
||||||
out = [None] * len(self.convs) |
|
||||||
pos = [None] * len(self.convs) |
|
||||||
assert len(xs) == len(self.convs) |
|
||||||
# fpn forward pass |
|
||||||
# see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py |
|
||||||
prev_features = None |
|
||||||
# forward in top-down order (from low to high resolution) |
|
||||||
n = len(self.convs) - 1 |
|
||||||
for i in range(n, -1, -1): |
|
||||||
x = xs[i] |
|
||||||
lateral_features = self.convs[n - i](x) |
|
||||||
if i in self.fpn_top_down_levels and prev_features is not None: |
|
||||||
top_down_features = F.interpolate( |
|
||||||
prev_features.to(dtype=torch.float32), |
|
||||||
scale_factor=2.0, |
|
||||||
mode=self.fpn_interp_model, |
|
||||||
align_corners=(None if self.fpn_interp_model == "nearest" else False), |
|
||||||
antialias=False, |
|
||||||
) |
|
||||||
prev_features = lateral_features + top_down_features |
|
||||||
if self.fuse_type == "avg": |
|
||||||
prev_features /= 2 |
|
||||||
else: |
|
||||||
prev_features = lateral_features |
|
||||||
x_out = prev_features |
|
||||||
out[i] = x_out |
|
||||||
pos[i] = self.position_encoding(x_out).to(x_out.dtype) |
|
||||||
|
|
||||||
return out, pos |
|
||||||
|
|
||||||
|
|
||||||
class Hiera(nn.Module): |
|
||||||
"""Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.""" |
|
||||||
|
|
||||||
def __init__( |
|
||||||
self, |
|
||||||
embed_dim: int = 96, # initial embed dim |
|
||||||
num_heads: int = 1, # initial number of heads |
|
||||||
drop_path_rate: float = 0.0, # stochastic depth |
|
||||||
q_pool: int = 3, # number of q_pool stages |
|
||||||
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages |
|
||||||
stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage |
|
||||||
dim_mul: float = 2.0, # dim_mul factor at stage shift |
|
||||||
head_mul: float = 2.0, # head_mul factor at stage shift |
|
||||||
window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), |
|
||||||
# window size per stage, when not using global att. |
|
||||||
window_spec: Tuple[int, ...] = ( |
|
||||||
8, |
|
||||||
4, |
|
||||||
14, |
|
||||||
7, |
|
||||||
), |
|
||||||
# global attn in these blocks |
|
||||||
global_att_blocks: Tuple[int, ...] = ( |
|
||||||
12, |
|
||||||
16, |
|
||||||
20, |
|
||||||
), |
|
||||||
return_interm_layers=True, # return feats from every stage |
|
||||||
): |
|
||||||
"""Initializes a Hiera model with configurable architecture for hierarchical vision transformers.""" |
|
||||||
super().__init__() |
|
||||||
|
|
||||||
assert len(stages) == len(window_spec) |
|
||||||
self.window_spec = window_spec |
|
||||||
|
|
||||||
depth = sum(stages) |
|
||||||
self.q_stride = q_stride |
|
||||||
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] |
|
||||||
assert 0 <= q_pool <= len(self.stage_ends[:-1]) |
|
||||||
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] |
|
||||||
self.return_interm_layers = return_interm_layers |
|
||||||
|
|
||||||
self.patch_embed = PatchEmbed( |
|
||||||
embed_dim=embed_dim, |
|
||||||
kernel_size=(7, 7), |
|
||||||
stride=(4, 4), |
|
||||||
padding=(3, 3), |
|
||||||
) |
|
||||||
# Which blocks have global att? |
|
||||||
self.global_att_blocks = global_att_blocks |
|
||||||
|
|
||||||
# Windowed positional embedding (https://arxiv.org/abs/2311.05613) |
|
||||||
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size |
|
||||||
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)) |
|
||||||
self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])) |
|
||||||
|
|
||||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule |
|
||||||
|
|
||||||
cur_stage = 1 |
|
||||||
self.blocks = nn.ModuleList() |
|
||||||
|
|
||||||
for i in range(depth): |
|
||||||
dim_out = embed_dim |
|
||||||
# lags by a block, so first block of |
|
||||||
# next stage uses an initial window size |
|
||||||
# of previous stage and final window size of current stage |
|
||||||
window_size = self.window_spec[cur_stage - 1] |
|
||||||
|
|
||||||
if self.global_att_blocks is not None: |
|
||||||
window_size = 0 if i in self.global_att_blocks else window_size |
|
||||||
|
|
||||||
if i - 1 in self.stage_ends: |
|
||||||
dim_out = int(embed_dim * dim_mul) |
|
||||||
num_heads = int(num_heads * head_mul) |
|
||||||
cur_stage += 1 |
|
||||||
|
|
||||||
block = MultiScaleBlock( |
|
||||||
dim=embed_dim, |
|
||||||
dim_out=dim_out, |
|
||||||
num_heads=num_heads, |
|
||||||
drop_path=dpr[i], |
|
||||||
q_stride=self.q_stride if i in self.q_pool_blocks else None, |
|
||||||
window_size=window_size, |
|
||||||
) |
|
||||||
|
|
||||||
embed_dim = dim_out |
|
||||||
self.blocks.append(block) |
|
||||||
|
|
||||||
self.channel_list = ( |
|
||||||
[self.blocks[i].dim_out for i in self.stage_ends[::-1]] |
|
||||||
if return_interm_layers |
|
||||||
else [self.blocks[-1].dim_out] |
|
||||||
) |
|
||||||
|
|
||||||
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: |
|
||||||
"""Generate positional embeddings by interpolating and combining window and background embeddings.""" |
|
||||||
h, w = hw |
|
||||||
window_embed = self.pos_embed_window |
|
||||||
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") |
|
||||||
pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)]) |
|
||||||
pos_embed = pos_embed.permute(0, 2, 3, 1) |
|
||||||
return pos_embed |
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]: |
|
||||||
"""Performs hierarchical vision transformer forward pass, returning multiscale feature maps.""" |
|
||||||
x = self.patch_embed(x) |
|
||||||
# x: (B, H, W, C) |
|
||||||
|
|
||||||
# Add pos embed |
|
||||||
x = x + self._get_pos_embed(x.shape[1:3]) |
|
||||||
|
|
||||||
outputs = [] |
|
||||||
for i, blk in enumerate(self.blocks): |
|
||||||
x = blk(x) |
|
||||||
if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers): |
|
||||||
feats = x.permute(0, 3, 1, 2) |
|
||||||
outputs.append(feats) |
|
||||||
|
|
||||||
return outputs |
|
@ -1,804 +0,0 @@ |
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
|
||||||
|
|
||||||
import torch |
|
||||||
import torch.distributed |
|
||||||
import torch.nn.functional as F |
|
||||||
from torch.nn.init import trunc_normal_ |
|
||||||
|
|
||||||
from ultralytics.models.sam.modules.encoders import PromptEncoder |
|
||||||
from ultralytics.nn.modules import MLP |
|
||||||
|
|
||||||
from .decoders import MaskDecoder |
|
||||||
from .sam2_blocks import TwoWayTransformer |
|
||||||
from .utils import get_1d_sine_pe, select_closest_cond_frames |
|
||||||
|
|
||||||
# a large negative value as a placeholder score for missing objects |
|
||||||
NO_OBJ_SCORE = -1024.0 |
|
||||||
|
|
||||||
|
|
||||||
class SAM2Model(torch.nn.Module): |
|
||||||
"""SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.""" |
|
||||||
|
|
||||||
mask_threshold: float = 0.0 |
|
||||||
|
|
||||||
def __init__( |
|
||||||
self, |
|
||||||
image_encoder, |
|
||||||
memory_attention, |
|
||||||
memory_encoder, |
|
||||||
num_maskmem=7, # default 1 input frame + 6 previous frames |
|
||||||
image_size=512, |
|
||||||
backbone_stride=16, # stride of the image backbone output |
|
||||||
sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob |
|
||||||
sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob |
|
||||||
# During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks |
|
||||||
binarize_mask_from_pts_for_mem_enc=False, |
|
||||||
use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder |
|
||||||
# The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit, |
|
||||||
# we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model |
|
||||||
# a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM. |
|
||||||
max_cond_frames_in_attn=-1, |
|
||||||
# on the first frame, whether to directly add the no-memory embedding to the image feature |
|
||||||
# (instead of using the transformer encoder) |
|
||||||
directly_add_no_mem_embed=False, |
|
||||||
# whether to use high-resolution feature maps in the SAM mask decoder |
|
||||||
use_high_res_features_in_sam=False, |
|
||||||
# whether to output multiple (3) masks for the first click on initial conditioning frames |
|
||||||
multimask_output_in_sam=False, |
|
||||||
# the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`; |
|
||||||
# default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points) |
|
||||||
multimask_min_pt_num=1, |
|
||||||
multimask_max_pt_num=1, |
|
||||||
# whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`) |
|
||||||
multimask_output_for_tracking=False, |
|
||||||
# Whether to use multimask tokens for obj ptr; Only relevant when both |
|
||||||
# use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True |
|
||||||
use_multimask_token_for_obj_ptr: bool = False, |
|
||||||
# whether to use sigmoid to restrict ious prediction to [0-1] |
|
||||||
iou_prediction_use_sigmoid=False, |
|
||||||
# The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5). |
|
||||||
# For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of |
|
||||||
# (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. |
|
||||||
memory_temporal_stride_for_eval=1, |
|
||||||
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click |
|
||||||
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames |
|
||||||
add_all_frames_to_correct_as_cond=False, |
|
||||||
# whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) |
|
||||||
non_overlap_masks_for_mem_enc=False, |
|
||||||
# whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder |
|
||||||
use_obj_ptrs_in_encoder=False, |
|
||||||
# the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`) |
|
||||||
max_obj_ptrs_in_encoder=16, |
|
||||||
# whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`) |
|
||||||
add_tpos_enc_to_obj_ptrs=True, |
|
||||||
# whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference |
|
||||||
# with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) |
|
||||||
proj_tpos_enc_in_obj_ptrs=False, |
|
||||||
# whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation |
|
||||||
# (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) |
|
||||||
only_obj_ptrs_in_the_past_for_eval=False, |
|
||||||
# Whether to predict if there is an object in the frame |
|
||||||
pred_obj_scores: bool = False, |
|
||||||
# Whether to use an MLP to predict object scores |
|
||||||
pred_obj_scores_mlp: bool = False, |
|
||||||
# Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True; |
|
||||||
# Whether to have a fixed no obj pointer when there is no object present |
|
||||||
# or to use it as an additive embedding with obj_ptr produced by decoder |
|
||||||
fixed_no_obj_ptr: bool = False, |
|
||||||
# Soft no object, i.e. mix in no_obj_ptr softly, |
|
||||||
# hope to make recovery easier if there is a mistake and mitigate accumulation of errors |
|
||||||
soft_no_obj_ptr: bool = False, |
|
||||||
use_mlp_for_obj_ptr_proj: bool = False, |
|
||||||
# extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. |
|
||||||
sam_mask_decoder_extra_args=None, |
|
||||||
compile_image_encoder: bool = False, |
|
||||||
): |
|
||||||
"""Initializes SAM2Model model with image encoder, memory attention, and memory encoder components.""" |
|
||||||
super().__init__() |
|
||||||
|
|
||||||
# Part 1: the image backbone |
|
||||||
self.image_encoder = image_encoder |
|
||||||
# Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting |
|
||||||
self.use_high_res_features_in_sam = use_high_res_features_in_sam |
|
||||||
self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 |
|
||||||
self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder |
|
||||||
self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder |
|
||||||
if use_obj_ptrs_in_encoder: |
|
||||||
# A conv layer to downsample the mask prompt to stride 4 (the same stride as |
|
||||||
# low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, |
|
||||||
# so that it can be fed into the SAM mask decoder to generate a pointer. |
|
||||||
self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) |
|
||||||
self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs |
|
||||||
if proj_tpos_enc_in_obj_ptrs: |
|
||||||
assert add_tpos_enc_to_obj_ptrs # these options need to be used together |
|
||||||
self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs |
|
||||||
self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval |
|
||||||
|
|
||||||
# Part 2: memory attention to condition current frame's visual features |
|
||||||
# with memories (and obj ptrs) from past frames |
|
||||||
self.memory_attention = memory_attention |
|
||||||
self.hidden_dim = memory_attention.d_model |
|
||||||
|
|
||||||
# Part 3: memory encoder for the previous frame's outputs |
|
||||||
self.memory_encoder = memory_encoder |
|
||||||
self.mem_dim = self.hidden_dim |
|
||||||
if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"): |
|
||||||
# if there is compression of memories along channel dim |
|
||||||
self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] |
|
||||||
self.num_maskmem = num_maskmem # Number of memories accessible |
|
||||||
# Temporal encoding of the memories |
|
||||||
self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim)) |
|
||||||
trunc_normal_(self.maskmem_tpos_enc, std=0.02) |
|
||||||
# a single token to indicate no memory embedding from previous frames |
|
||||||
self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) |
|
||||||
self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) |
|
||||||
trunc_normal_(self.no_mem_embed, std=0.02) |
|
||||||
trunc_normal_(self.no_mem_pos_enc, std=0.02) |
|
||||||
self.directly_add_no_mem_embed = directly_add_no_mem_embed |
|
||||||
# Apply sigmoid to the output raw mask logits (to turn them from |
|
||||||
# range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder |
|
||||||
self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc |
|
||||||
self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc |
|
||||||
self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc |
|
||||||
self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc |
|
||||||
self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval |
|
||||||
# On frames with mask input, whether to directly output the input mask without |
|
||||||
# using a SAM prompt encoder + mask decoder |
|
||||||
self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam |
|
||||||
self.multimask_output_in_sam = multimask_output_in_sam |
|
||||||
self.multimask_min_pt_num = multimask_min_pt_num |
|
||||||
self.multimask_max_pt_num = multimask_max_pt_num |
|
||||||
self.multimask_output_for_tracking = multimask_output_for_tracking |
|
||||||
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr |
|
||||||
self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid |
|
||||||
|
|
||||||
# Part 4: SAM-style prompt encoder (for both mask and point inputs) |
|
||||||
# and SAM-style mask decoder for the final mask output |
|
||||||
self.image_size = image_size |
|
||||||
self.backbone_stride = backbone_stride |
|
||||||
self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args |
|
||||||
self.pred_obj_scores = pred_obj_scores |
|
||||||
self.pred_obj_scores_mlp = pred_obj_scores_mlp |
|
||||||
self.fixed_no_obj_ptr = fixed_no_obj_ptr |
|
||||||
self.soft_no_obj_ptr = soft_no_obj_ptr |
|
||||||
if self.fixed_no_obj_ptr: |
|
||||||
assert self.pred_obj_scores |
|
||||||
assert self.use_obj_ptrs_in_encoder |
|
||||||
if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: |
|
||||||
self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) |
|
||||||
trunc_normal_(self.no_obj_ptr, std=0.02) |
|
||||||
self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj |
|
||||||
|
|
||||||
self._build_sam_heads() |
|
||||||
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond |
|
||||||
self.max_cond_frames_in_attn = max_cond_frames_in_attn |
|
||||||
|
|
||||||
# Model compilation |
|
||||||
if compile_image_encoder: |
|
||||||
# Compile the forward function (not the full module) to allow loading checkpoints. |
|
||||||
print("Image encoder compilation is enabled. First forward pass will be slow.") |
|
||||||
self.image_encoder.forward = torch.compile( |
|
||||||
self.image_encoder.forward, |
|
||||||
mode="max-autotune", |
|
||||||
fullgraph=True, |
|
||||||
dynamic=False, |
|
||||||
) |
|
||||||
|
|
||||||
@property |
|
||||||
def device(self): |
|
||||||
"""Returns the device on which the model's parameters are stored.""" |
|
||||||
return next(self.parameters()).device |
|
||||||
|
|
||||||
def forward(self, *args, **kwargs): |
|
||||||
"""Processes input frames and prompts to generate object masks and scores in video sequences.""" |
|
||||||
raise NotImplementedError( |
|
||||||
"Please use the corresponding methods in SAM2VideoPredictor for inference." |
|
||||||
"See notebooks/video_predictor_example.ipynb for an example." |
|
||||||
) |
|
||||||
|
|
||||||
def _build_sam_heads(self): |
|
||||||
"""Builds SAM-style prompt encoder and mask decoder for image segmentation tasks.""" |
|
||||||
self.sam_prompt_embed_dim = self.hidden_dim |
|
||||||
self.sam_image_embedding_size = self.image_size // self.backbone_stride |
|
||||||
|
|
||||||
# build PromptEncoder and MaskDecoder from SAM |
|
||||||
# (their hyperparameters like `mask_in_chans=16` are from SAM code) |
|
||||||
self.sam_prompt_encoder = PromptEncoder( |
|
||||||
embed_dim=self.sam_prompt_embed_dim, |
|
||||||
image_embedding_size=( |
|
||||||
self.sam_image_embedding_size, |
|
||||||
self.sam_image_embedding_size, |
|
||||||
), |
|
||||||
input_image_size=(self.image_size, self.image_size), |
|
||||||
mask_in_chans=16, |
|
||||||
) |
|
||||||
self.sam_mask_decoder = MaskDecoder( |
|
||||||
num_multimask_outputs=3, |
|
||||||
transformer=TwoWayTransformer( |
|
||||||
depth=2, |
|
||||||
embedding_dim=self.sam_prompt_embed_dim, |
|
||||||
mlp_dim=2048, |
|
||||||
num_heads=8, |
|
||||||
), |
|
||||||
transformer_dim=self.sam_prompt_embed_dim, |
|
||||||
iou_head_depth=3, |
|
||||||
iou_head_hidden_dim=256, |
|
||||||
use_high_res_features=self.use_high_res_features_in_sam, |
|
||||||
iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, |
|
||||||
pred_obj_scores=self.pred_obj_scores, |
|
||||||
pred_obj_scores_mlp=self.pred_obj_scores_mlp, |
|
||||||
use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, |
|
||||||
**(self.sam_mask_decoder_extra_args or {}), |
|
||||||
) |
|
||||||
if self.use_obj_ptrs_in_encoder: |
|
||||||
# a linear projection on SAM output tokens to turn them into object pointers |
|
||||||
self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) |
|
||||||
if self.use_mlp_for_obj_ptr_proj: |
|
||||||
self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3) |
|
||||||
else: |
|
||||||
self.obj_ptr_proj = torch.nn.Identity() |
|
||||||
if self.proj_tpos_enc_in_obj_ptrs: |
|
||||||
# a linear projection on temporal positional encoding in object pointers to |
|
||||||
# avoid potential interference with spatial positional encoding |
|
||||||
self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) |
|
||||||
else: |
|
||||||
self.obj_ptr_tpos_proj = torch.nn.Identity() |
|
||||||
|
|
||||||
def _forward_sam_heads( |
|
||||||
self, |
|
||||||
backbone_features, |
|
||||||
point_inputs=None, |
|
||||||
mask_inputs=None, |
|
||||||
high_res_features=None, |
|
||||||
multimask_output=False, |
|
||||||
): |
|
||||||
""" |
|
||||||
Forward SAM prompt encoders and mask heads. |
|
||||||
|
|
||||||
Args: |
|
||||||
backbone_features (torch.Tensor): Image features with shape (B, C, H, W). |
|
||||||
point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts. |
|
||||||
'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute |
|
||||||
pixel-unit coordinates in (x, y) format for P input points. |
|
||||||
'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks, |
|
||||||
0 means negative clicks, and -1 means padding. |
|
||||||
mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the |
|
||||||
same spatial size as the image. |
|
||||||
high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes |
|
||||||
(B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps |
|
||||||
for SAM decoder. |
|
||||||
multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False, |
|
||||||
output only 1 mask and its IoU estimate. |
|
||||||
|
|
||||||
Returns: |
|
||||||
(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): |
|
||||||
low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits. |
|
||||||
high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits. |
|
||||||
ious: Tensor of shape (B, M) with estimated IoU for each output mask. |
|
||||||
low_res_masks: Tensor of shape (B, 1, H*4, W*4) with best low-resolution mask. |
|
||||||
high_res_masks: Tensor of shape (B, 1, H*16, W*16) with best high-resolution mask. |
|
||||||
obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask. |
|
||||||
object_score_logits: Tensor of shape (B,) with object score logits. |
|
||||||
|
|
||||||
Where M is 3 if multimask_output=True, and 1 if multimask_output=False. |
|
||||||
|
|
||||||
Examples: |
|
||||||
>>> backbone_features = torch.rand(1, 256, 32, 32) |
|
||||||
>>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])} |
|
||||||
>>> mask_inputs = torch.rand(1, 1, 512, 512) |
|
||||||
>>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs) |
|
||||||
>>> low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits = results |
|
||||||
""" |
|
||||||
B = backbone_features.size(0) |
|
||||||
device = backbone_features.device |
|
||||||
assert backbone_features.size(1) == self.sam_prompt_embed_dim |
|
||||||
assert backbone_features.size(2) == self.sam_image_embedding_size |
|
||||||
assert backbone_features.size(3) == self.sam_image_embedding_size |
|
||||||
|
|
||||||
# a) Handle point prompts |
|
||||||
if point_inputs is not None: |
|
||||||
sam_point_coords = point_inputs["point_coords"] |
|
||||||
sam_point_labels = point_inputs["point_labels"] |
|
||||||
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B |
|
||||||
else: |
|
||||||
# If no points are provide, pad with an empty point (with label -1) |
|
||||||
sam_point_coords = torch.zeros(B, 1, 2, device=device) |
|
||||||
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) |
|
||||||
|
|
||||||
# b) Handle mask prompts |
|
||||||
if mask_inputs is not None: |
|
||||||
# If mask_inputs is provided, downsize it into low-res mask input if needed |
|
||||||
# and feed it as a dense mask prompt into the SAM mask encoder |
|
||||||
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) |
|
||||||
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: |
|
||||||
sam_mask_prompt = F.interpolate( |
|
||||||
mask_inputs.float(), |
|
||||||
size=self.sam_prompt_encoder.mask_input_size, |
|
||||||
align_corners=False, |
|
||||||
mode="bilinear", |
|
||||||
antialias=True, # use antialias for downsampling |
|
||||||
) |
|
||||||
else: |
|
||||||
sam_mask_prompt = mask_inputs |
|
||||||
else: |
|
||||||
# Otherwise, simply feed None (and SAM's prompt encoder will add |
|
||||||
# a learned `no_mask_embed` to indicate no mask input in this case). |
|
||||||
sam_mask_prompt = None |
|
||||||
|
|
||||||
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( |
|
||||||
points=(sam_point_coords, sam_point_labels), |
|
||||||
boxes=None, |
|
||||||
masks=sam_mask_prompt, |
|
||||||
) |
|
||||||
( |
|
||||||
low_res_multimasks, |
|
||||||
ious, |
|
||||||
sam_output_tokens, |
|
||||||
object_score_logits, |
|
||||||
) = self.sam_mask_decoder( |
|
||||||
image_embeddings=backbone_features, |
|
||||||
image_pe=self.sam_prompt_encoder.get_dense_pe(), |
|
||||||
sparse_prompt_embeddings=sparse_embeddings, |
|
||||||
dense_prompt_embeddings=dense_embeddings, |
|
||||||
multimask_output=multimask_output, |
|
||||||
repeat_image=False, # the image is already batched |
|
||||||
high_res_features=high_res_features, |
|
||||||
) |
|
||||||
if self.pred_obj_scores: |
|
||||||
is_obj_appearing = object_score_logits > 0 |
|
||||||
|
|
||||||
# Mask used for spatial memories is always a *hard* choice between obj and no obj, |
|
||||||
# consistent with the actual mask prediction |
|
||||||
low_res_multimasks = torch.where( |
|
||||||
is_obj_appearing[:, None, None], |
|
||||||
low_res_multimasks, |
|
||||||
NO_OBJ_SCORE, |
|
||||||
) |
|
||||||
|
|
||||||
# convert masks from possibly bfloat16 (or float16) to float32 |
|
||||||
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16) |
|
||||||
low_res_multimasks = low_res_multimasks.float() |
|
||||||
high_res_multimasks = F.interpolate( |
|
||||||
low_res_multimasks, |
|
||||||
size=(self.image_size, self.image_size), |
|
||||||
mode="bilinear", |
|
||||||
align_corners=False, |
|
||||||
) |
|
||||||
|
|
||||||
sam_output_token = sam_output_tokens[:, 0] |
|
||||||
if multimask_output: |
|
||||||
# take the best mask prediction (with the highest IoU estimation) |
|
||||||
best_iou_inds = torch.argmax(ious, dim=-1) |
|
||||||
batch_inds = torch.arange(B, device=device) |
|
||||||
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) |
|
||||||
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) |
|
||||||
if sam_output_tokens.size(1) > 1: |
|
||||||
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] |
|
||||||
else: |
|
||||||
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks |
|
||||||
|
|
||||||
# Extract object pointer from the SAM output token (with occlusion handling) |
|
||||||
obj_ptr = self.obj_ptr_proj(sam_output_token) |
|
||||||
if self.pred_obj_scores: |
|
||||||
# Allow *soft* no obj ptr, unlike for masks |
|
||||||
if self.soft_no_obj_ptr: |
|
||||||
# Only hard possible with gt |
|
||||||
assert not self.teacher_force_obj_scores_for_mem |
|
||||||
lambda_is_obj_appearing = object_score_logits.sigmoid() |
|
||||||
else: |
|
||||||
lambda_is_obj_appearing = is_obj_appearing.float() |
|
||||||
|
|
||||||
if self.fixed_no_obj_ptr: |
|
||||||
obj_ptr = lambda_is_obj_appearing * obj_ptr |
|
||||||
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr |
|
||||||
|
|
||||||
return ( |
|
||||||
low_res_multimasks, |
|
||||||
high_res_multimasks, |
|
||||||
ious, |
|
||||||
low_res_masks, |
|
||||||
high_res_masks, |
|
||||||
obj_ptr, |
|
||||||
object_score_logits, |
|
||||||
) |
|
||||||
|
|
||||||
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): |
|
||||||
"""Processes mask inputs to generate output mask logits and object pointers without using SAM.""" |
|
||||||
# Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). |
|
||||||
out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 |
|
||||||
mask_inputs_float = mask_inputs.float() |
|
||||||
high_res_masks = mask_inputs_float * out_scale + out_bias |
|
||||||
low_res_masks = F.interpolate( |
|
||||||
high_res_masks, |
|
||||||
size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), |
|
||||||
align_corners=False, |
|
||||||
mode="bilinear", |
|
||||||
antialias=True, # use antialias for downsampling |
|
||||||
) |
|
||||||
# a dummy IoU prediction of all 1's under mask input |
|
||||||
ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() |
|
||||||
if not self.use_obj_ptrs_in_encoder: |
|
||||||
# all zeros as a dummy object pointer (of shape [B, C]) |
|
||||||
obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device) |
|
||||||
else: |
|
||||||
# produce an object pointer using the SAM decoder from the mask input |
|
||||||
_, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( |
|
||||||
backbone_features=backbone_features, |
|
||||||
mask_inputs=self.mask_downsample(mask_inputs_float), |
|
||||||
high_res_features=high_res_features, |
|
||||||
) |
|
||||||
# In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; |
|
||||||
# Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying |
|
||||||
# on the object_scores from the SAM decoder. |
|
||||||
is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) |
|
||||||
is_obj_appearing = is_obj_appearing[..., None] |
|
||||||
lambda_is_obj_appearing = is_obj_appearing.float() |
|
||||||
object_score_logits = out_scale * lambda_is_obj_appearing + out_bias |
|
||||||
if self.pred_obj_scores: |
|
||||||
if self.fixed_no_obj_ptr: |
|
||||||
obj_ptr = lambda_is_obj_appearing * obj_ptr |
|
||||||
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr |
|
||||||
|
|
||||||
return ( |
|
||||||
low_res_masks, |
|
||||||
high_res_masks, |
|
||||||
ious, |
|
||||||
low_res_masks, |
|
||||||
high_res_masks, |
|
||||||
obj_ptr, |
|
||||||
object_score_logits, |
|
||||||
) |
|
||||||
|
|
||||||
def forward_image(self, img_batch: torch.Tensor): |
|
||||||
"""Process image batch through encoder to extract multi-level features for SAM model.""" |
|
||||||
backbone_out = self.image_encoder(img_batch) |
|
||||||
if self.use_high_res_features_in_sam: |
|
||||||
# precompute projected level 0 and level 1 features in SAM decoder |
|
||||||
# to avoid running it again on every SAM click |
|
||||||
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0]) |
|
||||||
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1]) |
|
||||||
return backbone_out |
|
||||||
|
|
||||||
def _prepare_backbone_features(self, backbone_out): |
|
||||||
"""Prepare and flatten visual features from the image backbone output.""" |
|
||||||
backbone_out = backbone_out.copy() |
|
||||||
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) |
|
||||||
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels |
|
||||||
|
|
||||||
feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] |
|
||||||
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] |
|
||||||
|
|
||||||
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] |
|
||||||
# flatten NxCxHxW to HWxNxC |
|
||||||
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] |
|
||||||
vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] |
|
||||||
|
|
||||||
return backbone_out, vision_feats, vision_pos_embeds, feat_sizes |
|
||||||
|
|
||||||
def _prepare_memory_conditioned_features( |
|
||||||
self, |
|
||||||
frame_idx, |
|
||||||
is_init_cond_frame, |
|
||||||
current_vision_feats, |
|
||||||
current_vision_pos_embeds, |
|
||||||
feat_sizes, |
|
||||||
output_dict, |
|
||||||
num_frames, |
|
||||||
track_in_reverse=False, # tracking in reverse time order (for demo usage) |
|
||||||
): |
|
||||||
"""Prepares memory-conditioned features by fusing current frame's visual features with previous memories.""" |
|
||||||
B = current_vision_feats[-1].size(1) # batch size on this frame |
|
||||||
C = self.hidden_dim |
|
||||||
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size |
|
||||||
device = current_vision_feats[-1].device |
|
||||||
# The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. |
|
||||||
# In this case, we skip the fusion with any memory. |
|
||||||
if self.num_maskmem == 0: # Disable memory and skip fusion |
|
||||||
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) |
|
||||||
return pix_feat |
|
||||||
|
|
||||||
num_obj_ptr_tokens = 0 |
|
||||||
# Step 1: condition the visual features of the current frame on previous memories |
|
||||||
if not is_init_cond_frame: |
|
||||||
# Retrieve the memories encoded with the maskmem backbone |
|
||||||
to_cat_memory, to_cat_memory_pos_embed = [], [] |
|
||||||
# Add conditioning frames's output first (all cond frames have t_pos=0 for |
|
||||||
# when getting temporal positional embedding below) |
|
||||||
assert len(output_dict["cond_frame_outputs"]) > 0 |
|
||||||
# Select a maximum number of temporally closest cond frames for cross attention |
|
||||||
cond_outputs = output_dict["cond_frame_outputs"] |
|
||||||
selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( |
|
||||||
frame_idx, cond_outputs, self.max_cond_frames_in_attn |
|
||||||
) |
|
||||||
t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] |
|
||||||
# Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory |
|
||||||
# the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 |
|
||||||
# We also allow taking the memory frame non-consecutively (with r>1), in which case |
|
||||||
# we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. |
|
||||||
r = self.memory_temporal_stride_for_eval |
|
||||||
for t_pos in range(1, self.num_maskmem): |
|
||||||
t_rel = self.num_maskmem - t_pos # how many frames before current frame |
|
||||||
if t_rel == 1: |
|
||||||
# for t_rel == 1, we take the last frame (regardless of r) |
|
||||||
if not track_in_reverse: |
|
||||||
# the frame immediately before this frame (i.e. frame_idx - 1) |
|
||||||
prev_frame_idx = frame_idx - t_rel |
|
||||||
else: |
|
||||||
# the frame immediately after this frame (i.e. frame_idx + 1) |
|
||||||
prev_frame_idx = frame_idx + t_rel |
|
||||||
else: |
|
||||||
# for t_rel >= 2, we take the memory frame from every r-th frames |
|
||||||
if not track_in_reverse: |
|
||||||
# first find the nearest frame among every r-th frames before this frame |
|
||||||
# for r=1, this would be (frame_idx - 2) |
|
||||||
prev_frame_idx = ((frame_idx - 2) // r) * r |
|
||||||
# then seek further among every r-th frames |
|
||||||
prev_frame_idx = prev_frame_idx - (t_rel - 2) * r |
|
||||||
else: |
|
||||||
# first find the nearest frame among every r-th frames after this frame |
|
||||||
# for r=1, this would be (frame_idx + 2) |
|
||||||
prev_frame_idx = -(-(frame_idx + 2) // r) * r |
|
||||||
# then seek further among every r-th frames |
|
||||||
prev_frame_idx = prev_frame_idx + (t_rel - 2) * r |
|
||||||
out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) |
|
||||||
if out is None: |
|
||||||
# If an unselected conditioning frame is among the last (self.num_maskmem - 1) |
|
||||||
# frames, we still attend to it as if it's a non-conditioning frame. |
|
||||||
out = unselected_cond_outputs.get(prev_frame_idx, None) |
|
||||||
t_pos_and_prevs.append((t_pos, out)) |
|
||||||
|
|
||||||
for t_pos, prev in t_pos_and_prevs: |
|
||||||
if prev is None: |
|
||||||
continue # skip padding frames |
|
||||||
# "maskmem_features" might have been offloaded to CPU in demo use cases, |
|
||||||
# so we load it back to GPU (it's a no-op if it's already on GPU). |
|
||||||
feats = prev["maskmem_features"].cuda(non_blocking=True) |
|
||||||
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) |
|
||||||
# Spatial positional encoding (it might have been offloaded to CPU in eval) |
|
||||||
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda() |
|
||||||
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) |
|
||||||
# Temporal positional encoding |
|
||||||
maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] |
|
||||||
to_cat_memory_pos_embed.append(maskmem_enc) |
|
||||||
|
|
||||||
# Construct the list of past object pointers |
|
||||||
if self.use_obj_ptrs_in_encoder: |
|
||||||
max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) |
|
||||||
# First add those object pointers from selected conditioning frames |
|
||||||
# (optionally, only include object pointers in the past during evaluation) |
|
||||||
if not self.training and self.only_obj_ptrs_in_the_past_for_eval: |
|
||||||
ptr_cond_outputs = { |
|
||||||
t: out |
|
||||||
for t, out in selected_cond_outputs.items() |
|
||||||
if (t >= frame_idx if track_in_reverse else t <= frame_idx) |
|
||||||
} |
|
||||||
else: |
|
||||||
ptr_cond_outputs = selected_cond_outputs |
|
||||||
pos_and_ptrs = [ |
|
||||||
# Temporal pos encoding contains how far away each pointer is from current frame |
|
||||||
(abs(frame_idx - t), out["obj_ptr"]) |
|
||||||
for t, out in ptr_cond_outputs.items() |
|
||||||
] |
|
||||||
# Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame |
|
||||||
for t_diff in range(1, max_obj_ptrs_in_encoder): |
|
||||||
t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff |
|
||||||
if t < 0 or (num_frames is not None and t >= num_frames): |
|
||||||
break |
|
||||||
out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None)) |
|
||||||
if out is not None: |
|
||||||
pos_and_ptrs.append((t_diff, out["obj_ptr"])) |
|
||||||
# If we have at least one object pointer, add them to the across attention |
|
||||||
if len(pos_and_ptrs) > 0: |
|
||||||
pos_list, ptrs_list = zip(*pos_and_ptrs) |
|
||||||
# stack object pointers along dim=0 into [ptr_seq_len, B, C] shape |
|
||||||
obj_ptrs = torch.stack(ptrs_list, dim=0) |
|
||||||
# a temporal positional embedding based on how far each object pointer is from |
|
||||||
# the current frame (sine embedding normalized by the max pointer num). |
|
||||||
if self.add_tpos_enc_to_obj_ptrs: |
|
||||||
t_diff_max = max_obj_ptrs_in_encoder - 1 |
|
||||||
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim |
|
||||||
obj_pos = torch.tensor(pos_list, device=device) |
|
||||||
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) |
|
||||||
obj_pos = self.obj_ptr_tpos_proj(obj_pos) |
|
||||||
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) |
|
||||||
else: |
|
||||||
obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) |
|
||||||
if self.mem_dim < C: |
|
||||||
# split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C |
|
||||||
obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim) |
|
||||||
obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) |
|
||||||
obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) |
|
||||||
to_cat_memory.append(obj_ptrs) |
|
||||||
to_cat_memory_pos_embed.append(obj_pos) |
|
||||||
num_obj_ptr_tokens = obj_ptrs.shape[0] |
|
||||||
else: |
|
||||||
num_obj_ptr_tokens = 0 |
|
||||||
else: |
|
||||||
# for initial conditioning frames, encode them without using any previous memory |
|
||||||
if self.directly_add_no_mem_embed: |
|
||||||
# directly add no-mem embedding (instead of using the transformer encoder) |
|
||||||
pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed |
|
||||||
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) |
|
||||||
return pix_feat_with_mem |
|
||||||
|
|
||||||
# Use a dummy token on the first frame (to avoid empty memory input to transformer encoder) |
|
||||||
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] |
|
||||||
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] |
|
||||||
|
|
||||||
# Step 2: Concatenate the memories and forward through the transformer encoder |
|
||||||
memory = torch.cat(to_cat_memory, dim=0) |
|
||||||
memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) |
|
||||||
|
|
||||||
pix_feat_with_mem = self.memory_attention( |
|
||||||
curr=current_vision_feats, |
|
||||||
curr_pos=current_vision_pos_embeds, |
|
||||||
memory=memory, |
|
||||||
memory_pos=memory_pos_embed, |
|
||||||
num_obj_ptr_tokens=num_obj_ptr_tokens, |
|
||||||
) |
|
||||||
# reshape the output (HW)BC => BCHW |
|
||||||
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) |
|
||||||
return pix_feat_with_mem |
|
||||||
|
|
||||||
def _encode_new_memory( |
|
||||||
self, |
|
||||||
current_vision_feats, |
|
||||||
feat_sizes, |
|
||||||
pred_masks_high_res, |
|
||||||
is_mask_from_pts, |
|
||||||
): |
|
||||||
"""Encodes the current frame's features and predicted masks into a new memory representation.""" |
|
||||||
B = current_vision_feats[-1].size(1) # batch size on this frame |
|
||||||
C = self.hidden_dim |
|
||||||
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size |
|
||||||
# top-level feature, (HW)BC => BCHW |
|
||||||
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) |
|
||||||
if self.non_overlap_masks_for_mem_enc and not self.training: |
|
||||||
# optionally, apply non-overlapping constraints to the masks (it's applied |
|
||||||
# in the batch dimension and should only be used during eval, where all |
|
||||||
# the objects come from the same video under batch size 1). |
|
||||||
pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res) |
|
||||||
# scale the raw mask logits with a temperature before applying sigmoid |
|
||||||
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts |
|
||||||
if binarize and not self.training: |
|
||||||
mask_for_mem = (pred_masks_high_res > 0).float() |
|
||||||
else: |
|
||||||
# apply sigmoid on the raw mask logits to turn them into range (0, 1) |
|
||||||
mask_for_mem = torch.sigmoid(pred_masks_high_res) |
|
||||||
# apply scale and bias terms to the sigmoid probabilities |
|
||||||
if self.sigmoid_scale_for_mem_enc != 1.0: |
|
||||||
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc |
|
||||||
if self.sigmoid_bias_for_mem_enc != 0.0: |
|
||||||
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc |
|
||||||
maskmem_out = self.memory_encoder( |
|
||||||
pix_feat, |
|
||||||
mask_for_mem, |
|
||||||
skip_mask_sigmoid=True, # sigmoid already applied |
|
||||||
) |
|
||||||
maskmem_features = maskmem_out["vision_features"] |
|
||||||
maskmem_pos_enc = maskmem_out["vision_pos_enc"] |
|
||||||
|
|
||||||
return maskmem_features, maskmem_pos_enc |
|
||||||
|
|
||||||
def track_step( |
|
||||||
self, |
|
||||||
frame_idx, |
|
||||||
is_init_cond_frame, |
|
||||||
current_vision_feats, |
|
||||||
current_vision_pos_embeds, |
|
||||||
feat_sizes, |
|
||||||
point_inputs, |
|
||||||
mask_inputs, |
|
||||||
output_dict, |
|
||||||
num_frames, |
|
||||||
track_in_reverse=False, # tracking in reverse time order (for demo usage) |
|
||||||
# Whether to run the memory encoder on the predicted masks. Sometimes we might want |
|
||||||
# to skip the memory encoder with `run_mem_encoder=False`. For example, |
|
||||||
# in demo we might call `track_step` multiple times for each user click, |
|
||||||
# and only encode the memory when the user finalizes their clicks. And in ablation |
|
||||||
# settings like SAM training on static images, we don't need the memory encoder. |
|
||||||
run_mem_encoder=True, |
|
||||||
# The previously predicted SAM mask logits (which can be fed together with new clicks in demo). |
|
||||||
prev_sam_mask_logits=None, |
|
||||||
): |
|
||||||
"""Performs a single tracking step, updating object masks and memory features based on current frame inputs.""" |
|
||||||
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} |
|
||||||
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW |
|
||||||
if len(current_vision_feats) > 1: |
|
||||||
high_res_features = [ |
|
||||||
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) |
|
||||||
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) |
|
||||||
] |
|
||||||
else: |
|
||||||
high_res_features = None |
|
||||||
if mask_inputs is not None and self.use_mask_input_as_output_without_sam: |
|
||||||
# When use_mask_input_as_output_without_sam=True, we directly output the mask input |
|
||||||
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder. |
|
||||||
pix_feat = current_vision_feats[-1].permute(1, 2, 0) |
|
||||||
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) |
|
||||||
sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs) |
|
||||||
else: |
|
||||||
# fused the visual feature with previous memory features in the memory bank |
|
||||||
pix_feat_with_mem = self._prepare_memory_conditioned_features( |
|
||||||
frame_idx=frame_idx, |
|
||||||
is_init_cond_frame=is_init_cond_frame, |
|
||||||
current_vision_feats=current_vision_feats[-1:], |
|
||||||
current_vision_pos_embeds=current_vision_pos_embeds[-1:], |
|
||||||
feat_sizes=feat_sizes[-1:], |
|
||||||
output_dict=output_dict, |
|
||||||
num_frames=num_frames, |
|
||||||
track_in_reverse=track_in_reverse, |
|
||||||
) |
|
||||||
# apply SAM-style segmentation head |
|
||||||
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, |
|
||||||
# e.g. in demo where such logits come from earlier interaction instead of correction sampling |
|
||||||
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) |
|
||||||
if prev_sam_mask_logits is not None: |
|
||||||
assert point_inputs is not None and mask_inputs is None |
|
||||||
mask_inputs = prev_sam_mask_logits |
|
||||||
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) |
|
||||||
sam_outputs = self._forward_sam_heads( |
|
||||||
backbone_features=pix_feat_with_mem, |
|
||||||
point_inputs=point_inputs, |
|
||||||
mask_inputs=mask_inputs, |
|
||||||
high_res_features=high_res_features, |
|
||||||
multimask_output=multimask_output, |
|
||||||
) |
|
||||||
( |
|
||||||
_, |
|
||||||
_, |
|
||||||
_, |
|
||||||
low_res_masks, |
|
||||||
high_res_masks, |
|
||||||
obj_ptr, |
|
||||||
_, |
|
||||||
) = sam_outputs |
|
||||||
|
|
||||||
current_out["pred_masks"] = low_res_masks |
|
||||||
current_out["pred_masks_high_res"] = high_res_masks |
|
||||||
current_out["obj_ptr"] = obj_ptr |
|
||||||
|
|
||||||
# Finally run the memory encoder on the predicted mask to encode |
|
||||||
# it into a new memory feature (that can be used in future frames) |
|
||||||
if run_mem_encoder and self.num_maskmem > 0: |
|
||||||
high_res_masks_for_mem_enc = high_res_masks |
|
||||||
maskmem_features, maskmem_pos_enc = self._encode_new_memory( |
|
||||||
current_vision_feats=current_vision_feats, |
|
||||||
feat_sizes=feat_sizes, |
|
||||||
pred_masks_high_res=high_res_masks_for_mem_enc, |
|
||||||
is_mask_from_pts=(point_inputs is not None), |
|
||||||
) |
|
||||||
current_out["maskmem_features"] = maskmem_features |
|
||||||
current_out["maskmem_pos_enc"] = maskmem_pos_enc |
|
||||||
else: |
|
||||||
current_out["maskmem_features"] = None |
|
||||||
current_out["maskmem_pos_enc"] = None |
|
||||||
|
|
||||||
return current_out |
|
||||||
|
|
||||||
def _use_multimask(self, is_init_cond_frame, point_inputs): |
|
||||||
"""Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs.""" |
|
||||||
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) |
|
||||||
multimask_output = ( |
|
||||||
self.multimask_output_in_sam |
|
||||||
and (is_init_cond_frame or self.multimask_output_for_tracking) |
|
||||||
and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) |
|
||||||
) |
|
||||||
return multimask_output |
|
||||||
|
|
||||||
def _apply_non_overlapping_constraints(self, pred_masks): |
|
||||||
"""Applies non-overlapping constraints to object masks, keeping highest scoring object at each location.""" |
|
||||||
batch_size = pred_masks.size(0) |
|
||||||
if batch_size == 1: |
|
||||||
return pred_masks |
|
||||||
|
|
||||||
device = pred_masks.device |
|
||||||
# "max_obj_inds": object index of the object with the highest score at each location |
|
||||||
max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) |
|
||||||
# "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` |
|
||||||
batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] |
|
||||||
keep = max_obj_inds == batch_obj_inds |
|
||||||
# suppress overlapping regions' scores below -10.0 so that the foreground regions |
|
||||||
# don't overlap (here sigmoid(-10.0)=4.5398e-05) |
|
||||||
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) |
|
||||||
return pred_masks |
|
@ -1,715 +0,0 @@ |
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
|
||||||
|
|
||||||
import copy |
|
||||||
import math |
|
||||||
from functools import partial |
|
||||||
from typing import Optional, Tuple, Type, Union |
|
||||||
|
|
||||||
import torch |
|
||||||
import torch.nn.functional as F |
|
||||||
from torch import Tensor, nn |
|
||||||
|
|
||||||
from ultralytics.models.sam.modules.transformer import ( |
|
||||||
Attention, |
|
||||||
) |
|
||||||
from ultralytics.models.sam.modules.transformer import ( |
|
||||||
TwoWayAttentionBlock as SAMTwoWayAttentionBlock, |
|
||||||
) |
|
||||||
from ultralytics.models.sam.modules.transformer import ( |
|
||||||
TwoWayTransformer as SAMTwoWayTransformer, |
|
||||||
) |
|
||||||
from ultralytics.nn.modules import MLP, LayerNorm2d |
|
||||||
|
|
||||||
from .utils import apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition |
|
||||||
|
|
||||||
|
|
||||||
class DropPath(nn.Module): |
|
||||||
"""Implements stochastic depth regularization for neural networks during training.""" |
|
||||||
|
|
||||||
def __init__(self, drop_prob=0.0, scale_by_keep=True): |
|
||||||
"""Initialize DropPath module with specified drop probability and scaling option.""" |
|
||||||
super(DropPath, self).__init__() |
|
||||||
self.drop_prob = drop_prob |
|
||||||
self.scale_by_keep = scale_by_keep |
|
||||||
|
|
||||||
def forward(self, x): |
|
||||||
"""Applies stochastic depth to input tensor during training, with optional scaling.""" |
|
||||||
if self.drop_prob == 0.0 or not self.training: |
|
||||||
return x |
|
||||||
keep_prob = 1 - self.drop_prob |
|
||||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
|
||||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob) |
|
||||||
if keep_prob > 0.0 and self.scale_by_keep: |
|
||||||
random_tensor.div_(keep_prob) |
|
||||||
return x * random_tensor |
|
||||||
|
|
||||||
|
|
||||||
class MaskDownSampler(nn.Module): |
|
||||||
"""Downsamples and embeds masks using convolutional layers and layer normalization for efficient processing.""" |
|
||||||
|
|
||||||
def __init__( |
|
||||||
self, |
|
||||||
embed_dim=256, |
|
||||||
kernel_size=4, |
|
||||||
stride=4, |
|
||||||
padding=0, |
|
||||||
total_stride=16, |
|
||||||
activation=nn.GELU, |
|
||||||
): |
|
||||||
"""Initializes a mask downsampler module for progressive downsampling and channel expansion.""" |
|
||||||
super().__init__() |
|
||||||
num_layers = int(math.log2(total_stride) // math.log2(stride)) |
|
||||||
assert stride**num_layers == total_stride |
|
||||||
self.encoder = nn.Sequential() |
|
||||||
mask_in_chans, mask_out_chans = 1, 1 |
|
||||||
for _ in range(num_layers): |
|
||||||
mask_out_chans = mask_in_chans * (stride**2) |
|
||||||
self.encoder.append( |
|
||||||
nn.Conv2d( |
|
||||||
mask_in_chans, |
|
||||||
mask_out_chans, |
|
||||||
kernel_size=kernel_size, |
|
||||||
stride=stride, |
|
||||||
padding=padding, |
|
||||||
) |
|
||||||
) |
|
||||||
self.encoder.append(LayerNorm2d(mask_out_chans)) |
|
||||||
self.encoder.append(activation()) |
|
||||||
mask_in_chans = mask_out_chans |
|
||||||
|
|
||||||
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) |
|
||||||
|
|
||||||
def forward(self, x): |
|
||||||
"""Downsamples and encodes input mask to embed_dim channels using convolutional layers and LayerNorm2d.""" |
|
||||||
return self.encoder(x) |
|
||||||
|
|
||||||
|
|
||||||
# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) |
|
||||||
class CXBlock(nn.Module): |
|
||||||
""" |
|
||||||
ConvNeXt Block for efficient feature extraction in convolutional neural networks. |
|
||||||
|
|
||||||
This block implements a modified version of the ConvNeXt architecture, offering two equivalent |
|
||||||
implementations for improved performance and flexibility. |
|
||||||
|
|
||||||
Attributes: |
|
||||||
dwconv (nn.Conv2d): Depthwise convolution layer. |
|
||||||
norm (LayerNorm2d): Layer normalization applied to channels. |
|
||||||
pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer. |
|
||||||
act (nn.GELU): GELU activation function. |
|
||||||
pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer. |
|
||||||
gamma (nn.Parameter | None): Learnable scale parameter for layer scaling. |
|
||||||
drop_path (nn.Module): DropPath layer for stochastic depth regularization. |
|
||||||
|
|
||||||
Methods: |
|
||||||
forward: Processes the input tensor through the ConvNeXt block. |
|
||||||
|
|
||||||
Examples: |
|
||||||
>>> import torch |
|
||||||
>>> x = torch.randn(1, 64, 56, 56) |
|
||||||
>>> block = CXBlock(dim=64, kernel_size=7, padding=3) |
|
||||||
>>> output = block(x) |
|
||||||
>>> print(output.shape) |
|
||||||
torch.Size([1, 64, 56, 56]) |
|
||||||
""" |
|
||||||
|
|
||||||
def __init__( |
|
||||||
self, |
|
||||||
dim, |
|
||||||
kernel_size=7, |
|
||||||
padding=3, |
|
||||||
drop_path=0.0, |
|
||||||
layer_scale_init_value=1e-6, |
|
||||||
use_dwconv=True, |
|
||||||
): |
|
||||||
""" |
|
||||||
Initialize a ConvNeXt Block. |
|
||||||
|
|
||||||
This block implements a ConvNeXt architecture with optional depthwise convolution, layer normalization, |
|
||||||
pointwise convolutions, and GELU activation. |
|
||||||
|
|
||||||
Args: |
|
||||||
dim (int): Number of input channels. |
|
||||||
kernel_size (int): Size of the convolutional kernel. Default is 7. |
|
||||||
padding (int): Padding size for the convolution. Default is 3. |
|
||||||
drop_path (float): Stochastic depth rate. Default is 0.0. |
|
||||||
layer_scale_init_value (float): Initial value for Layer Scale. Default is 1e-6. |
|
||||||
use_dwconv (bool): Whether to use depthwise convolution. Default is True. |
|
||||||
|
|
||||||
Attributes: |
|
||||||
dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer. |
|
||||||
norm (LayerNorm2d): Layer normalization applied to the output of dwconv. |
|
||||||
pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer. |
|
||||||
act (nn.GELU): GELU activation function. |
|
||||||
pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer. |
|
||||||
gamma (nn.Parameter | None): Learnable scale parameter for the residual path. |
|
||||||
|
|
||||||
Examples: |
|
||||||
>>> block = CXBlock(dim=64, kernel_size=7, padding=3) |
|
||||||
>>> x = torch.randn(1, 64, 32, 32) |
|
||||||
>>> output = block(x) |
|
||||||
>>> print(output.shape) |
|
||||||
torch.Size([1, 64, 32, 32]) |
|
||||||
""" |
|
||||||
super().__init__() |
|
||||||
self.dwconv = nn.Conv2d( |
|
||||||
dim, |
|
||||||
dim, |
|
||||||
kernel_size=kernel_size, |
|
||||||
padding=padding, |
|
||||||
groups=dim if use_dwconv else 1, |
|
||||||
) # depthwise conv |
|
||||||
self.norm = LayerNorm2d(dim, eps=1e-6) |
|
||||||
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers |
|
||||||
self.act = nn.GELU() |
|
||||||
self.pwconv2 = nn.Linear(4 * dim, dim) |
|
||||||
self.gamma = ( |
|
||||||
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) |
|
||||||
if layer_scale_init_value > 0 |
|
||||||
else None |
|
||||||
) |
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
||||||
|
|
||||||
def forward(self, x): |
|
||||||
"""Applies ConvNeXt block operations to input tensor, including convolutions and residual connection.""" |
|
||||||
input = x |
|
||||||
x = self.dwconv(x) |
|
||||||
x = self.norm(x) |
|
||||||
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) |
|
||||||
x = self.pwconv1(x) |
|
||||||
x = self.act(x) |
|
||||||
x = self.pwconv2(x) |
|
||||||
if self.gamma is not None: |
|
||||||
x = self.gamma * x |
|
||||||
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) |
|
||||||
|
|
||||||
x = input + self.drop_path(x) |
|
||||||
return x |
|
||||||
|
|
||||||
|
|
||||||
class Fuser(nn.Module): |
|
||||||
""" |
|
||||||
A module for fusing features through multiple layers of a neural network. |
|
||||||
|
|
||||||
This class applies a series of identical layers to an input tensor, optionally projecting the input first. |
|
||||||
|
|
||||||
Attributes: |
|
||||||
proj (nn.Module): An optional input projection layer. Identity if no projection is needed. |
|
||||||
layers (nn.ModuleList): A list of identical layers to be applied sequentially. |
|
||||||
|
|
||||||
Methods: |
|
||||||
forward: Applies the fuser to an input tensor. |
|
||||||
|
|
||||||
Examples: |
|
||||||
>>> layer = CXBlock(dim=256) |
|
||||||
>>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True) |
|
||||||
>>> x = torch.randn(1, 256, 32, 32) |
|
||||||
>>> output = fuser(x) |
|
||||||
>>> print(output.shape) |
|
||||||
torch.Size([1, 256, 32, 32]) |
|
||||||
""" |
|
||||||
|
|
||||||
def __init__(self, layer, num_layers, dim=None, input_projection=False): |
|
||||||
""" |
|
||||||
Initializes the Fuser module. |
|
||||||
|
|
||||||
This module creates a sequence of identical layers and optionally applies an input projection. |
|
||||||
|
|
||||||
Args: |
|
||||||
layer (nn.Module): The layer to be replicated in the fuser. |
|
||||||
num_layers (int): The number of times to replicate the layer. |
|
||||||
dim (int | None): The dimension for input projection, if used. |
|
||||||
input_projection (bool): Whether to use input projection. |
|
||||||
|
|
||||||
Attributes: |
|
||||||
proj (nn.Module): The input projection layer, or nn.Identity if not used. |
|
||||||
layers (nn.ModuleList): A list of replicated layers. |
|
||||||
|
|
||||||
Examples: |
|
||||||
>>> layer = nn.Linear(64, 64) |
|
||||||
>>> fuser = Fuser(layer, num_layers=3, dim=64, input_projection=True) |
|
||||||
>>> input_tensor = torch.randn(1, 64) |
|
||||||
>>> output = fuser(input_tensor) |
|
||||||
""" |
|
||||||
super().__init__() |
|
||||||
self.proj = nn.Identity() |
|
||||||
self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)]) |
|
||||||
|
|
||||||
if input_projection: |
|
||||||
assert dim is not None |
|
||||||
self.proj = nn.Conv2d(dim, dim, kernel_size=1) |
|
||||||
|
|
||||||
def forward(self, x): |
|
||||||
"""Applies a series of layers to the input tensor, optionally projecting it first.""" |
|
||||||
x = self.proj(x) |
|
||||||
for layer in self.layers: |
|
||||||
x = layer(x) |
|
||||||
return x |
|
||||||
|
|
||||||
|
|
||||||
class TwoWayAttentionBlock(SAMTwoWayAttentionBlock): |
|
||||||
""" |
|
||||||
A two-way attention block for performing self-attention and cross-attention in both directions. |
|
||||||
|
|
||||||
This block extends the SAMTwoWayAttentionBlock and consists of four main components: self-attention on |
|
||||||
sparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and |
|
||||||
cross-attention from dense to sparse inputs. |
|
||||||
|
|
||||||
Attributes: |
|
||||||
self_attn (Attention): Self-attention layer for queries. |
|
||||||
norm1 (nn.LayerNorm): Layer normalization after the first attention block. |
|
||||||
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys. |
|
||||||
norm2 (nn.LayerNorm): Layer normalization after the second attention block. |
|
||||||
mlp (MLP): MLP block for transforming query embeddings. |
|
||||||
norm3 (nn.LayerNorm): Layer normalization after the MLP block. |
|
||||||
norm4 (nn.LayerNorm): Layer normalization after the third attention block. |
|
||||||
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries. |
|
||||||
skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer. |
|
||||||
|
|
||||||
Methods: |
|
||||||
forward: Processes input through the attention blocks and MLP. |
|
||||||
|
|
||||||
Examples: |
|
||||||
>>> block = TwoWayAttentionBlock(embedding_dim=256, num_heads=8) |
|
||||||
>>> sparse_input = torch.randn(1, 100, 256) |
|
||||||
>>> dense_input = torch.randn(1, 256, 16, 16) |
|
||||||
>>> sparse_output, dense_output = block(sparse_input, dense_input) |
|
||||||
""" |
|
||||||
|
|
||||||
def __init__( |
|
||||||
self, |
|
||||||
embedding_dim: int, |
|
||||||
num_heads: int, |
|
||||||
mlp_dim: int = 2048, |
|
||||||
activation: Type[nn.Module] = nn.ReLU, |
|
||||||
attention_downsample_rate: int = 2, |
|
||||||
skip_first_layer_pe: bool = False, |
|
||||||
) -> None: |
|
||||||
""" |
|
||||||
Initializes a TwoWayAttentionBlock for performing self-attention and cross-attention in two directions. |
|
||||||
|
|
||||||
This block consists of four main layers: self-attention on sparse inputs, cross-attention of sparse inputs |
|
||||||
to dense inputs, an MLP block on sparse inputs, and cross-attention of dense inputs to sparse inputs. |
|
||||||
|
|
||||||
Args: |
|
||||||
embedding_dim (int): The channel dimension of the embeddings. |
|
||||||
num_heads (int): The number of heads in the attention layers. |
|
||||||
mlp_dim (int): The hidden dimension of the MLP block. |
|
||||||
activation (Type[nn.Module]): The activation function of the MLP block. |
|
||||||
attention_downsample_rate (int): The downsample rate for attention computations. |
|
||||||
skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer. |
|
||||||
|
|
||||||
Attributes: |
|
||||||
self_attn (Attention): The self-attention layer for the queries. |
|
||||||
norm1 (nn.LayerNorm): Layer normalization following the first attention block. |
|
||||||
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys. |
|
||||||
norm2 (nn.LayerNorm): Layer normalization following the second attention block. |
|
||||||
mlp (MLP): MLP block that transforms the query embeddings. |
|
||||||
norm3 (nn.LayerNorm): Layer normalization following the MLP block. |
|
||||||
norm4 (nn.LayerNorm): Layer normalization following the third attention block. |
|
||||||
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries. |
|
||||||
skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer. |
|
||||||
|
|
||||||
Examples: |
|
||||||
>>> block = TwoWayAttentionBlock(embedding_dim=256, num_heads=8, mlp_dim=2048) |
|
||||||
>>> sparse_inputs = torch.randn(1, 100, 256) |
|
||||||
>>> dense_inputs = torch.randn(1, 256, 32, 32) |
|
||||||
>>> sparse_outputs, dense_outputs = block(sparse_inputs, dense_inputs) |
|
||||||
""" |
|
||||||
super().__init__(embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate, skip_first_layer_pe) |
|
||||||
self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, act=activation) |
|
||||||
|
|
||||||
|
|
||||||
class TwoWayTransformer(SAMTwoWayTransformer): |
|
||||||
""" |
|
||||||
A Two-Way Transformer module for simultaneous attention to image and query points. |
|
||||||
|
|
||||||
This class implements a specialized transformer decoder that attends to an input image using queries with |
|
||||||
supplied positional embeddings. It is particularly useful for tasks like object detection, image |
|
||||||
segmentation, and point cloud processing. |
|
||||||
|
|
||||||
Attributes: |
|
||||||
depth (int): Number of layers in the transformer. |
|
||||||
embedding_dim (int): Channel dimension for input embeddings. |
|
||||||
num_heads (int): Number of heads for multihead attention. |
|
||||||
mlp_dim (int): Internal channel dimension for the MLP block. |
|
||||||
layers (nn.ModuleList): List of TwoWayAttentionBlock layers comprising the transformer. |
|
||||||
final_attn_token_to_image (Attention): Final attention layer from queries to image. |
|
||||||
norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries. |
|
||||||
|
|
||||||
Methods: |
|
||||||
forward: Processes input image embeddings and query embeddings through the transformer. |
|
||||||
|
|
||||||
Examples: |
|
||||||
>>> transformer = TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048) |
|
||||||
>>> image_embedding = torch.randn(1, 256, 64, 64) |
|
||||||
>>> query_embedding = torch.randn(1, 100, 256) |
|
||||||
>>> output = transformer(image_embedding, query_embedding) |
|
||||||
""" |
|
||||||
|
|
||||||
def __init__( |
|
||||||
self, |
|
||||||
depth: int, |
|
||||||
embedding_dim: int, |
|
||||||
num_heads: int, |
|
||||||
mlp_dim: int, |
|
||||||
activation: Type[nn.Module] = nn.ReLU, |
|
||||||
attention_downsample_rate: int = 2, |
|
||||||
) -> None: |
|
||||||
""" |
|
||||||
Initializes a TwoWayTransformer instance. |
|
||||||
|
|
||||||
This transformer decoder attends to an input image using queries with supplied positional embeddings. |
|
||||||
It is designed for tasks like object detection, image segmentation, and point cloud processing. |
|
||||||
|
|
||||||
Args: |
|
||||||
depth (int): Number of layers in the transformer. |
|
||||||
embedding_dim (int): Channel dimension for the input embeddings. |
|
||||||
num_heads (int): Number of heads for multihead attention. Must divide embedding_dim. |
|
||||||
mlp_dim (int): Channel dimension internal to the MLP block. |
|
||||||
activation (Type[nn.Module]): Activation function to use in the MLP block. |
|
||||||
attention_downsample_rate (int): Downsampling rate for attention computations. |
|
||||||
|
|
||||||
Attributes: |
|
||||||
depth (int): Number of layers in the transformer. |
|
||||||
embedding_dim (int): Channel dimension for the input embeddings. |
|
||||||
num_heads (int): Number of heads for multihead attention. |
|
||||||
mlp_dim (int): Internal channel dimension for the MLP block. |
|
||||||
layers (nn.ModuleList): List of TwoWayAttentionBlock layers comprising the transformer. |
|
||||||
final_attn_token_to_image (Attention): Final attention layer from queries to image. |
|
||||||
norm_final_attn (nn.LayerNorm): Layer normalization applied to the final queries. |
|
||||||
|
|
||||||
Examples: |
|
||||||
>>> transformer = TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048) |
|
||||||
>>> transformer |
|
||||||
TwoWayTransformer( |
|
||||||
(layers): ModuleList( |
|
||||||
(0-4): 5 x TwoWayAttentionBlock(...) |
|
||||||
) |
|
||||||
(final_attn_token_to_image): Attention(...) |
|
||||||
(norm_final_attn): LayerNorm(...) |
|
||||||
) |
|
||||||
""" |
|
||||||
super().__init__(depth, embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate) |
|
||||||
self.layers = nn.ModuleList() |
|
||||||
for i in range(depth): |
|
||||||
self.layers.append( |
|
||||||
TwoWayAttentionBlock( |
|
||||||
embedding_dim=embedding_dim, |
|
||||||
num_heads=num_heads, |
|
||||||
mlp_dim=mlp_dim, |
|
||||||
activation=activation, |
|
||||||
attention_downsample_rate=attention_downsample_rate, |
|
||||||
skip_first_layer_pe=(i == 0), |
|
||||||
) |
|
||||||
) |
|
||||||
|
|
||||||
|
|
||||||
class RoPEAttention(Attention): |
|
||||||
"""Implements rotary position encoding for attention mechanisms in transformer architectures.""" |
|
||||||
|
|
||||||
def __init__( |
|
||||||
self, |
|
||||||
*args, |
|
||||||
rope_theta=10000.0, |
|
||||||
# whether to repeat q rope to match k length |
|
||||||
# this is needed for cross-attention to memories |
|
||||||
rope_k_repeat=False, |
|
||||||
feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution |
|
||||||
**kwargs, |
|
||||||
): |
|
||||||
"""Initializes RoPEAttention with rotary position encoding for attention mechanisms.""" |
|
||||||
super().__init__(*args, **kwargs) |
|
||||||
|
|
||||||
self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta) |
|
||||||
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) |
|
||||||
self.freqs_cis = freqs_cis |
|
||||||
self.rope_k_repeat = rope_k_repeat |
|
||||||
|
|
||||||
def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor: |
|
||||||
"""Applies rotary position encoding and computes attention between query, key, and value tensors.""" |
|
||||||
q = self.q_proj(q) |
|
||||||
k = self.k_proj(k) |
|
||||||
v = self.v_proj(v) |
|
||||||
|
|
||||||
# Separate into heads |
|
||||||
q = self._separate_heads(q, self.num_heads) |
|
||||||
k = self._separate_heads(k, self.num_heads) |
|
||||||
v = self._separate_heads(v, self.num_heads) |
|
||||||
|
|
||||||
# Apply rotary position encoding |
|
||||||
w = h = math.sqrt(q.shape[-2]) |
|
||||||
self.freqs_cis = self.freqs_cis.to(q.device) |
|
||||||
if self.freqs_cis.shape[0] != q.shape[-2]: |
|
||||||
self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) |
|
||||||
if q.shape[-2] != k.shape[-2]: |
|
||||||
assert self.rope_k_repeat |
|
||||||
|
|
||||||
num_k_rope = k.size(-2) - num_k_exclude_rope |
|
||||||
q, k[:, :, :num_k_rope] = apply_rotary_enc( |
|
||||||
q, |
|
||||||
k[:, :, :num_k_rope], |
|
||||||
freqs_cis=self.freqs_cis, |
|
||||||
repeat_freqs_k=self.rope_k_repeat, |
|
||||||
) |
|
||||||
|
|
||||||
# Attention |
|
||||||
_, _, _, c_per_head = q.shape |
|
||||||
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens |
|
||||||
attn = attn / math.sqrt(c_per_head) |
|
||||||
attn = torch.softmax(attn, dim=-1) |
|
||||||
|
|
||||||
# Get output |
|
||||||
out = attn @ v |
|
||||||
|
|
||||||
out = self._recombine_heads(out) |
|
||||||
out = self.out_proj(out) |
|
||||||
|
|
||||||
return out |
|
||||||
|
|
||||||
|
|
||||||
def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: |
|
||||||
"""Applies pooling and optional normalization to a tensor, handling permutations for spatial operations.""" |
|
||||||
if pool is None: |
|
||||||
return x |
|
||||||
# (B, H, W, C) -> (B, C, H, W) |
|
||||||
x = x.permute(0, 3, 1, 2) |
|
||||||
x = pool(x) |
|
||||||
# (B, C, H', W') -> (B, H', W', C) |
|
||||||
x = x.permute(0, 2, 3, 1) |
|
||||||
if norm: |
|
||||||
x = norm(x) |
|
||||||
|
|
||||||
return x |
|
||||||
|
|
||||||
|
|
||||||
class MultiScaleAttention(nn.Module): |
|
||||||
"""Implements multi-scale self-attention with optional query pooling for efficient feature extraction.""" |
|
||||||
|
|
||||||
def __init__( |
|
||||||
self, |
|
||||||
dim: int, |
|
||||||
dim_out: int, |
|
||||||
num_heads: int, |
|
||||||
q_pool: nn.Module = None, |
|
||||||
): |
|
||||||
"""Initializes a multi-scale attention module with configurable query pooling and linear projections.""" |
|
||||||
super().__init__() |
|
||||||
|
|
||||||
self.dim = dim |
|
||||||
self.dim_out = dim_out |
|
||||||
|
|
||||||
self.num_heads = num_heads |
|
||||||
head_dim = dim_out // num_heads |
|
||||||
self.scale = head_dim**-0.5 |
|
||||||
|
|
||||||
self.q_pool = q_pool |
|
||||||
self.qkv = nn.Linear(dim, dim_out * 3) |
|
||||||
self.proj = nn.Linear(dim_out, dim_out) |
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
||||||
"""Applies multi-scale attention to input tensor, optionally downsampling query features.""" |
|
||||||
B, H, W, _ = x.shape |
|
||||||
# qkv with shape (B, H * W, 3, nHead, C) |
|
||||||
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) |
|
||||||
# q, k, v with shape (B, H * W, nheads, C) |
|
||||||
q, k, v = torch.unbind(qkv, 2) |
|
||||||
|
|
||||||
# Q pooling (for downsample at stage changes) |
|
||||||
if self.q_pool: |
|
||||||
q = do_pool(q.reshape(B, H, W, -1), self.q_pool) |
|
||||||
H, W = q.shape[1:3] # downsampled shape |
|
||||||
q = q.reshape(B, H * W, self.num_heads, -1) |
|
||||||
|
|
||||||
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose |
|
||||||
x = F.scaled_dot_product_attention( |
|
||||||
q.transpose(1, 2), |
|
||||||
k.transpose(1, 2), |
|
||||||
v.transpose(1, 2), |
|
||||||
) |
|
||||||
# Transpose back |
|
||||||
x = x.transpose(1, 2) |
|
||||||
x = x.reshape(B, H, W, -1) |
|
||||||
|
|
||||||
x = self.proj(x) |
|
||||||
|
|
||||||
return x |
|
||||||
|
|
||||||
|
|
||||||
class MultiScaleBlock(nn.Module): |
|
||||||
"""Multiscale attention block with window partitioning and query pooling for efficient vision transformers.""" |
|
||||||
|
|
||||||
def __init__( |
|
||||||
self, |
|
||||||
dim: int, |
|
||||||
dim_out: int, |
|
||||||
num_heads: int, |
|
||||||
mlp_ratio: float = 4.0, |
|
||||||
drop_path: float = 0.0, |
|
||||||
norm_layer: Union[nn.Module, str] = "LayerNorm", |
|
||||||
q_stride: Tuple[int, int] = None, |
|
||||||
act_layer: nn.Module = nn.GELU, |
|
||||||
window_size: int = 0, |
|
||||||
): |
|
||||||
"""Initializes a multi-scale attention block with optional window partitioning and downsampling.""" |
|
||||||
super().__init__() |
|
||||||
|
|
||||||
if isinstance(norm_layer, str): |
|
||||||
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) |
|
||||||
|
|
||||||
self.dim = dim |
|
||||||
self.dim_out = dim_out |
|
||||||
self.norm1 = norm_layer(dim) |
|
||||||
|
|
||||||
self.window_size = window_size |
|
||||||
|
|
||||||
self.pool, self.q_stride = None, q_stride |
|
||||||
if self.q_stride: |
|
||||||
self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False) |
|
||||||
|
|
||||||
self.attn = MultiScaleAttention( |
|
||||||
dim, |
|
||||||
dim_out, |
|
||||||
num_heads=num_heads, |
|
||||||
q_pool=self.pool, |
|
||||||
) |
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
||||||
|
|
||||||
self.norm2 = norm_layer(dim_out) |
|
||||||
self.mlp = MLP( |
|
||||||
dim_out, |
|
||||||
int(dim_out * mlp_ratio), |
|
||||||
dim_out, |
|
||||||
num_layers=2, |
|
||||||
act=act_layer, |
|
||||||
) |
|
||||||
|
|
||||||
if dim != dim_out: |
|
||||||
self.proj = nn.Linear(dim, dim_out) |
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
||||||
"""Applies multi-scale attention and MLP processing to input tensor, with optional windowing.""" |
|
||||||
shortcut = x # B, H, W, C |
|
||||||
x = self.norm1(x) |
|
||||||
|
|
||||||
# Skip connection |
|
||||||
if self.dim != self.dim_out: |
|
||||||
shortcut = do_pool(self.proj(x), self.pool) |
|
||||||
|
|
||||||
# Window partition |
|
||||||
window_size = self.window_size |
|
||||||
if window_size > 0: |
|
||||||
H, W = x.shape[1], x.shape[2] |
|
||||||
x, pad_hw = window_partition(x, window_size) |
|
||||||
|
|
||||||
# Window Attention + Q Pooling (if stage change) |
|
||||||
x = self.attn(x) |
|
||||||
if self.q_stride: |
|
||||||
# Shapes have changed due to Q pooling |
|
||||||
window_size = self.window_size // self.q_stride[0] |
|
||||||
H, W = shortcut.shape[1:3] |
|
||||||
|
|
||||||
pad_h = (window_size - H % window_size) % window_size |
|
||||||
pad_w = (window_size - W % window_size) % window_size |
|
||||||
pad_hw = (H + pad_h, W + pad_w) |
|
||||||
|
|
||||||
# Reverse window partition |
|
||||||
if self.window_size > 0: |
|
||||||
x = window_unpartition(x, window_size, pad_hw, (H, W)) |
|
||||||
|
|
||||||
x = shortcut + self.drop_path(x) |
|
||||||
# MLP |
|
||||||
x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
||||||
return x |
|
||||||
|
|
||||||
|
|
||||||
class PositionEmbeddingSine(nn.Module): |
|
||||||
"""Generates sinusoidal positional embeddings for 2D inputs like images.""" |
|
||||||
|
|
||||||
def __init__( |
|
||||||
self, |
|
||||||
num_pos_feats, |
|
||||||
temperature: int = 10000, |
|
||||||
normalize: bool = True, |
|
||||||
scale: Optional[float] = None, |
|
||||||
): |
|
||||||
"""Initializes sinusoidal position embeddings for 2D image inputs.""" |
|
||||||
super().__init__() |
|
||||||
assert num_pos_feats % 2 == 0, "Expecting even model width" |
|
||||||
self.num_pos_feats = num_pos_feats // 2 |
|
||||||
self.temperature = temperature |
|
||||||
self.normalize = normalize |
|
||||||
if scale is not None and normalize is False: |
|
||||||
raise ValueError("normalize should be True if scale is passed") |
|
||||||
if scale is None: |
|
||||||
scale = 2 * math.pi |
|
||||||
self.scale = scale |
|
||||||
|
|
||||||
self.cache = {} |
|
||||||
|
|
||||||
def _encode_xy(self, x, y): |
|
||||||
"""Encodes 2D positions using sine and cosine functions for positional embeddings.""" |
|
||||||
assert len(x) == len(y) and x.ndim == y.ndim == 1 |
|
||||||
x_embed = x * self.scale |
|
||||||
y_embed = y * self.scale |
|
||||||
|
|
||||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) |
|
||||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) |
|
||||||
|
|
||||||
pos_x = x_embed[:, None] / dim_t |
|
||||||
pos_y = y_embed[:, None] / dim_t |
|
||||||
pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1) |
|
||||||
pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1) |
|
||||||
return pos_x, pos_y |
|
||||||
|
|
||||||
@torch.no_grad() |
|
||||||
def encode_boxes(self, x, y, w, h): |
|
||||||
"""Encodes box coordinates and dimensions into positional embeddings for object detection tasks.""" |
|
||||||
pos_x, pos_y = self._encode_xy(x, y) |
|
||||||
pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) |
|
||||||
return pos |
|
||||||
|
|
||||||
encode = encode_boxes # Backwards compatibility |
|
||||||
|
|
||||||
@torch.no_grad() |
|
||||||
def encode_points(self, x, y, labels): |
|
||||||
"""Encodes 2D point coordinates with sinusoidal positional embeddings and appends labels.""" |
|
||||||
(bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape |
|
||||||
assert bx == by and nx == ny and bx == bl and nx == nl |
|
||||||
pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) |
|
||||||
pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) |
|
||||||
pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) |
|
||||||
return pos |
|
||||||
|
|
||||||
@torch.no_grad() |
|
||||||
def forward(self, x: torch.Tensor): |
|
||||||
"""Generate sinusoidal position embeddings for 2D inputs.""" |
|
||||||
cache_key = (x.shape[-2], x.shape[-1]) |
|
||||||
if cache_key in self.cache: |
|
||||||
return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) |
|
||||||
y_embed = ( |
|
||||||
torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) |
|
||||||
.view(1, -1, 1) |
|
||||||
.repeat(x.shape[0], 1, x.shape[-1]) |
|
||||||
) |
|
||||||
x_embed = ( |
|
||||||
torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) |
|
||||||
.view(1, 1, -1) |
|
||||||
.repeat(x.shape[0], x.shape[-2], 1) |
|
||||||
) |
|
||||||
|
|
||||||
if self.normalize: |
|
||||||
eps = 1e-6 |
|
||||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale |
|
||||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale |
|
||||||
|
|
||||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) |
|
||||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) |
|
||||||
|
|
||||||
pos_x = x_embed[:, :, :, None] / dim_t |
|
||||||
pos_y = y_embed[:, :, :, None] / dim_t |
|
||||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) |
|
||||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) |
|
||||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) |
|
||||||
self.cache[cache_key] = pos[0] |
|
||||||
return pos |
|
@ -1,177 +0,0 @@ |
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license |
|
||||||
|
|
||||||
import torch |
|
||||||
|
|
||||||
from ..sam.predict import Predictor |
|
||||||
from .build import build_sam2 |
|
||||||
|
|
||||||
|
|
||||||
class SAM2Predictor(Predictor): |
|
||||||
""" |
|
||||||
A predictor class for the Segment Anything Model 2 (SAM2), extending the base Predictor class. |
|
||||||
|
|
||||||
This class provides an interface for model inference tailored to image segmentation tasks, leveraging SAM2's |
|
||||||
advanced architecture and promptable segmentation capabilities. It facilitates flexible and real-time mask |
|
||||||
generation, working with various types of prompts such as bounding boxes, points, and low-resolution masks. |
|
||||||
|
|
||||||
Attributes: |
|
||||||
cfg (Dict): Configuration dictionary specifying model and task-related parameters. |
|
||||||
overrides (Dict): Dictionary containing values that override the default configuration. |
|
||||||
_callbacks (Dict): Dictionary of user-defined callback functions to augment behavior. |
|
||||||
args (namespace): Namespace to hold command-line arguments or other operational variables. |
|
||||||
im (torch.Tensor): Preprocessed input image tensor. |
|
||||||
features (torch.Tensor): Extracted image features used for inference. |
|
||||||
prompts (Dict): Collection of various prompt types, such as bounding boxes and points. |
|
||||||
segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones. |
|
||||||
model (torch.nn.Module): The loaded SAM2 model. |
|
||||||
device (torch.device): The device (CPU or GPU) on which the model is loaded. |
|
||||||
_bb_feat_sizes (List[Tuple[int, int]]): List of feature sizes for different backbone levels. |
|
||||||
|
|
||||||
Methods: |
|
||||||
get_model: Builds and returns the SAM2 model. |
|
||||||
prompt_inference: Performs image segmentation inference based on various prompts. |
|
||||||
set_image: Preprocesses and sets a single image for inference. |
|
||||||
get_im_features: Extracts image features from the SAM2 image encoder. |
|
||||||
|
|
||||||
Examples: |
|
||||||
>>> predictor = SAM2Predictor(model='sam2_l.pt') |
|
||||||
>>> predictor.set_image('path/to/image.jpg') |
|
||||||
>>> masks, scores = predictor.prompt_inference(im=predictor.im, points=[[500, 375]], labels=[1]) |
|
||||||
>>> print(f"Generated {len(masks)} mask(s) with scores: {scores}") |
|
||||||
""" |
|
||||||
|
|
||||||
_bb_feat_sizes = [ |
|
||||||
(256, 256), |
|
||||||
(128, 128), |
|
||||||
(64, 64), |
|
||||||
] |
|
||||||
|
|
||||||
def get_model(self): |
|
||||||
"""Retrieves and initializes the Segment Anything Model (SAM) for image segmentation tasks.""" |
|
||||||
return build_sam2(self.args.model) |
|
||||||
|
|
||||||
def prompt_inference( |
|
||||||
self, |
|
||||||
im, |
|
||||||
bboxes=None, |
|
||||||
points=None, |
|
||||||
labels=None, |
|
||||||
masks=None, |
|
||||||
multimask_output=False, |
|
||||||
img_idx=-1, |
|
||||||
): |
|
||||||
""" |
|
||||||
Performs image segmentation inference based on various prompts using SAM2 architecture. |
|
||||||
|
|
||||||
Args: |
|
||||||
im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W). |
|
||||||
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). |
|
||||||
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels. |
|
||||||
labels (np.ndarray | List | None): Labels for point prompts with shape (N,). 1 = foreground, 0 = background. |
|
||||||
masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W). |
|
||||||
multimask_output (bool): Flag to return multiple masks for ambiguous prompts. |
|
||||||
img_idx (int): Index of the image in the batch to process. |
|
||||||
|
|
||||||
Returns: |
|
||||||
(tuple): Tuple containing: |
|
||||||
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks. |
|
||||||
- np.ndarray: Quality scores for each mask, with length C. |
|
||||||
- np.ndarray: Low-resolution logits with shape (C, 256, 256) for subsequent inference. |
|
||||||
|
|
||||||
Examples: |
|
||||||
>>> predictor = SAM2Predictor(cfg) |
|
||||||
>>> image = torch.rand(1, 3, 640, 640) |
|
||||||
>>> bboxes = [[100, 100, 200, 200]] |
|
||||||
>>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes) |
|
||||||
""" |
|
||||||
features = self.get_im_features(im) if self.features is None else self.features |
|
||||||
|
|
||||||
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:] |
|
||||||
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) |
|
||||||
# Transform input prompts |
|
||||||
if points is not None: |
|
||||||
points = torch.as_tensor(points, dtype=torch.float32, device=self.device) |
|
||||||
points = points[None] if points.ndim == 1 else points |
|
||||||
# Assuming labels are all positive if users don't pass labels. |
|
||||||
if labels is None: |
|
||||||
labels = torch.ones(points.shape[0]) |
|
||||||
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device) |
|
||||||
points *= r |
|
||||||
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1) |
|
||||||
points, labels = points[:, None], labels[:, None] |
|
||||||
if bboxes is not None: |
|
||||||
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device) |
|
||||||
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes |
|
||||||
bboxes = bboxes.view(-1, 2, 2) * r |
|
||||||
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1) |
|
||||||
# NOTE: merge "boxes" and "points" into a single "points" input |
|
||||||
# (where boxes are added at the beginning) to model.sam_prompt_encoder |
|
||||||
if points is not None: |
|
||||||
points = torch.cat([bboxes, points], dim=1) |
|
||||||
labels = torch.cat([bbox_labels, labels], dim=1) |
|
||||||
else: |
|
||||||
points, labels = bboxes, bbox_labels |
|
||||||
if masks is not None: |
|
||||||
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) |
|
||||||
|
|
||||||
points = (points, labels) if points is not None else None |
|
||||||
|
|
||||||
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( |
|
||||||
points=points, |
|
||||||
boxes=None, |
|
||||||
masks=masks, |
|
||||||
) |
|
||||||
# Predict masks |
|
||||||
batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction |
|
||||||
high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]] |
|
||||||
pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder( |
|
||||||
image_embeddings=features["image_embed"][img_idx].unsqueeze(0), |
|
||||||
image_pe=self.model.sam_prompt_encoder.get_dense_pe(), |
|
||||||
sparse_prompt_embeddings=sparse_embeddings, |
|
||||||
dense_prompt_embeddings=dense_embeddings, |
|
||||||
multimask_output=multimask_output, |
|
||||||
repeat_image=batched_mode, |
|
||||||
high_res_features=high_res_features, |
|
||||||
) |
|
||||||
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) |
|
||||||
# `d` could be 1 or 3 depends on `multimask_output`. |
|
||||||
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) |
|
||||||
|
|
||||||
def set_image(self, image): |
|
||||||
""" |
|
||||||
Preprocesses and sets a single image for inference. |
|
||||||
|
|
||||||
This function sets up the model if not already initialized, configures the data source to the specified image, |
|
||||||
and preprocesses the image for feature extraction. Only one image can be set at a time. |
|
||||||
|
|
||||||
Args: |
|
||||||
image (str | np.ndarray): Image file path as a string, or a numpy array image read by cv2. |
|
||||||
|
|
||||||
Raises: |
|
||||||
AssertionError: If more than one image is set. |
|
||||||
|
|
||||||
Examples: |
|
||||||
>>> predictor = SAM2Predictor() |
|
||||||
>>> predictor.set_image("path/to/image.jpg") |
|
||||||
>>> predictor.set_image(np.array([...])) # Using a numpy array |
|
||||||
""" |
|
||||||
if self.model is None: |
|
||||||
self.setup_model(model=None) |
|
||||||
self.setup_source(image) |
|
||||||
assert len(self.dataset) == 1, "`set_image` only supports setting one image!" |
|
||||||
for batch in self.dataset: |
|
||||||
im = self.preprocess(batch[1]) |
|
||||||
self.features = self.get_im_features(im) |
|
||||||
break |
|
||||||
|
|
||||||
def get_im_features(self, im): |
|
||||||
"""Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks.""" |
|
||||||
backbone_out = self.model.forward_image(im) |
|
||||||
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) |
|
||||||
if self.model.directly_add_no_mem_embed: |
|
||||||
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed |
|
||||||
feats = [ |
|
||||||
feat.permute(1, 2, 0).view(1, -1, *feat_size) |
|
||||||
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) |
|
||||||
][::-1] |
|
||||||
return {"image_embed": feats[-1], "high_res_feats": feats[:-1]} |
|
Loading…
Reference in new issue