|
|
|
@ -671,26 +671,19 @@ class SAM2Model(torch.nn.Module): |
|
|
|
|
t_rel = self.num_maskmem - t_pos # how many frames before current frame |
|
|
|
|
if t_rel == 1: |
|
|
|
|
# for t_rel == 1, we take the last frame (regardless of r) |
|
|
|
|
if not track_in_reverse: |
|
|
|
|
# the frame immediately before this frame (i.e. frame_idx - 1) |
|
|
|
|
prev_frame_idx = frame_idx - t_rel |
|
|
|
|
else: |
|
|
|
|
# the frame immediately after this frame (i.e. frame_idx + 1) |
|
|
|
|
prev_frame_idx = frame_idx + t_rel |
|
|
|
|
prev_frame_idx = frame_idx + t_rel if track_in_reverse else frame_idx - t_rel |
|
|
|
|
elif not track_in_reverse: |
|
|
|
|
# first find the nearest frame among every r-th frames before this frame |
|
|
|
|
# for r=1, this would be (frame_idx - 2) |
|
|
|
|
prev_frame_idx = ((frame_idx - 2) // r) * r |
|
|
|
|
# then seek further among every r-th frames |
|
|
|
|
prev_frame_idx = prev_frame_idx - (t_rel - 2) * r |
|
|
|
|
else: |
|
|
|
|
# for t_rel >= 2, we take the memory frame from every r-th frames |
|
|
|
|
if not track_in_reverse: |
|
|
|
|
# first find the nearest frame among every r-th frames before this frame |
|
|
|
|
# for r=1, this would be (frame_idx - 2) |
|
|
|
|
prev_frame_idx = ((frame_idx - 2) // r) * r |
|
|
|
|
# then seek further among every r-th frames |
|
|
|
|
prev_frame_idx = prev_frame_idx - (t_rel - 2) * r |
|
|
|
|
else: |
|
|
|
|
# first find the nearest frame among every r-th frames after this frame |
|
|
|
|
# for r=1, this would be (frame_idx + 2) |
|
|
|
|
prev_frame_idx = -(-(frame_idx + 2) // r) * r |
|
|
|
|
# then seek further among every r-th frames |
|
|
|
|
prev_frame_idx = prev_frame_idx + (t_rel - 2) * r |
|
|
|
|
# first find the nearest frame among every r-th frames after this frame |
|
|
|
|
# for r=1, this would be (frame_idx + 2) |
|
|
|
|
prev_frame_idx = -(-(frame_idx + 2) // r) * r |
|
|
|
|
# then seek further among every r-th frames |
|
|
|
|
prev_frame_idx = prev_frame_idx + (t_rel - 2) * r |
|
|
|
|
out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) |
|
|
|
|
if out is None: |
|
|
|
|
# If an unselected conditioning frame is among the last (self.num_maskmem - 1) |
|
|
|
@ -739,7 +732,7 @@ class SAM2Model(torch.nn.Module): |
|
|
|
|
if out is not None: |
|
|
|
|
pos_and_ptrs.append((t_diff, out["obj_ptr"])) |
|
|
|
|
# If we have at least one object pointer, add them to the across attention |
|
|
|
|
if len(pos_and_ptrs) > 0: |
|
|
|
|
if pos_and_ptrs: |
|
|
|
|
pos_list, ptrs_list = zip(*pos_and_ptrs) |
|
|
|
|
# stack object pointers along dim=0 into [ptr_seq_len, B, C] shape |
|
|
|
|
obj_ptrs = torch.stack(ptrs_list, dim=0) |
|
|
|
@ -930,12 +923,11 @@ class SAM2Model(torch.nn.Module): |
|
|
|
|
def _use_multimask(self, is_init_cond_frame, point_inputs): |
|
|
|
|
"""Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs.""" |
|
|
|
|
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) |
|
|
|
|
multimask_output = ( |
|
|
|
|
return ( |
|
|
|
|
self.multimask_output_in_sam |
|
|
|
|
and (is_init_cond_frame or self.multimask_output_for_tracking) |
|
|
|
|
and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) |
|
|
|
|
) |
|
|
|
|
return multimask_output |
|
|
|
|
|
|
|
|
|
def _apply_non_overlapping_constraints(self, pred_masks): |
|
|
|
|
"""Applies non-overlapping constraints to masks, keeping highest scoring object per location.""" |
|
|
|
|