@ -161,18 +161,19 @@ class SAM2Model(torch.nn.Module):
use_multimask_token_for_obj_ptr : bool = False ,
iou_prediction_use_sigmoid = False ,
memory_temporal_stride_for_eval = 1 ,
add_all_frames_to_correct_as_cond = False ,
non_overlap_masks_for_mem_enc = False ,
use_obj_ptrs_in_encoder = False ,
max_obj_ptrs_in_encoder = 16 ,
add_tpos_enc_to_obj_ptrs = True ,
proj_tpos_enc_in_obj_ptrs = False ,
use_signed_tpos_enc_to_obj_ptrs = False ,
only_obj_ptrs_in_the_past_for_eval = False ,
pred_obj_scores : bool = False ,
pred_obj_scores_mlp : bool = False ,
fixed_no_obj_ptr : bool = False ,
soft_no_obj_ptr : bool = False ,
use_mlp_for_obj_ptr_proj : bool = False ,
no_obj_embed_spatial : bool = False ,
sam_mask_decoder_extra_args = None ,
compile_image_encoder : bool = False ,
) :
@ -205,8 +206,6 @@ class SAM2Model(torch.nn.Module):
use_multimask_token_for_obj_ptr ( bool ) : Whether to use multimask tokens for object pointers .
iou_prediction_use_sigmoid ( bool ) : Whether to use sigmoid to restrict IoU prediction to [ 0 - 1 ] .
memory_temporal_stride_for_eval ( int ) : Memory bank ' s temporal stride during evaluation.
add_all_frames_to_correct_as_cond ( bool ) : Whether to append frames with correction clicks to conditioning
frame list .
non_overlap_masks_for_mem_enc ( bool ) : Whether to apply non - overlapping constraints on object masks in
memory encoder during evaluation .
use_obj_ptrs_in_encoder ( bool ) : Whether to cross - attend to object pointers from other frames in the encoder .
@ -216,6 +215,9 @@ class SAM2Model(torch.nn.Module):
the encoder .
proj_tpos_enc_in_obj_ptrs ( bool ) : Whether to add an extra linear projection layer for temporal positional
encoding in object pointers .
use_signed_tpos_enc_to_obj_ptrs ( bool ) : whether to use signed distance ( instead of unsigned absolute distance )
in the temporal positional encoding in the object pointers , only relevant when both ` use_obj_ptrs_in_encoder = True `
and ` add_tpos_enc_to_obj_ptrs = True ` .
only_obj_ptrs_in_the_past_for_eval ( bool ) : Whether to only attend to object pointers in the past
during evaluation .
pred_obj_scores ( bool ) : Whether to predict if there is an object in the frame .
@ -223,6 +225,7 @@ class SAM2Model(torch.nn.Module):
fixed_no_obj_ptr ( bool ) : Whether to have a fixed no - object pointer when there is no object present .
soft_no_obj_ptr ( bool ) : Whether to mix in no - object pointer softly for easier recovery and error mitigation .
use_mlp_for_obj_ptr_proj ( bool ) : Whether to use MLP for object pointer projection .
no_obj_embed_spatial ( bool ) : Whether add no obj embedding to spatial frames .
sam_mask_decoder_extra_args ( Dict | None ) : Extra arguments for constructing the SAM mask decoder .
compile_image_encoder ( bool ) : Whether to compile the image encoder for faster inference .
@ -253,6 +256,7 @@ class SAM2Model(torch.nn.Module):
if proj_tpos_enc_in_obj_ptrs :
assert add_tpos_enc_to_obj_ptrs # these options need to be used together
self . proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
self . use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
self . only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
# Part 2: memory attention to condition current frame's visual features
@ -309,9 +313,12 @@ class SAM2Model(torch.nn.Module):
self . no_obj_ptr = torch . nn . Parameter ( torch . zeros ( 1 , self . hidden_dim ) )
trunc_normal_ ( self . no_obj_ptr , std = 0.02 )
self . use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
self . no_obj_embed_spatial = None
if no_obj_embed_spatial :
self . no_obj_embed_spatial = torch . nn . Parameter ( torch . zeros ( 1 , self . mem_dim ) )
trunc_normal_ ( self . no_obj_embed_spatial , std = 0.02 )
self . _build_sam_heads ( )
self . add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
self . max_cond_frames_in_attn = max_cond_frames_in_attn
# Model compilation
@ -533,8 +540,6 @@ class SAM2Model(torch.nn.Module):
if self . pred_obj_scores :
# Allow *soft* no obj ptr, unlike for masks
if self . soft_no_obj_ptr :
# Only hard possible with gt
assert not self . teacher_force_obj_scores_for_mem
lambda_is_obj_appearing = object_score_logits . sigmoid ( )
else :
lambda_is_obj_appearing = is_obj_appearing . float ( )
@ -647,6 +652,7 @@ class SAM2Model(torch.nn.Module):
if self . num_maskmem == 0 : # Disable memory and skip fusion
return current_vision_feats [ - 1 ] . permute ( 1 , 2 , 0 ) . view ( B , C , H , W )
num_obj_ptr_tokens = 0
tpos_sign_mul = - 1 if track_in_reverse else 1
# Step 1: condition the visual features of the current frame on previous memories
if not is_init_cond_frame :
# Retrieve the memories encoded with the maskmem backbone
@ -664,7 +670,7 @@ class SAM2Model(torch.nn.Module):
# the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
# We also allow taking the memory frame non-consecutively (with r>1), in which case
# we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
r = self . memory_temporal_stride_for_eval
r = 1 if self . training else self . memory_temporal_stride_for_eval
for t_pos in range ( 1 , self . num_maskmem ) :
t_rel = self . num_maskmem - t_pos # how many frames before current frame
if t_rel == 1 :
@ -718,7 +724,14 @@ class SAM2Model(torch.nn.Module):
ptr_cond_outputs = selected_cond_outputs
pos_and_ptrs = [
# Temporal pos encoding contains how far away each pointer is from current frame
( abs ( frame_idx - t ) , out [ " obj_ptr " ] )
(
(
( frame_idx - t ) * tpos_sign_mul
if self . use_signed_tpos_enc_to_obj_ptrs
else abs ( frame_idx - t )
) ,
out [ " obj_ptr " ] ,
)
for t , out in ptr_cond_outputs . items ( )
]
# Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
@ -787,6 +800,7 @@ class SAM2Model(torch.nn.Module):
current_vision_feats ,
feat_sizes ,
pred_masks_high_res ,
object_score_logits ,
is_mask_from_pts ,
) :
""" Encodes frame features and masks into a new memory representation for video segmentation. """
@ -819,10 +833,17 @@ class SAM2Model(torch.nn.Module):
)
maskmem_features = maskmem_out [ " vision_features " ]
maskmem_pos_enc = maskmem_out [ " vision_pos_enc " ]
# add a no-object embedding to the spatial memory to indicate that the frame
# is predicted to be occluded (i.e. no object is appearing in the frame)
if self . no_obj_embed_spatial is not None :
is_obj_appearing = ( object_score_logits > 0 ) . float ( )
maskmem_features + = ( 1 - is_obj_appearing [ . . . , None , None ] ) * self . no_obj_embed_spatial [
. . . , None , None
] . expand ( * maskmem_features . shape )
return maskmem_features , maskmem_pos_enc
def track_step (
def _ track_step(
self ,
frame_idx ,
is_init_cond_frame ,
@ -833,15 +854,7 @@ class SAM2Model(torch.nn.Module):
mask_inputs ,
output_dict ,
num_frames ,
track_in_reverse = False , # tracking in reverse time order (for demo usage)
# Whether to run the memory encoder on the predicted masks. Sometimes we might want
# to skip the memory encoder with `run_mem_encoder=False`. For example,
# in demo we might call `track_step` multiple times for each user click,
# and only encode the memory when the user finalizes their clicks. And in ablation
# settings like SAM training on static images, we don't need the memory encoder.
run_mem_encoder = True ,
# The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
prev_sam_mask_logits = None ,
prev_sam_mask_logits ,
) :
""" Performs a single tracking step, updating object masks and memory features based on current frame inputs. """
current_out = { " point_inputs " : point_inputs , " mask_inputs " : mask_inputs }
@ -861,7 +874,7 @@ class SAM2Model(torch.nn.Module):
sam_outputs = self . _use_mask_as_output ( pix_feat , high_res_features , mask_inputs )
else :
# fused the visual feature with previous memory features in the memory bank
pix_feat_with_mem = self . _prepare_memory_conditioned_features (
pix_feat = self . _prepare_memory_conditioned_features (
frame_idx = frame_idx ,
is_init_cond_frame = is_init_cond_frame ,
current_vision_feats = current_vision_feats [ - 1 : ] ,
@ -880,12 +893,78 @@ class SAM2Model(torch.nn.Module):
mask_inputs = prev_sam_mask_logits
multimask_output = self . _use_multimask ( is_init_cond_frame , point_inputs )
sam_outputs = self . _forward_sam_heads (
backbone_features = pix_feat_with_mem ,
backbone_features = pix_feat ,
point_inputs = point_inputs ,
mask_inputs = mask_inputs ,
high_res_features = high_res_features ,
multimask_output = multimask_output ,
)
return current_out , sam_outputs , high_res_features , pix_feat
def _encode_memory_in_output (
self ,
current_vision_feats ,
feat_sizes ,
point_inputs ,
run_mem_encoder ,
high_res_masks ,
object_score_logits ,
current_out ,
) :
""" Finally run the memory encoder on the predicted mask to encode, it into a new memory feature (that can be
used in future frames ) .
"""
if run_mem_encoder and self . num_maskmem > 0 :
high_res_masks_for_mem_enc = high_res_masks
maskmem_features , maskmem_pos_enc = self . _encode_new_memory (
current_vision_feats = current_vision_feats ,
feat_sizes = feat_sizes ,
pred_masks_high_res = high_res_masks_for_mem_enc ,
object_score_logits = object_score_logits ,
is_mask_from_pts = ( point_inputs is not None ) ,
)
current_out [ " maskmem_features " ] = maskmem_features
current_out [ " maskmem_pos_enc " ] = maskmem_pos_enc
else :
current_out [ " maskmem_features " ] = None
current_out [ " maskmem_pos_enc " ] = None
def track_step (
self ,
frame_idx ,
is_init_cond_frame ,
current_vision_feats ,
current_vision_pos_embeds ,
feat_sizes ,
point_inputs ,
mask_inputs ,
output_dict ,
num_frames ,
track_in_reverse = False , # tracking in reverse time order (for demo usage)
# Whether to run the memory encoder on the predicted masks. Sometimes we might want
# to skip the memory encoder with `run_mem_encoder=False`. For example,
# in demo we might call `track_step` multiple times for each user click,
# and only encode the memory when the user finalizes their clicks. And in ablation
# settings like SAM training on static images, we don't need the memory encoder.
run_mem_encoder = True ,
# The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
prev_sam_mask_logits = None ,
) :
""" Performs a single tracking step, updating object masks and memory features based on current frame inputs. """
current_out , sam_outputs , _ , _ = self . _track_step (
frame_idx ,
is_init_cond_frame ,
current_vision_feats ,
current_vision_pos_embeds ,
feat_sizes ,
point_inputs ,
mask_inputs ,
output_dict ,
num_frames ,
track_in_reverse ,
prev_sam_mask_logits ,
)
(
_ ,
_ ,
@ -893,28 +972,28 @@ class SAM2Model(torch.nn.Module):
low_res_masks ,
high_res_masks ,
obj_ptr ,
_ ,
object _score_logits ,
) = sam_outputs
current_out [ " pred_masks " ] = low_res_masks
current_out [ " pred_masks_high_res " ] = high_res_masks
current_out [ " obj_ptr " ] = obj_ptr
if not self . training :
# Only add this in inference (to avoid unused param in activation checkpointing;
# it's mainly used in the demo to encode spatial memories w/ consolidated masks)
current_out [ " object_score_logits " ] = object_score_logits
# Finally run the memory encoder on the predicted mask to encode
# it into a new memory feature (that can be used in future frames)
if run_mem_encoder and self . num_maskmem > 0 :
high_res_masks_for_mem_enc = high_res_masks
maskmem_features , maskmem_pos_enc = self . _encode_new_memory (
current_vision_feats = current_vision_feats ,
feat_sizes = feat_sizes ,
pred_masks_high_res = high_res_masks_for_mem_enc ,
is_mask_from_pts = ( point_inputs is not None ) ,
)
current_out [ " maskmem_features " ] = maskmem_features
current_out [ " maskmem_pos_enc " ] = maskmem_pos_enc
else :
current_out [ " maskmem_features " ] = None
current_out [ " maskmem_pos_enc " ] = None
self . _encode_memory_in_output (
current_vision_feats ,
feat_sizes ,
point_inputs ,
run_mem_encoder ,
high_res_masks ,
object_score_logits ,
current_out ,
)
return current_out