diff --git a/docs/en/reference/models/sam/build.md b/docs/en/reference/models/sam/build.md
index c5a66c656..5b608c7ec 100644
--- a/docs/en/reference/models/sam/build.md
+++ b/docs/en/reference/models/sam/build.md
@@ -1,6 +1,6 @@
---
-description: Discover detailed instructions for building various Segment Anything Model (SAM) architectures with Ultralytics, including SAM ViT and Mobile-SAM.
-keywords: Ultralytics, SAM model, Segment Anything Model, SAM ViT, Mobile-SAM, model building, deep learning, AI
+description: Discover detailed instructions for building various Segment Anything Model (SAM) and Segment Anything Model 2 (SAM 2) architectures with Ultralytics, including SAM ViT and Mobile-SAM.
+keywords: Ultralytics, SAM model, Segment Anything Model, SAM 2 model, Segment Anything Model 2, SAM ViT, Mobile-SAM, model building, deep learning, AI
---
# Reference for `ultralytics/models/sam/build.py`
@@ -27,10 +27,30 @@ keywords: Ultralytics, SAM model, Segment Anything Model, SAM ViT, Mobile-SAM, m
+## ::: ultralytics.models.sam.build.build_sam2_t
+
+
+
+## ::: ultralytics.models.sam.build.build_sam2_s
+
+
+
+## ::: ultralytics.models.sam.build.build_sam2_b
+
+
+
+## ::: ultralytics.models.sam.build.build_sam2_l
+
+
+
## ::: ultralytics.models.sam.build._build_sam
+## ::: ultralytics.models.sam.build._build_sam2
+
+
+
## ::: ultralytics.models.sam.build.build_sam
diff --git a/docs/en/reference/models/sam/model.md b/docs/en/reference/models/sam/model.md
index 2255ad882..9c0f9301c 100644
--- a/docs/en/reference/models/sam/model.md
+++ b/docs/en/reference/models/sam/model.md
@@ -1,6 +1,6 @@
---
-description: Explore the SAM (Segment Anything Model) interface for real-time image segmentation. Learn about promptable segmentation and zero-shot capabilities.
-keywords: Ultralytics, SAM, Segment Anything Model, image segmentation, real-time segmentation, zero-shot performance, promptable segmentation, SA-1B dataset
+description: Explore the SAM (Segment Anything Model) and SAM 2 (Segment Anything Model 2) interface for real-time image segmentation. Learn about promptable segmentation and zero-shot capabilities.
+keywords: Ultralytics, SAM, Segment Anything Model, SAM 2, Segment Anything Model 2, image segmentation, real-time segmentation, zero-shot performance, promptable segmentation, SA-1B dataset
---
# Reference for `ultralytics/models/sam/model.py`
diff --git a/docs/en/reference/models/sam/modules/blocks.md b/docs/en/reference/models/sam/modules/blocks.md
new file mode 100644
index 000000000..5303a852e
--- /dev/null
+++ b/docs/en/reference/models/sam/modules/blocks.md
@@ -0,0 +1,72 @@
+---
+description: Explore detailed documentation of various SAM and SAM 2 modules such as MaskDownSampler, CXBlock, and more, available in Ultralytics' repository.
+keywords: Ultralytics, SAM encoder, SAM 2 encoder, DropPath, MaskDownSampler, CXBlock, Fuser, TwoWayTransformer, TwoWayAttentionBlock, RoPEAttention, MultiScaleAttention, MultiScaleBlock. PositionEmbeddingSine, do_pool
+---
+
+# Reference for `ultralytics/models/sam/modules/blocks.py`
+
+!!! Note
+
+ This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/blocks.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/blocks.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam/modules/blocks.py) 🛠️. Thank you 🙏!
+
+
+
+## ::: ultralytics.models.sam.modules.blocks.DropPath
+
+
+
+## ::: ultralytics.models.sam.modules.blocks.MaskDownSampler
+
+
+
+## ::: ultralytics.models.sam.modules.blocks.CXBlock
+
+
+
+## ::: ultralytics.models.sam.modules.blocks.Fuser
+
+
+
+## ::: ultralytics.models.sam.modules.blocks.SAM2TwoWayAttentionBlock
+
+
+
+## ::: ultralytics.models.sam.modules.blocks.SAM2TwoWayTransformer
+
+
+
+## ::: ultralytics.models.sam.modules.blocks.RoPEAttention
+
+
+
+## ::: ultralytics.models.sam.modules.blocks.MultiScaleAttention
+
+
+
+## ::: ultralytics.models.sam.modules.blocks.MultiScaleBlock
+
+
+
+## ::: ultralytics.models.sam.modules.blocks.PositionEmbeddingSine
+
+
+
+## ::: ultralytics.models.sam.modules.blocks.PositionEmbeddingRandom
+
+
+
+## ::: ultralytics.models.sam.modules.blocks.Block
+
+
+
+## ::: ultralytics.models.sam.modules.blocks.REAttention
+
+
+
+## ::: ultralytics.models.sam.modules.blocks.PatchEmbed
+
+
+
+## ::: ultralytics.models.sam.modules.blocks.do_pool
+
+
diff --git a/docs/en/reference/models/sam/modules/decoders.md b/docs/en/reference/models/sam/modules/decoders.md
index a224738b7..df2fc528f 100644
--- a/docs/en/reference/models/sam/modules/decoders.md
+++ b/docs/en/reference/models/sam/modules/decoders.md
@@ -13,4 +13,8 @@ keywords: Ultralytics, MaskDecoder, MLP, machine learning, transformer architect
## ::: ultralytics.models.sam.modules.decoders.MaskDecoder
+
+
+## ::: ultralytics.models.sam.modules.decoders.SAM2MaskDecoder
+
diff --git a/docs/en/reference/models/sam/modules/encoders.md b/docs/en/reference/models/sam/modules/encoders.md
index 4e839e479..a2ff493f6 100644
--- a/docs/en/reference/models/sam/modules/encoders.md
+++ b/docs/en/reference/models/sam/modules/encoders.md
@@ -19,34 +19,18 @@ keywords: Ultralytics, SAM encoder, ImageEncoderViT, PromptEncoder, PositionEmbe
-## ::: ultralytics.models.sam.modules.encoders.PositionEmbeddingRandom
+## ::: ultralytics.models.sam.modules.encoders.MemoryEncoder
-## ::: ultralytics.models.sam.modules.encoders.Block
+## ::: ultralytics.models.sam.modules.encoders.ImageEncoder
-## ::: ultralytics.models.sam.modules.encoders.Attention
+## ::: ultralytics.models.sam.modules.encoders.FpnNeck
-## ::: ultralytics.models.sam.modules.encoders.PatchEmbed
-
-
-
-## ::: ultralytics.models.sam.modules.encoders.window_partition
-
-
-
-## ::: ultralytics.models.sam.modules.encoders.window_unpartition
-
-
-
-## ::: ultralytics.models.sam.modules.encoders.get_rel_pos
-
-
-
-## ::: ultralytics.models.sam.modules.encoders.add_decomposed_rel_pos
+## ::: ultralytics.models.sam.modules.encoders.Hiera
diff --git a/docs/en/reference/models/sam/modules/memory_attention.md b/docs/en/reference/models/sam/modules/memory_attention.md
new file mode 100644
index 000000000..a341d6fed
--- /dev/null
+++ b/docs/en/reference/models/sam/modules/memory_attention.md
@@ -0,0 +1,20 @@
+---
+description: Explore detailed documentation of various SAM 2 encoder modules such as MemoryAttentionLayer, MemoryAttention, available in Ultralytics' repository.
+keywords: Ultralytics, SAM 2 encoder, MemoryAttentionLayer, MemoryAttention
+---
+
+# Reference for `ultralytics/models/sam/modules/memory_attention.py`
+
+!!! Note
+
+ This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/memory_attention.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/memory_attention.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam/modules/memory_attention.py) 🛠️. Thank you 🙏!
+
+
+
+## ::: ultralytics.models.sam.modules.memory_attention.MemoryAttentionLayer
+
+
+
+## ::: ultralytics.models.sam.modules.memory_attention.MemoryAttention
+
+
diff --git a/docs/en/reference/models/sam/modules/sam.md b/docs/en/reference/models/sam/modules/sam.md
index a31cece15..d66508bb5 100644
--- a/docs/en/reference/models/sam/modules/sam.md
+++ b/docs/en/reference/models/sam/modules/sam.md
@@ -1,6 +1,6 @@
---
-description: Discover the Ultralytics SAM module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide.
-keywords: Ultralytics, SAM Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning
+description: Discover the Ultralytics SAM and SAM 2 module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide.
+keywords: Ultralytics, SAM Module, SAM 2 Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning
---
# Reference for `ultralytics/models/sam/modules/sam.py`
@@ -13,4 +13,8 @@ keywords: Ultralytics, SAM Module, object segmentation, image encoder, mask deco
## ::: ultralytics.models.sam.modules.sam.SAMModel
+
+
+## ::: ultralytics.models.sam.modules.sam.SAM2Model
+
diff --git a/docs/en/reference/models/sam/modules/tiny_encoder.md b/docs/en/reference/models/sam/modules/tiny_encoder.md
index a5cac0529..fa5e0988d 100644
--- a/docs/en/reference/models/sam/modules/tiny_encoder.md
+++ b/docs/en/reference/models/sam/modules/tiny_encoder.md
@@ -47,10 +47,6 @@ keywords: Ultralytics, TinyViT, Conv2d_BN, PatchEmbed, MBConv, Attention, PyTorc
-## ::: ultralytics.models.sam.modules.tiny_encoder.LayerNorm2d
-
-
-
## ::: ultralytics.models.sam.modules.tiny_encoder.TinyViT
diff --git a/docs/en/reference/models/sam/modules/utils.md b/docs/en/reference/models/sam/modules/utils.md
new file mode 100644
index 000000000..5611d156e
--- /dev/null
+++ b/docs/en/reference/models/sam/modules/utils.md
@@ -0,0 +1,52 @@
+---
+description: Explore the detailed API reference for Ultralytics SAM and SAM 2 models.
+keywords: Ultralytics, SAM, SAM 2, API Reference, models, window partition, data processing, YOLO
+---
+
+# Reference for `ultralytics/models/sam/modules/utils.py`
+
+!!! Note
+
+ This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/utils.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/utils.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam/modules/utils.py) 🛠️. Thank you 🙏!
+
+
+
+## ::: ultralytics.models.sam.modules.utils.select_closest_cond_frames
+
+
+
+## ::: ultralytics.models.sam.modules.utils.get_1d_sine_pe
+
+
+
+## ::: ultralytics.models.sam.modules.utils.init_t_xy
+
+
+
+## ::: ultralytics.models.sam.modules.utils.compute_axial_cis
+
+
+
+## ::: ultralytics.models.sam.modules.utils.reshape_for_broadcast
+
+
+
+## ::: ultralytics.models.sam.modules.utils.apply_rotary_enc
+
+
+
+## ::: ultralytics.models.sam.modules.utils.window_partition
+
+
+
+## ::: ultralytics.models.sam.modules.utils.window_unpartition
+
+
+
+## ::: ultralytics.models.sam.modules.utils.get_rel_pos
+
+
+
+## ::: ultralytics.models.sam.modules.utils.add_decomposed_rel_pos
+
+
diff --git a/docs/en/reference/models/sam/predict.md b/docs/en/reference/models/sam/predict.md
index 6c35fdc00..76bdce18b 100644
--- a/docs/en/reference/models/sam/predict.md
+++ b/docs/en/reference/models/sam/predict.md
@@ -1,6 +1,6 @@
---
-description: Explore Ultralytics SAM Predictor for advanced, real-time image segmentation using the Segment Anything Model (SAM). Complete implementation details and auxiliary utilities.
-keywords: Ultralytics, SAM, Segment Anything Model, image segmentation, real-time, prediction, AI, machine learning, Python, torch, inference
+description: Explore Ultralytics SAM and SAM 2 Predictor for advanced, real-time image segmentation using the Segment Anything Model (SAM and SAM 2). Complete implementation details and auxiliary utilities.
+keywords: Ultralytics, SAM, Segment Anything Model, SAM 2, Segment Anything Model 2, image segmentation, real-time, prediction, AI, machine learning, Python, torch, inference
---
# Reference for `ultralytics/models/sam/predict.py`
@@ -13,4 +13,8 @@ keywords: Ultralytics, SAM, Segment Anything Model, image segmentation, real-tim
## ::: ultralytics.models.sam.predict.Predictor
+
+
+## ::: ultralytics.models.sam.predict.SAM2Predictor
+
diff --git a/docs/en/reference/models/sam2/build.md b/docs/en/reference/models/sam2/build.md
deleted file mode 100644
index 94e9f5bef..000000000
--- a/docs/en/reference/models/sam2/build.md
+++ /dev/null
@@ -1,36 +0,0 @@
----
-description: Discover detailed instructions for building various Segment Anything Model 2 (SAM 2) architectures with Ultralytics.
-keywords: Ultralytics, SAM 2 model, Segment Anything Model 2, SAM, model building, deep learning, AI
----
-
-# Reference for `ultralytics/models/sam2/build.py`
-
-!!! Note
-
- This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/build.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/build.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/build.py) 🛠️. Thank you 🙏!
-
-
-
-## ::: ultralytics.models.sam2.build.build_sam2_t
-
-
-
-## ::: ultralytics.models.sam2.build.build_sam2_s
-
-
-
-## ::: ultralytics.models.sam2.build.build_sam2_b
-
-
-
-## ::: ultralytics.models.sam2.build.build_sam2_l
-
-
-
-## ::: ultralytics.models.sam2.build._build_sam2
-
-
-
-## ::: ultralytics.models.sam2.build.build_sam2
-
-
diff --git a/docs/en/reference/models/sam2/model.md b/docs/en/reference/models/sam2/model.md
deleted file mode 100644
index fc0d2e250..000000000
--- a/docs/en/reference/models/sam2/model.md
+++ /dev/null
@@ -1,16 +0,0 @@
----
-description: Explore the SAM 2 (Segment Anything Model 2) interface for real-time image segmentation. Learn about promptable segmentation and zero-shot capabilities.
-keywords: Ultralytics, SAM 2, Segment Anything Model 2, image segmentation, real-time segmentation, zero-shot performance, promptable segmentation, SA-1B dataset
----
-
-# Reference for `ultralytics/models/sam2/model.py`
-
-!!! Note
-
- This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/model.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/model.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/model.py) 🛠️. Thank you 🙏!
-
-
-
-## ::: ultralytics.models.sam2.model.SAM2
-
-
diff --git a/docs/en/reference/models/sam2/modules/decoders.md b/docs/en/reference/models/sam2/modules/decoders.md
deleted file mode 100644
index 989169a77..000000000
--- a/docs/en/reference/models/sam2/modules/decoders.md
+++ /dev/null
@@ -1,16 +0,0 @@
----
-description: Explore the MaskDecoder and MLP modules in Ultralytics for efficient mask prediction using transformer architecture. Detailed attributes, functionalities, and implementation.
-keywords: Ultralytics, MaskDecoder, MLP, machine learning, transformer architecture, mask prediction, neural networks, PyTorch modules
----
-
-# Reference for `ultralytics/models/sam2/modules/decoders.py`
-
-!!! Note
-
- This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/decoders.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/decoders.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/decoders.py) 🛠️. Thank you 🙏!
-
-
-
-## ::: ultralytics.models.sam2.modules.decoders.MaskDecoder
-
-
diff --git a/docs/en/reference/models/sam2/modules/encoders.md b/docs/en/reference/models/sam2/modules/encoders.md
deleted file mode 100644
index a3da286e1..000000000
--- a/docs/en/reference/models/sam2/modules/encoders.md
+++ /dev/null
@@ -1,28 +0,0 @@
----
-description: Discover the Ultralytics SAM 2 module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide.
-keywords: Ultralytics, SAM 2 Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning
----
-
-# Reference for `ultralytics/models/sam2/modules/encoders.py`
-
-!!! Note
-
- This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/encoders.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/encoders.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/encoders.py) 🛠️. Thank you 🙏!
-
-
-
-## ::: ultralytics.models.sam2.modules.encoders.MemoryEncoder
-
-
-
-## ::: ultralytics.models.sam2.modules.encoders.ImageEncoder
-
-
-
-## ::: ultralytics.models.sam2.modules.encoders.FpnNeck
-
-
-
-## ::: ultralytics.models.sam2.modules.encoders.Hiera
-
-
diff --git a/docs/en/reference/models/sam2/modules/memory_attention.md b/docs/en/reference/models/sam2/modules/memory_attention.md
deleted file mode 100644
index fef22174e..000000000
--- a/docs/en/reference/models/sam2/modules/memory_attention.md
+++ /dev/null
@@ -1,20 +0,0 @@
----
-description: Explore detailed documentation of various SAM 2 encoder modules such as MemoryAttentionLayer, MemoryAttention, available in Ultralytics' repository.
-keywords: Ultralytics, SAM 2 encoder, MemoryAttentionLayer, MemoryAttention
----
-
-# Reference for `ultralytics/models/sam2/modules/memory_attention.py`
-
-!!! Note
-
- This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/memory_attention.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/memory_attention.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/memory_attention.py) 🛠️. Thank you 🙏!
-
-
-
-## ::: ultralytics.models.sam2.modules.memory_attention.MemoryAttentionLayer
-
-
-
-## ::: ultralytics.models.sam2.modules.memory_attention.MemoryAttention
-
-
diff --git a/docs/en/reference/models/sam2/modules/sam2.md b/docs/en/reference/models/sam2/modules/sam2.md
deleted file mode 100644
index fcd47bbdc..000000000
--- a/docs/en/reference/models/sam2/modules/sam2.md
+++ /dev/null
@@ -1,16 +0,0 @@
----
-description: Discover the Ultralytics SAM 2 module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide.
-keywords: Ultralytics, SAM 2 Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning
----
-
-# Reference for `ultralytics/models/sam2/modules/sam2.py`
-
-!!! Note
-
- This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/sam2.py) 🛠️. Thank you 🙏!
-
-
-
-## ::: ultralytics.models.sam2.modules.sam2.SAM2Model
-
-
diff --git a/docs/en/reference/models/sam2/modules/sam2_blocks.md b/docs/en/reference/models/sam2/modules/sam2_blocks.md
deleted file mode 100644
index 796669a03..000000000
--- a/docs/en/reference/models/sam2/modules/sam2_blocks.md
+++ /dev/null
@@ -1,56 +0,0 @@
----
-description: Explore detailed documentation of various SAM 2 modules such as MaskDownSampler, CXBlock, and more, available in Ultralytics' repository.
-keywords: Ultralytics, SAM 2 encoder, DropPath, MaskDownSampler, CXBlock, Fuser, TwoWayTransformer, TwoWayAttentionBlock, RoPEAttention, MultiScaleAttention, MultiScaleBlock. PositionEmbeddingSine, do_pool
----
-
-# Reference for `ultralytics/models/sam2/modules/sam2_blocks.py`
-
-!!! Note
-
- This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2_blocks.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2_blocks.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/sam2_blocks.py) 🛠️. Thank you 🙏!
-
-
-
-## ::: ultralytics.models.sam2.modules.sam2_blocks.DropPath
-
-
-
-## ::: ultralytics.models.sam2.modules.sam2_blocks.MaskDownSampler
-
-
-
-## ::: ultralytics.models.sam2.modules.sam2_blocks.CXBlock
-
-
-
-## ::: ultralytics.models.sam2.modules.sam2_blocks.Fuser
-
-
-
-## ::: ultralytics.models.sam2.modules.sam2_blocks.TwoWayAttentionBlock
-
-
-
-## ::: ultralytics.models.sam2.modules.sam2_blocks.TwoWayTransformer
-
-
-
-## ::: ultralytics.models.sam2.modules.sam2_blocks.RoPEAttention
-
-
-
-## ::: ultralytics.models.sam2.modules.sam2_blocks.MultiScaleAttention
-
-
-
-## ::: ultralytics.models.sam2.modules.sam2_blocks.MultiScaleBlock
-
-
-
-## ::: ultralytics.models.sam2.modules.sam2_blocks.PositionEmbeddingSine
-
-
-
-## ::: ultralytics.models.sam2.modules.sam2_blocks.do_pool
-
-
diff --git a/docs/en/reference/models/sam2/modules/utils.md b/docs/en/reference/models/sam2/modules/utils.md
deleted file mode 100644
index 357cea623..000000000
--- a/docs/en/reference/models/sam2/modules/utils.md
+++ /dev/null
@@ -1,44 +0,0 @@
----
-description: Explore the detailed API reference for Ultralytics SAM 2 models.
-keywords: Ultralytics, SAM 2, API Reference, models, window partition, data processing, YOLO
----
-
-# Reference for `ultralytics/models/sam2/modules/utils.py`
-
-!!! Note
-
- This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/utils.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/utils.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/utils.py) 🛠️. Thank you 🙏!
-
-
-
-## ::: ultralytics.models.sam2.modules.utils.select_closest_cond_frames
-
-
-
-## ::: ultralytics.models.sam2.modules.utils.get_1d_sine_pe
-
-
-
-## ::: ultralytics.models.sam2.modules.utils.init_t_xy
-
-
-
-## ::: ultralytics.models.sam2.modules.utils.compute_axial_cis
-
-
-
-## ::: ultralytics.models.sam2.modules.utils.reshape_for_broadcast
-
-
-
-## ::: ultralytics.models.sam2.modules.utils.apply_rotary_enc
-
-
-
-## ::: ultralytics.models.sam2.modules.utils.window_partition
-
-
-
-## ::: ultralytics.models.sam2.modules.utils.window_unpartition
-
-
diff --git a/docs/en/reference/models/sam2/predict.md b/docs/en/reference/models/sam2/predict.md
deleted file mode 100644
index 23b30140b..000000000
--- a/docs/en/reference/models/sam2/predict.md
+++ /dev/null
@@ -1,16 +0,0 @@
----
-description: Explore Ultralytics SAM 2 Predictor for advanced, real-time image segmentation using the Segment Anything Model 2 (SAM 2). Complete implementation details and auxiliary utilities.
-keywords: Ultralytics, SAM 2, Segment Anything Model 2, image segmentation, real-time, prediction, AI, machine learning, Python, torch, inference
----
-
-# Reference for `ultralytics/models/sam2/predict.py`
-
-!!! Note
-
- This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/predict.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/predict.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/predict.py) 🛠️. Thank you 🙏!
-
-
-
-## ::: ultralytics.models.sam2.predict.SAM2Predictor
-
-
diff --git a/mkdocs.yml b/mkdocs.yml
index fcf5516b5..eb150fd7a 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -503,23 +503,15 @@ nav:
- build: reference/models/sam/build.md
- model: reference/models/sam/model.md
- modules:
+ - blocks: reference/models/sam/modules/blocks.md
- decoders: reference/models/sam/modules/decoders.md
- encoders: reference/models/sam/modules/encoders.md
+ - memory_attention: reference/models/sam/modules/memory_attention.md
- sam: reference/models/sam/modules/sam.md
- tiny_encoder: reference/models/sam/modules/tiny_encoder.md
- transformer: reference/models/sam/modules/transformer.md
+ - utils: reference/models/sam/modules/utils.md
- predict: reference/models/sam/predict.md
- - sam2:
- - build: reference/models/sam2/build.md
- - model: reference/models/sam2/model.md
- - modules:
- - decoders: reference/models/sam2/modules/decoders.md
- - encoders: reference/models/sam2/modules/encoders.md
- - memory_attention: reference/models/sam2/modules/memory_attention.md
- - sam2: reference/models/sam2/modules/sam2.md
- - sam2_blocks: reference/models/sam2/modules/sam2_blocks.md
- - utils: reference/models/sam2/modules/utils.md
- - predict: reference/models/sam2/predict.md
- utils:
- loss: reference/models/utils/loss.md
- ops: reference/models/utils/ops.md
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index 16bacbbe5..313c06fd0 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-__version__ = "8.2.72"
+__version__ = "8.2.73"
import os
@@ -8,7 +8,7 @@ import os
os.environ["OMP_NUM_THREADS"] = "1" # reduce CPU utilization during training
from ultralytics.data.explorer.explorer import Explorer
-from ultralytics.models import NAS, RTDETR, SAM, SAM2, YOLO, FastSAM, YOLOWorld
+from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld
from ultralytics.utils import ASSETS, SETTINGS
from ultralytics.utils.checks import check_yolo as checks
from ultralytics.utils.downloads import download
@@ -21,7 +21,6 @@ __all__ = (
"YOLOWorld",
"NAS",
"SAM",
- "SAM2",
"FastSAM",
"RTDETR",
"checks",
diff --git a/ultralytics/models/__init__.py b/ultralytics/models/__init__.py
index bbade0472..aff620a9a 100644
--- a/ultralytics/models/__init__.py
+++ b/ultralytics/models/__init__.py
@@ -4,7 +4,6 @@ from .fastsam import FastSAM
from .nas import NAS
from .rtdetr import RTDETR
from .sam import SAM
-from .sam2 import SAM2
from .yolo import YOLO, YOLOWorld
-__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld", "SAM2" # allow simpler import
+__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld" # allow simpler import
diff --git a/ultralytics/models/sam/__init__.py b/ultralytics/models/sam/__init__.py
index 8701fccea..a29f5cb3f 100644
--- a/ultralytics/models/sam/__init__.py
+++ b/ultralytics/models/sam/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .model import SAM
-from .predict import Predictor
+from .predict import Predictor, SAM2Predictor
-__all__ = "SAM", "Predictor" # tuple or list
+__all__ = "SAM", "Predictor", "SAM2Predictor" # tuple or list
diff --git a/ultralytics/models/sam/amg.py b/ultralytics/models/sam/amg.py
index b61c6a710..55db3e011 100644
--- a/ultralytics/models/sam/amg.py
+++ b/ultralytics/models/sam/amg.py
@@ -11,7 +11,7 @@ import torch
def is_box_near_crop_edge(
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
) -> torch.Tensor:
- """Return a boolean tensor indicating if boxes are near the crop edge."""
+ """Determines if bounding boxes are near the edge of a cropped image region using a specified tolerance."""
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
@@ -22,7 +22,7 @@ def is_box_near_crop_edge(
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
- """Yield batches of data from the input arguments."""
+ """Yields batches of data from input arguments with specified batch size for efficient processing."""
assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
for b in range(n_batches):
@@ -33,12 +33,26 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
"""
Computes the stability score for a batch of masks.
- The stability score is the IoU between the binary masks obtained by thresholding the predicted mask logits at high
- and low values.
+ The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
+ high and low values.
+
+ Args:
+ masks (torch.Tensor): Batch of predicted mask logits.
+ mask_threshold (float): Threshold value for creating binary masks.
+ threshold_offset (float): Offset applied to the threshold for creating high and low binary masks.
+
+ Returns:
+ (torch.Tensor): Stability scores for each mask in the batch.
Notes:
- One mask is always contained inside the other.
- - Save memory by preventing unnecessary cast to torch.int64
+ - Memory is saved by preventing unnecessary cast to torch.int64.
+
+ Examples:
+ >>> masks = torch.rand(10, 256, 256) # Batch of 10 masks
+ >>> mask_threshold = 0.5
+ >>> threshold_offset = 0.1
+ >>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
"""
intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
@@ -46,7 +60,7 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
def build_point_grid(n_per_side: int) -> np.ndarray:
- """Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1]."""
+ """Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks."""
offset = 1 / (2 * n_per_side)
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
@@ -55,18 +69,14 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
- """Generate point grids for all crop layers."""
+ """Generates point grids for multiple crop layers with varying scales and densities."""
return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
def generate_crop_boxes(
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
) -> Tuple[List[List[int]], List[int]]:
- """
- Generates a list of crop boxes of different sizes.
-
- Each layer has (2**i)**2 boxes for the ith layer.
- """
+ """Generates crop boxes of varying sizes for multi-scale image processing, with layered overlapping regions."""
crop_boxes, layer_idxs = [], []
im_h, im_w = im_size
short_side = min(im_h, im_w)
@@ -99,7 +109,7 @@ def generate_crop_boxes(
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
- """Uncrop bounding boxes by adding the crop box offset."""
+ """Uncrop bounding boxes by adding the crop box offset to their coordinates."""
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
# Check if boxes has a channel dimension
@@ -109,7 +119,7 @@ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
- """Uncrop points by adding the crop box offset."""
+ """Uncrop points by adding the crop box offset to their coordinates."""
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0]], device=points.device)
# Check if points has a channel dimension
@@ -119,7 +129,7 @@ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
- """Uncrop masks by padding them to the original image size."""
+ """Uncrop masks by padding them to the original image size, handling coordinate transformations."""
x0, y0, x1, y1 = crop_box
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
return masks
@@ -130,7 +140,7 @@ def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w:
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
- """Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator."""
+ """Removes small disconnected regions or holes in a mask based on area threshold and mode."""
import cv2 # type: ignore
assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid"
@@ -150,11 +160,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
- """
- Calculates boxes in XYXY format around masks.
-
- Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
- """
+ """Calculates bounding boxes in XYXY format around binary masks, handling empty masks and various input shapes."""
# torch.max below raises an error on empty inputs, just skip in this case
if torch.numel(masks) == 0:
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
diff --git a/ultralytics/models/sam/build.py b/ultralytics/models/sam/build.py
index 253c0159d..0e7ddedcf 100644
--- a/ultralytics/models/sam/build.py
+++ b/ultralytics/models/sam/build.py
@@ -13,14 +13,15 @@ import torch
from ultralytics.utils.downloads import attempt_download_asset
from .modules.decoders import MaskDecoder
-from .modules.encoders import ImageEncoderViT, PromptEncoder
-from .modules.sam import SAMModel
+from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
+from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
+from .modules.sam import SAM2Model, SAMModel
from .modules.tiny_encoder import TinyViT
from .modules.transformer import TwoWayTransformer
def build_sam_vit_h(checkpoint=None):
- """Build and return a Segment Anything Model (SAM) h-size model."""
+ """Builds and returns a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
return _build_sam(
encoder_embed_dim=1280,
encoder_depth=32,
@@ -31,7 +32,7 @@ def build_sam_vit_h(checkpoint=None):
def build_sam_vit_l(checkpoint=None):
- """Build and return a Segment Anything Model (SAM) l-size model."""
+ """Builds and returns a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
return _build_sam(
encoder_embed_dim=1024,
encoder_depth=24,
@@ -42,7 +43,7 @@ def build_sam_vit_l(checkpoint=None):
def build_sam_vit_b(checkpoint=None):
- """Build and return a Segment Anything Model (SAM) b-size model."""
+ """Constructs and returns a Segment Anything Model (SAM) with b-size architecture and optional checkpoint."""
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
@@ -53,7 +54,7 @@ def build_sam_vit_b(checkpoint=None):
def build_mobile_sam(checkpoint=None):
- """Build and return Mobile Segment Anything Model (Mobile-SAM)."""
+ """Builds and returns a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
return _build_sam(
encoder_embed_dim=[64, 128, 160, 320],
encoder_depth=[2, 2, 6, 2],
@@ -64,10 +65,85 @@ def build_mobile_sam(checkpoint=None):
)
+def build_sam2_t(checkpoint=None):
+ """Builds and returns a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
+ return _build_sam2(
+ encoder_embed_dim=96,
+ encoder_stages=[1, 2, 7, 2],
+ encoder_num_heads=1,
+ encoder_global_att_blocks=[5, 7, 9],
+ encoder_window_spec=[8, 4, 14, 7],
+ encoder_backbone_channel_list=[768, 384, 192, 96],
+ checkpoint=checkpoint,
+ )
+
+
+def build_sam2_s(checkpoint=None):
+ """Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters."""
+ return _build_sam2(
+ encoder_embed_dim=96,
+ encoder_stages=[1, 2, 11, 2],
+ encoder_num_heads=1,
+ encoder_global_att_blocks=[7, 10, 13],
+ encoder_window_spec=[8, 4, 14, 7],
+ encoder_backbone_channel_list=[768, 384, 192, 96],
+ checkpoint=checkpoint,
+ )
+
+
+def build_sam2_b(checkpoint=None):
+ """Builds and returns a SAM2 base-size model with specified architecture parameters."""
+ return _build_sam2(
+ encoder_embed_dim=112,
+ encoder_stages=[2, 3, 16, 3],
+ encoder_num_heads=2,
+ encoder_global_att_blocks=[12, 16, 20],
+ encoder_window_spec=[8, 4, 14, 7],
+ encoder_window_spatial_size=[14, 14],
+ encoder_backbone_channel_list=[896, 448, 224, 112],
+ checkpoint=checkpoint,
+ )
+
+
+def build_sam2_l(checkpoint=None):
+ """Builds and returns a large-size Segment Anything Model (SAM2) with specified architecture parameters."""
+ return _build_sam2(
+ encoder_embed_dim=144,
+ encoder_stages=[2, 6, 36, 4],
+ encoder_num_heads=2,
+ encoder_global_att_blocks=[23, 33, 43],
+ encoder_window_spec=[8, 4, 16, 8],
+ encoder_backbone_channel_list=[1152, 576, 288, 144],
+ checkpoint=checkpoint,
+ )
+
+
def _build_sam(
- encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, mobile_sam=False
+ encoder_embed_dim,
+ encoder_depth,
+ encoder_num_heads,
+ encoder_global_attn_indexes,
+ checkpoint=None,
+ mobile_sam=False,
):
- """Builds the selected SAM model architecture."""
+ """
+ Builds a Segment Anything Model (SAM) with specified encoder parameters.
+
+ Args:
+ encoder_embed_dim (int | List[int]): Embedding dimension for the encoder.
+ encoder_depth (int | List[int]): Depth of the encoder.
+ encoder_num_heads (int | List[int]): Number of attention heads in the encoder.
+ encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder.
+ checkpoint (str | None): Path to the model checkpoint file.
+ mobile_sam (bool): Whether to build a Mobile-SAM model.
+
+ Returns:
+ (SAMModel): A Segment Anything Model instance with the specified architecture.
+
+ Examples:
+ >>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11])
+ >>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True)
+ """
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
@@ -139,16 +215,131 @@ def _build_sam(
return sam
+def _build_sam2(
+ encoder_embed_dim=1280,
+ encoder_stages=[2, 6, 36, 4],
+ encoder_num_heads=2,
+ encoder_global_att_blocks=[7, 15, 23, 31],
+ encoder_backbone_channel_list=[1152, 576, 288, 144],
+ encoder_window_spatial_size=[7, 7],
+ encoder_window_spec=[8, 4, 16, 8],
+ checkpoint=None,
+):
+ """
+ Builds and returns a Segment Anything Model 2 (SAM2) with specified architecture parameters.
+
+ Args:
+ encoder_embed_dim (int): Embedding dimension for the encoder.
+ encoder_stages (List[int]): Number of blocks in each stage of the encoder.
+ encoder_num_heads (int): Number of attention heads in the encoder.
+ encoder_global_att_blocks (List[int]): Indices of global attention blocks in the encoder.
+ encoder_backbone_channel_list (List[int]): Channel dimensions for each level of the encoder backbone.
+ encoder_window_spatial_size (List[int]): Spatial size of the window for position embeddings.
+ encoder_window_spec (List[int]): Window specifications for each stage of the encoder.
+ checkpoint (str | None): Path to the checkpoint file for loading pre-trained weights.
+
+ Returns:
+ (SAM2Model): A configured and initialized SAM2 model.
+
+ Examples:
+ >>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2])
+ >>> sam2_model.eval()
+ """
+ image_encoder = ImageEncoder(
+ trunk=Hiera(
+ embed_dim=encoder_embed_dim,
+ num_heads=encoder_num_heads,
+ stages=encoder_stages,
+ global_att_blocks=encoder_global_att_blocks,
+ window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,
+ window_spec=encoder_window_spec,
+ ),
+ neck=FpnNeck(
+ d_model=256,
+ backbone_channel_list=encoder_backbone_channel_list,
+ fpn_top_down_levels=[2, 3],
+ fpn_interp_model="nearest",
+ ),
+ scalp=1,
+ )
+ memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
+ memory_encoder = MemoryEncoder(out_dim=64)
+
+ sam2 = SAM2Model(
+ image_encoder=image_encoder,
+ memory_attention=memory_attention,
+ memory_encoder=memory_encoder,
+ num_maskmem=7,
+ image_size=1024,
+ sigmoid_scale_for_mem_enc=20.0,
+ sigmoid_bias_for_mem_enc=-10.0,
+ use_mask_input_as_output_without_sam=True,
+ directly_add_no_mem_embed=True,
+ use_high_res_features_in_sam=True,
+ multimask_output_in_sam=True,
+ iou_prediction_use_sigmoid=True,
+ use_obj_ptrs_in_encoder=True,
+ add_tpos_enc_to_obj_ptrs=True,
+ only_obj_ptrs_in_the_past_for_eval=True,
+ pred_obj_scores=True,
+ pred_obj_scores_mlp=True,
+ fixed_no_obj_ptr=True,
+ multimask_output_for_tracking=True,
+ use_multimask_token_for_obj_ptr=True,
+ multimask_min_pt_num=0,
+ multimask_max_pt_num=1,
+ use_mlp_for_obj_ptr_proj=True,
+ compile_image_encoder=False,
+ sam_mask_decoder_extra_args=dict(
+ dynamic_multimask_via_stability=True,
+ dynamic_multimask_stability_delta=0.05,
+ dynamic_multimask_stability_thresh=0.98,
+ ),
+ )
+
+ if checkpoint is not None:
+ checkpoint = attempt_download_asset(checkpoint)
+ with open(checkpoint, "rb") as f:
+ state_dict = torch.load(f)["model"]
+ sam2.load_state_dict(state_dict)
+ sam2.eval()
+ return sam2
+
+
sam_model_map = {
"sam_h.pt": build_sam_vit_h,
"sam_l.pt": build_sam_vit_l,
"sam_b.pt": build_sam_vit_b,
"mobile_sam.pt": build_mobile_sam,
+ "sam2_t.pt": build_sam2_t,
+ "sam2_s.pt": build_sam2_s,
+ "sam2_b.pt": build_sam2_b,
+ "sam2_l.pt": build_sam2_l,
}
def build_sam(ckpt="sam_b.pt"):
- """Build a SAM model specified by ckpt."""
+ """
+ Builds and returns a Segment Anything Model (SAM) based on the provided checkpoint.
+
+ Args:
+ ckpt (str | Path): Path to the checkpoint file or name of a pre-defined SAM model.
+
+ Returns:
+ (SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.
+
+ Raises:
+ FileNotFoundError: If the provided checkpoint is not a supported SAM model.
+
+ Examples:
+ >>> sam_model = build_sam("sam_b.pt")
+ >>> sam_model = build_sam("path/to/custom_checkpoint.pt")
+
+ Notes:
+ Supported pre-defined models include:
+ - SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt'
+ - SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt'
+ """
model_builder = None
ckpt = str(ckpt) # to allow Path ckpt types
for k in sam_model_map.keys():
diff --git a/ultralytics/models/sam/model.py b/ultralytics/models/sam/model.py
index a819b6eaf..30d8644bb 100644
--- a/ultralytics/models/sam/model.py
+++ b/ultralytics/models/sam/model.py
@@ -20,27 +20,46 @@ from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info
from .build import build_sam
-from .predict import Predictor
+from .predict import Predictor, SAM2Predictor
class SAM(Model):
"""
- SAM (Segment Anything Model) interface class.
-
- SAM is designed for promptable real-time image segmentation. It can be used with a variety of prompts such as
- bounding boxes, points, or labels. The model has capabilities for zero-shot performance and is trained on the SA-1B
- dataset.
+ SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
+
+ This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for
+ promptable segmentation with versatility in image analysis. It supports various prompts such as bounding
+ boxes, points, or labels, and features zero-shot performance capabilities.
+
+ Attributes:
+ model (torch.nn.Module): The loaded SAM model.
+ is_sam2 (bool): Indicates whether the model is SAM2 variant.
+ task (str): The task type, set to "segment" for SAM models.
+
+ Methods:
+ predict: Performs segmentation prediction on the given image or video source.
+ info: Logs information about the SAM model.
+
+ Examples:
+ >>> sam = SAM('sam_b.pt')
+ >>> results = sam.predict('image.jpg', points=[[500, 375]])
+ >>> for r in results:
+ >>> print(f"Detected {len(r.masks)} masks")
"""
def __init__(self, model="sam_b.pt") -> None:
"""
- Initializes the SAM model with a pre-trained model file.
+ Initializes the SAM (Segment Anything Model) instance.
Args:
model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
Raises:
NotImplementedError: If the model file extension is not .pt or .pth.
+
+ Examples:
+ >>> sam = SAM('sam_b.pt')
+ >>> print(sam.is_sam2)
"""
if model and Path(model).suffix not in {".pt", ".pth"}:
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
@@ -51,30 +70,40 @@ class SAM(Model):
"""
Loads the specified weights into the SAM model.
+ This method initializes the SAM model with the provided weights file, setting up the model architecture
+ and loading the pre-trained parameters.
+
Args:
- weights (str): Path to the weights file.
- task (str, optional): Task name. Defaults to None.
- """
- if self.is_sam2:
- from ..sam2.build import build_sam2
+ weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
+ task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.
- self.model = build_sam2(weights)
- else:
- self.model = build_sam(weights)
+ Examples:
+ >>> sam = SAM('sam_b.pt')
+ >>> sam._load('path/to/custom_weights.pt')
+ """
+ self.model = build_sam(weights)
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""
Performs segmentation prediction on the given image or video source.
Args:
- source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
- stream (bool, optional): If True, enables real-time streaming. Defaults to False.
- bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
- points (list, optional): List of points for prompted segmentation. Defaults to None.
- labels (list, optional): List of labels for prompted segmentation. Defaults to None.
+ source (str | PIL.Image | numpy.ndarray): Path to the image or video file, or a PIL.Image object, or
+ a numpy.ndarray object.
+ stream (bool): If True, enables real-time streaming.
+ bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
+ points (List[List[float]] | None): List of points for prompted segmentation.
+ labels (List[int] | None): List of labels for prompted segmentation.
+ **kwargs (Any): Additional keyword arguments for prediction.
Returns:
- (list): The model predictions.
+ (List): The model predictions.
+
+ Examples:
+ >>> sam = SAM('sam_b.pt')
+ >>> results = sam.predict('image.jpg', points=[[500, 375]])
+ >>> for r in results:
+ ... print(f"Detected {len(r.masks)} masks")
"""
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
kwargs.update(overrides)
@@ -83,17 +112,27 @@ class SAM(Model):
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""
- Alias for the 'predict' method.
+ Performs segmentation prediction on the given image or video source.
+
+ This method is an alias for the 'predict' method, providing a convenient way to call the SAM model
+ for segmentation tasks.
Args:
- source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
- stream (bool, optional): If True, enables real-time streaming. Defaults to False.
- bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
- points (list, optional): List of points for prompted segmentation. Defaults to None.
- labels (list, optional): List of labels for prompted segmentation. Defaults to None.
+ source (str | PIL.Image | numpy.ndarray | None): Path to the image or video file, or a PIL.Image
+ object, or a numpy.ndarray object.
+ stream (bool): If True, enables real-time streaming.
+ bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
+ points (List[List[float]] | None): List of points for prompted segmentation.
+ labels (List[int] | None): List of labels for prompted segmentation.
+ **kwargs (Any): Additional keyword arguments to be passed to the predict method.
Returns:
- (list): The model predictions.
+ (List): The model predictions, typically containing segmentation masks and other relevant information.
+
+ Examples:
+ >>> sam = SAM('sam_b.pt')
+ >>> results = sam('image.jpg', points=[[500, 375]])
+ >>> print(f"Detected {len(results[0].masks)} masks")
"""
return self.predict(source, stream, bboxes, points, labels, **kwargs)
@@ -101,12 +140,20 @@ class SAM(Model):
"""
Logs information about the SAM model.
+ This method provides details about the Segment Anything Model (SAM), including its architecture,
+ parameters, and computational requirements.
+
Args:
- detailed (bool, optional): If True, displays detailed information about the model. Defaults to False.
- verbose (bool, optional): If True, displays information on the console. Defaults to True.
+ detailed (bool): If True, displays detailed information about the model layers and operations.
+ verbose (bool): If True, prints the information to the console.
Returns:
- (tuple): A tuple containing the model's information.
+ (Tuple): A tuple containing the model's information (string representations of the model).
+
+ Examples:
+ >>> sam = SAM('sam_b.pt')
+ >>> info = sam.info()
+ >>> print(info[0]) # Print summary information
"""
return model_info(self.model, detailed=detailed, verbose=verbose)
@@ -116,8 +163,13 @@ class SAM(Model):
Provides a mapping from the 'segment' task to its corresponding 'Predictor'.
Returns:
- (dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
+ (Dict[str, Type[Predictor]]): A dictionary mapping the 'segment' task to its corresponding Predictor
+ class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
+
+ Examples:
+ >>> sam = SAM('sam_b.pt')
+ >>> task_map = sam.task_map
+ >>> print(task_map)
+ {'segment': }
"""
- from ..sam2.predict import SAM2Predictor
-
return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
diff --git a/ultralytics/models/sam/modules/blocks.py b/ultralytics/models/sam/modules/blocks.py
new file mode 100644
index 000000000..008185a29
--- /dev/null
+++ b/ultralytics/models/sam/modules/blocks.py
@@ -0,0 +1,1131 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+import copy
+import math
+from functools import partial
+from typing import Any, Optional, Tuple, Type, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+from ultralytics.nn.modules import MLP, LayerNorm2d, MLPBlock
+
+from .transformer import Attention, TwoWayAttentionBlock, TwoWayTransformer
+from .utils import add_decomposed_rel_pos, apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition
+
+
+class DropPath(nn.Module):
+ """
+ Implements stochastic depth regularization for neural networks during training.
+
+ Attributes:
+ drop_prob (float): Probability of dropping a path during training.
+ scale_by_keep (bool): Whether to scale the output by the keep probability.
+
+ Methods:
+ forward: Applies stochastic depth to input tensor during training, with optional scaling.
+
+ Examples:
+ >>> drop_path = DropPath(drop_prob=0.2, scale_by_keep=True)
+ >>> x = torch.randn(32, 64, 224, 224)
+ >>> output = drop_path(x)
+ """
+
+ def __init__(self, drop_prob=0.0, scale_by_keep=True):
+ """Initialize DropPath module for stochastic depth regularization during training."""
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ """Applies stochastic depth to input tensor during training, with optional scaling."""
+ if self.drop_prob == 0.0 or not self.training:
+ return x
+ keep_prob = 1 - self.drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and self.scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class MaskDownSampler(nn.Module):
+ """
+ A mask downsampling and embedding module for efficient processing of input masks.
+
+ This class implements a mask downsampler that progressively reduces the spatial dimensions of input masks
+ while expanding their channel dimensions using convolutional layers, layer normalization, and activation
+ functions.
+
+ Attributes:
+ encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and
+ activation functions for downsampling and embedding masks.
+
+ Methods:
+ forward: Downsamples and encodes input mask to embed_dim channels.
+
+ Examples:
+ >>> mask_downsampler = MaskDownSampler(embed_dim=256, kernel_size=4, stride=4, padding=0, total_stride=16)
+ >>> input_mask = torch.randn(1, 1, 256, 256)
+ >>> output = mask_downsampler(input_mask)
+ >>> print(output.shape)
+ torch.Size([1, 256, 16, 16])
+ """
+
+ def __init__(
+ self,
+ embed_dim=256,
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ total_stride=16,
+ activation=nn.GELU,
+ ):
+ """Initializes a mask downsampler module for progressive downsampling and channel expansion."""
+ super().__init__()
+ num_layers = int(math.log2(total_stride) // math.log2(stride))
+ assert stride**num_layers == total_stride
+ self.encoder = nn.Sequential()
+ mask_in_chans, mask_out_chans = 1, 1
+ for _ in range(num_layers):
+ mask_out_chans = mask_in_chans * (stride**2)
+ self.encoder.append(
+ nn.Conv2d(
+ mask_in_chans,
+ mask_out_chans,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ )
+ )
+ self.encoder.append(LayerNorm2d(mask_out_chans))
+ self.encoder.append(activation())
+ mask_in_chans = mask_out_chans
+
+ self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
+
+ def forward(self, x):
+ """Downsamples and encodes input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
+ return self.encoder(x)
+
+
+class CXBlock(nn.Module):
+ """
+ ConvNeXt Block for efficient feature extraction in convolutional neural networks.
+
+ This block implements a modified version of the ConvNeXt architecture, offering improved performance and
+ flexibility in feature extraction.
+
+ Attributes:
+ dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer.
+ norm (LayerNorm2d): Layer normalization applied to channels.
+ pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.
+ act (nn.GELU): GELU activation function.
+ pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.
+ gamma (nn.Parameter | None): Learnable scale parameter for layer scaling.
+ drop_path (nn.Module): DropPath layer for stochastic depth regularization.
+
+ Methods:
+ forward: Processes the input tensor through the ConvNeXt block.
+
+ Examples:
+ >>> import torch
+ >>> x = torch.randn(1, 64, 56, 56)
+ >>> block = CXBlock(dim=64, kernel_size=7, padding=3)
+ >>> output = block(x)
+ >>> print(output.shape)
+ torch.Size([1, 64, 56, 56])
+ """
+
+ def __init__(
+ self,
+ dim,
+ kernel_size=7,
+ padding=3,
+ drop_path=0.0,
+ layer_scale_init_value=1e-6,
+ use_dwconv=True,
+ ):
+ """
+ Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks.
+
+ This block implements a modified version of the ConvNeXt architecture, offering improved performance and
+ flexibility in feature extraction.
+
+ Args:
+ dim (int): Number of input channels.
+ kernel_size (int): Size of the convolutional kernel.
+ padding (int): Padding size for the convolution.
+ drop_path (float): Stochastic depth rate.
+ layer_scale_init_value (float): Initial value for Layer Scale.
+ use_dwconv (bool): Whether to use depthwise convolution.
+
+ Examples:
+ >>> block = CXBlock(dim=64, kernel_size=7, padding=3)
+ >>> x = torch.randn(1, 64, 32, 32)
+ >>> output = block(x)
+ >>> print(output.shape)
+ torch.Size([1, 64, 32, 32])
+ """
+ super().__init__()
+ self.dwconv = nn.Conv2d(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ padding=padding,
+ groups=dim if use_dwconv else 1,
+ ) # depthwise conv
+ self.norm = LayerNorm2d(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(4 * dim, dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, x):
+ """Applies ConvNeXt block operations to input tensor, including convolutions and residual connection."""
+ input = x
+ x = self.dwconv(x)
+ x = self.norm(x)
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+ x = input + self.drop_path(x)
+ return x
+
+
+class Fuser(nn.Module):
+ """
+ A module for fusing features through multiple layers of a neural network.
+
+ This class applies a series of identical layers to an input tensor, optionally projecting the input first.
+
+ Attributes:
+ proj (nn.Module): An optional input projection layer. Identity if no projection is needed.
+ layers (nn.ModuleList): A list of identical layers to be applied sequentially.
+
+ Methods:
+ forward: Applies the fuser to an input tensor.
+
+ Examples:
+ >>> layer = CXBlock(dim=256)
+ >>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True)
+ >>> x = torch.randn(1, 256, 32, 32)
+ >>> output = fuser(x)
+ >>> print(output.shape)
+ torch.Size([1, 256, 32, 32])
+ """
+
+ def __init__(self, layer, num_layers, dim=None, input_projection=False):
+ """
+ Initializes the Fuser module for feature fusion through multiple layers.
+
+ This module creates a sequence of identical layers and optionally applies an input projection.
+
+ Args:
+ layer (nn.Module): The layer to be replicated in the fuser.
+ num_layers (int): The number of times to replicate the layer.
+ dim (int | None): The dimension for input projection, if used.
+ input_projection (bool): Whether to use input projection.
+
+ Examples:
+ >>> layer = nn.Linear(64, 64)
+ >>> fuser = Fuser(layer, num_layers=3, dim=64, input_projection=True)
+ >>> input_tensor = torch.randn(1, 64)
+ >>> output = fuser(input_tensor)
+ """
+ super().__init__()
+ self.proj = nn.Identity()
+ self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
+
+ if input_projection:
+ assert dim is not None
+ self.proj = nn.Conv2d(dim, dim, kernel_size=1)
+
+ def forward(self, x):
+ """Applies a series of layers to the input tensor, optionally projecting it first."""
+ x = self.proj(x)
+ for layer in self.layers:
+ x = layer(x)
+ return x
+
+
+class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock):
+ """
+ A two-way attention block for performing self-attention and cross-attention in both directions.
+
+ This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on
+ sparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and
+ cross-attention from dense to sparse inputs.
+
+ Attributes:
+ self_attn (Attention): Self-attention layer for queries.
+ norm1 (nn.LayerNorm): Layer normalization after the first attention block.
+ cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
+ norm2 (nn.LayerNorm): Layer normalization after the second attention block.
+ mlp (MLP): MLP block for transforming query embeddings.
+ norm3 (nn.LayerNorm): Layer normalization after the MLP block.
+ norm4 (nn.LayerNorm): Layer normalization after the third attention block.
+ cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
+ skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer.
+
+ Methods:
+ forward: Processes input through the attention blocks and MLP.
+
+ Examples:
+ >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8)
+ >>> sparse_input = torch.randn(1, 100, 256)
+ >>> dense_input = torch.randn(1, 256, 16, 16)
+ >>> sparse_output, dense_output = block(sparse_input, dense_input)
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int = 2048,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ skip_first_layer_pe: bool = False,
+ ) -> None:
+ """
+ Initializes a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
+
+ This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse
+ inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention
+ from dense to sparse inputs.
+
+ Args:
+ embedding_dim (int): The channel dimension of the embeddings.
+ num_heads (int): The number of heads in the attention layers.
+ mlp_dim (int): The hidden dimension of the MLP block.
+ activation (Type[nn.Module]): The activation function of the MLP block.
+ attention_downsample_rate (int): The downsample rate for attention computations.
+ skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
+
+ Examples:
+ >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8, mlp_dim=2048)
+ >>> sparse_inputs = torch.randn(1, 100, 256)
+ >>> dense_inputs = torch.randn(1, 256, 32, 32)
+ >>> sparse_outputs, dense_outputs = block(sparse_inputs, dense_inputs)
+ """
+ super().__init__(embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate, skip_first_layer_pe)
+ self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, act=activation)
+
+
+class SAM2TwoWayTransformer(TwoWayTransformer):
+ """
+ A Two-Way Transformer module for simultaneous attention to image and query points.
+
+ This class extends the TwoWayTransformer, implementing a specialized transformer decoder that attends to an
+ input image using queries with supplied positional embeddings. It is particularly useful for tasks like
+ object detection, image segmentation, and point cloud processing.
+
+ Attributes:
+ depth (int): Number of layers in the transformer.
+ embedding_dim (int): Channel dimension for input embeddings.
+ num_heads (int): Number of heads for multihead attention.
+ mlp_dim (int): Internal channel dimension for the MLP block.
+ layers (nn.ModuleList): List of SAM2TwoWayAttentionBlock layers comprising the transformer.
+ final_attn_token_to_image (Attention): Final attention layer from queries to image.
+ norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
+
+ Methods:
+ forward: Processes input image embeddings and query embeddings through the transformer.
+
+ Examples:
+ >>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
+ >>> image_embedding = torch.randn(1, 256, 64, 64)
+ >>> query_embedding = torch.randn(1, 100, 256)
+ >>> output = transformer(image_embedding, query_embedding)
+ >>> print(output[0].shape, output[1].shape)
+ torch.Size([1, 100, 256]) torch.Size([1, 256, 64, 64])
+ """
+
+ def __init__(
+ self,
+ depth: int,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ ) -> None:
+ """
+ Initializes a SAM2TwoWayTransformer instance.
+
+ This transformer decoder attends to an input image using queries with supplied positional embeddings.
+ It is designed for tasks like object detection, image segmentation, and point cloud processing.
+
+ Args:
+ depth (int): Number of layers in the transformer.
+ embedding_dim (int): Channel dimension for the input embeddings.
+ num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
+ mlp_dim (int): Channel dimension internal to the MLP block.
+ activation (Type[nn.Module]): Activation function to use in the MLP block.
+ attention_downsample_rate (int): Downsampling rate for attention computations.
+
+ Examples:
+ >>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
+ >>> transformer
+ SAM2TwoWayTransformer(
+ (layers): ModuleList(
+ (0-4): 5 x SAM2TwoWayAttentionBlock(...)
+ )
+ (final_attn_token_to_image): Attention(...)
+ (norm_final_attn): LayerNorm(...)
+ )
+ """
+ super().__init__(depth, embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate)
+ self.layers = nn.ModuleList()
+ for i in range(depth):
+ self.layers.append(
+ SAM2TwoWayAttentionBlock(
+ embedding_dim=embedding_dim,
+ num_heads=num_heads,
+ mlp_dim=mlp_dim,
+ activation=activation,
+ attention_downsample_rate=attention_downsample_rate,
+ skip_first_layer_pe=(i == 0),
+ )
+ )
+
+
+class RoPEAttention(Attention):
+ """
+ Implements rotary position encoding for attention mechanisms in transformer architectures.
+
+ This class extends the base Attention class by incorporating Rotary Position Encoding (RoPE) to enhance
+ the positional awareness of the attention mechanism.
+
+ Attributes:
+ compute_cis (Callable): Function to compute axial complex numbers for rotary encoding.
+ freqs_cis (Tensor): Precomputed frequency tensor for rotary encoding.
+ rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories.
+
+ Methods:
+ forward: Applies rotary position encoding and computes attention between query, key, and value tensors.
+
+ Examples:
+ >>> rope_attn = RoPEAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32))
+ >>> q = torch.randn(1, 1024, 256)
+ >>> k = torch.randn(1, 1024, 256)
+ >>> v = torch.randn(1, 1024, 256)
+ >>> output = rope_attn(q, k, v)
+ >>> print(output.shape)
+ torch.Size([1, 1024, 256])
+ """
+
+ def __init__(
+ self,
+ *args,
+ rope_theta=10000.0,
+ rope_k_repeat=False,
+ feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
+ **kwargs,
+ ):
+ """Initializes RoPEAttention with rotary position encoding for enhanced positional awareness."""
+ super().__init__(*args, **kwargs)
+
+ self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta)
+ freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
+ self.freqs_cis = freqs_cis
+ self.rope_k_repeat = rope_k_repeat # repeat q rope to match k length, needed for cross-attention to memories
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor:
+ """Applies rotary position encoding and computes attention between query, key, and value tensors."""
+ q = self.q_proj(q)
+ k = self.k_proj(k)
+ v = self.v_proj(v)
+
+ # Separate into heads
+ q = self._separate_heads(q, self.num_heads)
+ k = self._separate_heads(k, self.num_heads)
+ v = self._separate_heads(v, self.num_heads)
+
+ # Apply rotary position encoding
+ w = h = math.sqrt(q.shape[-2])
+ self.freqs_cis = self.freqs_cis.to(q.device)
+ if self.freqs_cis.shape[0] != q.shape[-2]:
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
+ if q.shape[-2] != k.shape[-2]:
+ assert self.rope_k_repeat
+
+ num_k_rope = k.size(-2) - num_k_exclude_rope
+ q, k[:, :, :num_k_rope] = apply_rotary_enc(
+ q,
+ k[:, :, :num_k_rope],
+ freqs_cis=self.freqs_cis,
+ repeat_freqs_k=self.rope_k_repeat,
+ )
+
+ # Attention
+ _, _, _, c_per_head = q.shape
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
+ attn = attn / math.sqrt(c_per_head)
+ attn = torch.softmax(attn, dim=-1)
+
+ # Get output
+ out = attn @ v
+
+ out = self._recombine_heads(out)
+ out = self.out_proj(out)
+
+ return out
+
+
+def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
+ """Applies pooling and optional normalization to a tensor, handling spatial dimension permutations."""
+ if pool is None:
+ return x
+ # (B, H, W, C) -> (B, C, H, W)
+ x = x.permute(0, 3, 1, 2)
+ x = pool(x)
+ # (B, C, H', W') -> (B, H', W', C)
+ x = x.permute(0, 2, 3, 1)
+ if norm:
+ x = norm(x)
+
+ return x
+
+
+class MultiScaleAttention(nn.Module):
+ """
+ Implements multi-scale self-attention with optional query pooling for efficient feature extraction.
+
+ This class provides a flexible implementation of multi-scale attention, allowing for optional
+ downsampling of query features through pooling. It's designed to enhance the model's ability to
+ capture multi-scale information in visual tasks.
+
+ Attributes:
+ dim (int): Input dimension of the feature map.
+ dim_out (int): Output dimension of the attention module.
+ num_heads (int): Number of attention heads.
+ scale (float): Scaling factor for dot-product attention.
+ q_pool (nn.Module | None): Optional pooling module for query features.
+ qkv (nn.Linear): Linear projection for query, key, and value.
+ proj (nn.Linear): Output projection.
+
+ Methods:
+ forward: Applies multi-scale attention to the input tensor.
+
+ Examples:
+ >>> import torch
+ >>> from torch import nn
+ >>> x = torch.randn(1, 64, 64, 256)
+ >>> msa = MultiScaleAttention(dim=256, dim_out=256, num_heads=8)
+ >>> output = msa(x)
+ >>> print(output.shape)
+ torch.Size([1, 64, 64, 256])
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int,
+ num_heads: int,
+ q_pool: nn.Module = None,
+ ):
+ """Initializes multi-scale attention with optional query pooling for efficient feature extraction."""
+ super().__init__()
+
+ self.dim = dim
+ self.dim_out = dim_out
+
+ self.num_heads = num_heads
+ head_dim = dim_out // num_heads
+ self.scale = head_dim**-0.5
+
+ self.q_pool = q_pool
+ self.qkv = nn.Linear(dim, dim_out * 3)
+ self.proj = nn.Linear(dim_out, dim_out)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Applies multi-scale attention with optional query pooling to extract multi-scale features."""
+ B, H, W, _ = x.shape
+ # qkv with shape (B, H * W, 3, nHead, C)
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
+ # q, k, v with shape (B, H * W, nheads, C)
+ q, k, v = torch.unbind(qkv, 2)
+
+ # Q pooling (for downsample at stage changes)
+ if self.q_pool:
+ q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
+ H, W = q.shape[1:3] # downsampled shape
+ q = q.reshape(B, H * W, self.num_heads, -1)
+
+ # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
+ x = F.scaled_dot_product_attention(
+ q.transpose(1, 2),
+ k.transpose(1, 2),
+ v.transpose(1, 2),
+ )
+ # Transpose back
+ x = x.transpose(1, 2)
+ x = x.reshape(B, H, W, -1)
+
+ x = self.proj(x)
+
+ return x
+
+
+class MultiScaleBlock(nn.Module):
+ """
+ A multi-scale attention block with window partitioning and query pooling for efficient vision transformers.
+
+ This class implements a multi-scale attention mechanism with optional window partitioning and downsampling,
+ designed for use in vision transformer architectures.
+
+ Attributes:
+ dim (int): Input dimension of the block.
+ dim_out (int): Output dimension of the block.
+ norm1 (nn.Module): First normalization layer.
+ window_size (int): Size of the window for partitioning.
+ pool (nn.Module | None): Pooling layer for query downsampling.
+ q_stride (Tuple[int, int] | None): Stride for query pooling.
+ attn (MultiScaleAttention): Multi-scale attention module.
+ drop_path (nn.Module): Drop path layer for regularization.
+ norm2 (nn.Module): Second normalization layer.
+ mlp (MLP): Multi-layer perceptron module.
+ proj (nn.Linear | None): Projection layer for dimension mismatch.
+
+ Methods:
+ forward: Processes input tensor through the multi-scale block.
+
+ Examples:
+ >>> block = MultiScaleBlock(dim=256, dim_out=512, num_heads=8, window_size=7)
+ >>> x = torch.randn(1, 56, 56, 256)
+ >>> output = block(x)
+ >>> print(output.shape)
+ torch.Size([1, 28, 28, 512])
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ drop_path: float = 0.0,
+ norm_layer: Union[nn.Module, str] = "LayerNorm",
+ q_stride: Tuple[int, int] = None,
+ act_layer: nn.Module = nn.GELU,
+ window_size: int = 0,
+ ):
+ """Initializes a multi-scale attention block with window partitioning and optional query pooling."""
+ super().__init__()
+
+ if isinstance(norm_layer, str):
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
+
+ self.dim = dim
+ self.dim_out = dim_out
+ self.norm1 = norm_layer(dim)
+
+ self.window_size = window_size
+
+ self.pool, self.q_stride = None, q_stride
+ if self.q_stride:
+ self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False)
+
+ self.attn = MultiScaleAttention(
+ dim,
+ dim_out,
+ num_heads=num_heads,
+ q_pool=self.pool,
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim_out)
+ self.mlp = MLP(
+ dim_out,
+ int(dim_out * mlp_ratio),
+ dim_out,
+ num_layers=2,
+ act=act_layer,
+ )
+
+ if dim != dim_out:
+ self.proj = nn.Linear(dim, dim_out)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Processes input through multi-scale attention and MLP, with optional windowing and downsampling."""
+ shortcut = x # B, H, W, C
+ x = self.norm1(x)
+
+ # Skip connection
+ if self.dim != self.dim_out:
+ shortcut = do_pool(self.proj(x), self.pool)
+
+ # Window partition
+ window_size = self.window_size
+ if window_size > 0:
+ H, W = x.shape[1], x.shape[2]
+ x, pad_hw = window_partition(x, window_size)
+
+ # Window Attention + Q Pooling (if stage change)
+ x = self.attn(x)
+ if self.q_stride:
+ # Shapes have changed due to Q pooling
+ window_size = self.window_size // self.q_stride[0]
+ H, W = shortcut.shape[1:3]
+
+ pad_h = (window_size - H % window_size) % window_size
+ pad_w = (window_size - W % window_size) % window_size
+ pad_hw = (H + pad_h, W + pad_w)
+
+ # Reverse window partition
+ if self.window_size > 0:
+ x = window_unpartition(x, window_size, pad_hw, (H, W))
+
+ x = shortcut + self.drop_path(x)
+ # MLP
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ A module for generating sinusoidal positional embeddings for 2D inputs like images.
+
+ This class implements sinusoidal position encoding for 2D spatial positions, which can be used in
+ transformer-based models for computer vision tasks.
+
+ Attributes:
+ num_pos_feats (int): Number of positional features (half of the embedding dimension).
+ temperature (int): Temperature parameter for the sinusoidal functions.
+ normalize (bool): Whether to normalize the positional embeddings.
+ scale (float): Scaling factor for the embeddings when normalize is True.
+ cache (Dict): Cache for storing precomputed embeddings.
+
+ Methods:
+ _encode_xy: Encodes 2D positions using sine and cosine functions.
+ encode_boxes: Encodes box coordinates and dimensions into positional embeddings.
+ encode_points: Encodes 2D point coordinates with sinusoidal positional embeddings.
+ forward: Generates sinusoidal position embeddings for 2D inputs.
+
+ Examples:
+ >>> pos_emb = PositionEmbeddingSine(num_pos_feats=128)
+ >>> x = torch.randn(1, 3, 224, 224)
+ >>> embeddings = pos_emb(x)
+ >>> print(embeddings.shape)
+ torch.Size([1, 256, 224, 224])
+ """
+
+ def __init__(
+ self,
+ num_pos_feats,
+ temperature: int = 10000,
+ normalize: bool = True,
+ scale: Optional[float] = None,
+ ):
+ """Initializes sinusoidal position embeddings for 2D image inputs."""
+ super().__init__()
+ assert num_pos_feats % 2 == 0, "Expecting even model width"
+ self.num_pos_feats = num_pos_feats // 2
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ self.cache = {}
+
+ def _encode_xy(self, x, y):
+ """Encodes 2D positions using sine/cosine functions for transformer positional embeddings."""
+ assert len(x) == len(y) and x.ndim == y.ndim == 1
+ x_embed = x * self.scale
+ y_embed = y * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, None] / dim_t
+ pos_y = y_embed[:, None] / dim_t
+ pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
+ pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
+ return pos_x, pos_y
+
+ @torch.no_grad()
+ def encode_boxes(self, x, y, w, h):
+ """Encodes box coordinates and dimensions into positional embeddings for detection."""
+ pos_x, pos_y = self._encode_xy(x, y)
+ pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
+ return pos
+
+ encode = encode_boxes # Backwards compatibility
+
+ @torch.no_grad()
+ def encode_points(self, x, y, labels):
+ """Encodes 2D points with sinusoidal embeddings and appends labels."""
+ (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
+ assert bx == by and nx == ny and bx == bl and nx == nl
+ pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
+ pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
+ pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
+ return pos
+
+ @torch.no_grad()
+ def forward(self, x: torch.Tensor):
+ """Generates sinusoidal position embeddings for 2D inputs like images."""
+ cache_key = (x.shape[-2], x.shape[-1])
+ if cache_key in self.cache:
+ return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
+ y_embed = (
+ torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
+ .view(1, -1, 1)
+ .repeat(x.shape[0], 1, x.shape[-1])
+ )
+ x_embed = (
+ torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
+ .view(1, 1, -1)
+ .repeat(x.shape[0], x.shape[-2], 1)
+ )
+
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ self.cache[cache_key] = pos[0]
+ return pos
+
+
+class PositionEmbeddingRandom(nn.Module):
+ """
+ Positional encoding using random spatial frequencies.
+
+ This class generates positional embeddings for input coordinates using random spatial frequencies. It is
+ particularly useful for transformer-based models that require position information.
+
+ Attributes:
+ positional_encoding_gaussian_matrix (torch.Tensor): A buffer containing random values for encoding.
+
+ Methods:
+ _pe_encoding: Positionally encodes points that are normalized to [0,1].
+ forward: Generates positional encoding for a grid of the specified size.
+ forward_with_coords: Positionally encodes points that are not normalized to [0,1].
+
+ Examples:
+ >>> pe = PositionEmbeddingRandom(num_pos_feats=64)
+ >>> size = (32, 32)
+ >>> encoding = pe(size)
+ >>> print(encoding.shape)
+ torch.Size([128, 32, 32])
+ """
+
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
+ """Initializes random spatial frequency position embedding for transformers."""
+ super().__init__()
+ if scale is None or scale <= 0.0:
+ scale = 1.0
+ self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)))
+
+ # Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
+ torch.use_deterministic_algorithms(False)
+ torch.backends.cudnn.deterministic = False
+
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
+ """Encodes normalized [0,1] coordinates using random spatial frequencies."""
+ # Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+ coords = 2 * coords - 1
+ coords = coords @ self.positional_encoding_gaussian_matrix
+ coords = 2 * np.pi * coords
+ # Outputs d_1 x ... x d_n x C shape
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
+
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
+ """Generates positional encoding for a grid using random spatial frequencies."""
+ h, w = size
+ device: Any = self.positional_encoding_gaussian_matrix.device
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
+ y_embed = grid.cumsum(dim=0) - 0.5
+ x_embed = grid.cumsum(dim=1) - 0.5
+ y_embed = y_embed / h
+ x_embed = x_embed / w
+
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
+ return pe.permute(2, 0, 1) # C x H x W
+
+ def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
+ """Positionally encodes input coordinates, normalizing them to [0,1] based on the given image size."""
+ coords = coords_input.clone()
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
+
+
+class Block(nn.Module):
+ """
+ Transformer block with support for window attention and residual propagation.
+
+ This class implements a transformer block that can use either global or windowed self-attention,
+ followed by a feed-forward network. It supports relative positional embeddings and is designed
+ for use in vision transformer architectures.
+
+ Attributes:
+ norm1 (nn.Module): First normalization layer.
+ attn (REAttention): Self-attention layer with optional relative positional encoding.
+ norm2 (nn.Module): Second normalization layer.
+ mlp (MLPBlock): Multi-layer perceptron block.
+ window_size (int): Size of attention window. If 0, global attention is used.
+
+ Methods:
+ forward: Processes input through the transformer block.
+
+ Examples:
+ >>> import torch
+ >>> block = Block(dim=256, num_heads=8, window_size=7)
+ >>> x = torch.randn(1, 56, 56, 256)
+ >>> output = block(x)
+ >>> print(output.shape)
+ torch.Size([1, 56, 56, 256])
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
+ act_layer: Type[nn.Module] = nn.GELU,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ window_size: int = 0,
+ input_size: Optional[Tuple[int, int]] = None,
+ ) -> None:
+ """
+ Initializes a transformer block with optional window attention and relative positional embeddings.
+
+ This constructor sets up a transformer block that can use either global or windowed self-attention,
+ followed by a feed-forward network. It supports relative positional embeddings and is designed
+ for use in vision transformer architectures.
+
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads in the self-attention layer.
+ mlp_ratio (float): Ratio of mlp hidden dimension to embedding dimension.
+ qkv_bias (bool): If True, adds a learnable bias to query, key, value projections.
+ norm_layer (Type[nn.Module]): Type of normalization layer to use.
+ act_layer (Type[nn.Module]): Type of activation function to use in the MLP block.
+ use_rel_pos (bool): If True, uses relative positional embeddings in attention.
+ rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
+ window_size (int): Size of attention window. If 0, uses global attention.
+ input_size (Optional[Tuple[int, int]]): Input resolution for calculating relative positional parameter size.
+
+ Examples:
+ >>> block = Block(dim=256, num_heads=8, window_size=7)
+ >>> x = torch.randn(1, 56, 56, 256)
+ >>> output = block(x)
+ >>> print(output.shape)
+ torch.Size([1, 56, 56, 256])
+ """
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = REAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ use_rel_pos=use_rel_pos,
+ rel_pos_zero_init=rel_pos_zero_init,
+ input_size=input_size if window_size == 0 else (window_size, window_size),
+ )
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
+
+ self.window_size = window_size
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Processes input through transformer block with optional windowed self-attention and residual connection."""
+ shortcut = x
+ x = self.norm1(x)
+ # Window partition
+ if self.window_size > 0:
+ H, W = x.shape[1], x.shape[2]
+ x, pad_hw = window_partition(x, self.window_size)
+
+ x = self.attn(x)
+ # Reverse window partition
+ if self.window_size > 0:
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
+
+ x = shortcut + x
+ return x + self.mlp(self.norm2(x))
+
+
+class REAttention(nn.Module):
+ """
+ Rotary Embedding Attention module for efficient self-attention in transformer architectures.
+
+ This class implements a multi-head attention mechanism with rotary positional embeddings, designed
+ for use in vision transformer models. It supports optional query pooling and window partitioning
+ for efficient processing of large inputs.
+
+ Attributes:
+ compute_cis (Callable): Function to compute axial complex numbers for rotary encoding.
+ freqs_cis (Tensor): Precomputed frequency tensor for rotary encoding.
+ rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories.
+ q_proj (nn.Linear): Linear projection for query.
+ k_proj (nn.Linear): Linear projection for key.
+ v_proj (nn.Linear): Linear projection for value.
+ out_proj (nn.Linear): Output projection.
+ num_heads (int): Number of attention heads.
+ internal_dim (int): Internal dimension for attention computation.
+
+ Methods:
+ forward: Applies rotary position encoding and computes attention between query, key, and value tensors.
+
+ Examples:
+ >>> rope_attn = REAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32))
+ >>> q = torch.randn(1, 1024, 256)
+ >>> k = torch.randn(1, 1024, 256)
+ >>> v = torch.randn(1, 1024, 256)
+ >>> output = rope_attn(q, k, v)
+ >>> print(output.shape)
+ torch.Size([1, 1024, 256])
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = True,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ input_size: Optional[Tuple[int, int]] = None,
+ ) -> None:
+ """
+ Initializes a Relative Position Attention module for transformer-based architectures.
+
+ This module implements multi-head attention with optional relative positional encodings, designed
+ specifically for vision tasks in transformer models.
+
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads. Default is 8.
+ qkv_bias (bool): If True, adds a learnable bias to query, key, value projections. Default is True.
+ use_rel_pos (bool): If True, uses relative positional encodings. Default is False.
+ rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero. Default is True.
+ input_size (Tuple[int, int] | None): Input resolution for calculating relative positional parameter size.
+ Required if use_rel_pos is True. Default is None.
+
+ Examples:
+ >>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32))
+ >>> x = torch.randn(1, 32, 32, 256)
+ >>> output = attention(x)
+ >>> print(output.shape)
+ torch.Size([1, 32, 32, 256])
+ """
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ self.use_rel_pos = use_rel_pos
+ if self.use_rel_pos:
+ assert input_size is not None, "Input size must be provided if using relative positional encoding."
+ # Initialize relative positional embeddings
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Applies multi-head attention with optional relative positional encoding to input tensor."""
+ B, H, W, _ = x.shape
+ # qkv with shape (3, B, nHead, H * W, C)
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ # q, k, v with shape (B * nHead, H * W, C)
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
+
+ attn = (q * self.scale) @ k.transpose(-2, -1)
+
+ if self.use_rel_pos:
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
+
+ attn = attn.softmax(dim=-1)
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
+ return self.proj(x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ Image to Patch Embedding module for vision transformer architectures.
+
+ This module converts an input image into a sequence of patch embeddings using a convolutional layer.
+ It is commonly used as the first layer in vision transformer architectures to transform image data
+ into a suitable format for subsequent transformer blocks.
+
+ Attributes:
+ proj (nn.Conv2d): Convolutional layer for projecting image patches to embeddings.
+
+ Methods:
+ forward: Applies patch embedding to the input tensor.
+
+ Examples:
+ >>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768)
+ >>> x = torch.randn(1, 3, 224, 224)
+ >>> output = patch_embed(x)
+ >>> print(output.shape)
+ torch.Size([1, 768, 14, 14])
+ """
+
+ def __init__(
+ self,
+ kernel_size: Tuple[int, int] = (16, 16),
+ stride: Tuple[int, int] = (16, 16),
+ padding: Tuple[int, int] = (0, 0),
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ ) -> None:
+ """
+ Initializes the PatchEmbed module for converting image patches to embeddings.
+
+ This module is typically used as the first layer in vision transformer architectures to transform
+ image data into a suitable format for subsequent transformer blocks.
+
+ Args:
+ kernel_size (Tuple[int, int]): Size of the convolutional kernel for patch extraction.
+ stride (Tuple[int, int]): Stride of the convolutional operation.
+ padding (Tuple[int, int]): Padding applied to the input before convolution.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): Dimensionality of the output patch embeddings.
+
+ Examples:
+ >>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768)
+ >>> x = torch.randn(1, 3, 224, 224)
+ >>> output = patch_embed(x)
+ >>> print(output.shape)
+ torch.Size([1, 768, 14, 14])
+ """
+ super().__init__()
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Computes patch embedding by applying convolution and transposing resulting tensor."""
+ return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C
diff --git a/ultralytics/models/sam/modules/decoders.py b/ultralytics/models/sam/modules/decoders.py
index eeaab6b45..5b172bdd2 100644
--- a/ultralytics/models/sam/modules/decoders.py
+++ b/ultralytics/models/sam/modules/decoders.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-from typing import List, Tuple, Type
+from typing import List, Optional, Tuple, Type
import torch
from torch import nn
@@ -10,12 +10,14 @@ from ultralytics.nn.modules import MLP, LayerNorm2d
class MaskDecoder(nn.Module):
"""
- Decoder module for generating masks and their associated quality scores, using a transformer architecture to predict
- masks given image and prompt embeddings.
+ Decoder module for generating masks and their associated quality scores using a transformer architecture.
+
+ This class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and
+ generate mask predictions along with their quality scores.
Attributes:
transformer_dim (int): Channel dimension for the transformer module.
- transformer (nn.Module): The transformer module used for mask prediction.
+ transformer (nn.Module): Transformer module used for mask prediction.
num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
iou_token (nn.Embedding): Embedding for the IoU token.
num_mask_tokens (int): Number of mask tokens.
@@ -23,6 +25,16 @@ class MaskDecoder(nn.Module):
output_upscaling (nn.Sequential): Neural network sequence for upscaling the output.
output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks.
iou_prediction_head (nn.Module): MLP for predicting mask quality.
+
+ Methods:
+ forward: Predicts masks given image and prompt embeddings.
+ predict_masks: Internal method for mask prediction.
+
+ Examples:
+ >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
+ >>> masks, iou_pred = decoder(image_embeddings, image_pe, sparse_prompt_embeddings,
+ ... dense_prompt_embeddings, multimask_output=True)
+ >>> print(f"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
"""
def __init__(
@@ -35,15 +47,20 @@ class MaskDecoder(nn.Module):
iou_head_hidden_dim: int = 256,
) -> None:
"""
- Predicts masks given an image and prompt embeddings, using a transformer architecture.
+ Initializes the MaskDecoder module for generating masks and their quality scores.
Args:
- transformer_dim (int): the channel dimension of the transformer module
- transformer (nn.Module): the transformer used to predict masks
- num_multimask_outputs (int): the number of masks to predict when disambiguating masks
- activation (nn.Module): the type of activation to use when upscaling masks
- iou_head_depth (int): the depth of the MLP used to predict mask quality
- iou_head_hidden_dim (int): the hidden dimension of the MLP used to predict mask quality
+ transformer_dim (int): Channel dimension for the transformer module.
+ transformer (nn.Module): Transformer module used for mask prediction.
+ num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
+ activation (Type[nn.Module]): Type of activation to use when upscaling masks.
+ iou_head_depth (int): Depth of the MLP used to predict mask quality.
+ iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
+
+ Examples:
+ >>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
+ >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer)
+ >>> print(decoder)
"""
super().__init__()
self.transformer_dim = transformer_dim
@@ -77,18 +94,28 @@ class MaskDecoder(nn.Module):
multimask_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
- Predict masks given image and prompt embeddings.
+ Predicts masks given image and prompt embeddings.
Args:
- image_embeddings (torch.Tensor): the embeddings from the image encoder
- image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
- sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
- dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
+ image_embeddings (torch.Tensor): Embeddings from the image encoder.
+ image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings.
+ sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes.
+ dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs.
multimask_output (bool): Whether to return multiple masks or a single mask.
Returns:
- torch.Tensor: batched predicted masks
- torch.Tensor: batched predictions of mask quality
+ (Tuple[torch.Tensor, torch.Tensor]): A tuple containing:
+ - masks (torch.Tensor): Batched predicted masks.
+ - iou_pred (torch.Tensor): Batched predictions of mask quality.
+
+ Examples:
+ >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
+ >>> image_emb = torch.rand(1, 256, 64, 64)
+ >>> image_pe = torch.rand(1, 256, 64, 64)
+ >>> sparse_emb = torch.rand(1, 2, 256)
+ >>> dense_emb = torch.rand(1, 256, 64, 64)
+ >>> masks, iou_pred = decoder(image_emb, image_pe, sparse_emb, dense_emb, multimask_output=True)
+ >>> print(f"Masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
"""
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
@@ -112,11 +139,7 @@ class MaskDecoder(nn.Module):
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Predicts masks.
-
- See 'forward' for more details.
- """
+ """Predicts masks and quality scores using image and prompt embeddings via transformer architecture."""
# Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)
@@ -147,3 +170,347 @@ class MaskDecoder(nn.Module):
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
+
+
+class SAM2MaskDecoder(nn.Module):
+ """
+ Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.
+
+ This class extends the functionality of the MaskDecoder, incorporating additional features such as
+ high-resolution feature processing, dynamic multimask output, and object score prediction.
+
+ Attributes:
+ transformer_dim (int): Channel dimension of the transformer.
+ transformer (nn.Module): Transformer used to predict masks.
+ num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
+ iou_token (nn.Embedding): Embedding for IOU token.
+ num_mask_tokens (int): Total number of mask tokens.
+ mask_tokens (nn.Embedding): Embedding for mask tokens.
+ pred_obj_scores (bool): Whether to predict object scores.
+ obj_score_token (nn.Embedding): Embedding for object score token.
+ use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
+ output_upscaling (nn.Sequential): Upscaling layers for output.
+ use_high_res_features (bool): Whether to use high-resolution features.
+ conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0).
+ conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1).
+ output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks.
+ iou_prediction_head (MLP): MLP for IOU prediction.
+ pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction.
+ dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
+ dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
+ dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
+
+ Methods:
+ forward: Predicts masks given image and prompt embeddings.
+ predict_masks: Predicts instance segmentation masks from image and prompt embeddings.
+ _get_stability_scores: Computes mask stability scores based on IoU between thresholds.
+ _dynamic_multimask_via_stability: Dynamically selects the most stable mask output.
+
+ Examples:
+ >>> image_embeddings = torch.rand(1, 256, 64, 64)
+ >>> image_pe = torch.rand(1, 256, 64, 64)
+ >>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
+ >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
+ >>> decoder = SAM2MaskDecoder(256, transformer)
+ >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
+ ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False)
+ """
+
+ def __init__(
+ self,
+ transformer_dim: int,
+ transformer: nn.Module,
+ num_multimask_outputs: int = 3,
+ activation: Type[nn.Module] = nn.GELU,
+ iou_head_depth: int = 3,
+ iou_head_hidden_dim: int = 256,
+ use_high_res_features: bool = False,
+ iou_prediction_use_sigmoid=False,
+ dynamic_multimask_via_stability=False,
+ dynamic_multimask_stability_delta=0.05,
+ dynamic_multimask_stability_thresh=0.98,
+ pred_obj_scores: bool = False,
+ pred_obj_scores_mlp: bool = False,
+ use_multimask_token_for_obj_ptr: bool = False,
+ ) -> None:
+ """
+ Initializes the SAM2MaskDecoder module for predicting instance segmentation masks.
+
+ This decoder extends the functionality of MaskDecoder, incorporating additional features such as
+ high-resolution feature processing, dynamic multimask output, and object score prediction.
+
+ Args:
+ transformer_dim (int): Channel dimension of the transformer.
+ transformer (nn.Module): Transformer used to predict masks.
+ num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
+ activation (Type[nn.Module]): Type of activation to use when upscaling masks.
+ iou_head_depth (int): Depth of the MLP used to predict mask quality.
+ iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
+ use_high_res_features (bool): Whether to use high-resolution features.
+ iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction.
+ dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
+ dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
+ dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
+ pred_obj_scores (bool): Whether to predict object scores.
+ pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction.
+ use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
+
+ Examples:
+ >>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
+ >>> decoder = SAM2MaskDecoder(transformer_dim=256, transformer=transformer)
+ >>> print(decoder)
+ """
+ super().__init__()
+ self.transformer_dim = transformer_dim
+ self.transformer = transformer
+
+ self.num_multimask_outputs = num_multimask_outputs
+
+ self.iou_token = nn.Embedding(1, transformer_dim)
+ self.num_mask_tokens = num_multimask_outputs + 1
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+
+ self.pred_obj_scores = pred_obj_scores
+ if self.pred_obj_scores:
+ self.obj_score_token = nn.Embedding(1, transformer_dim)
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
+
+ self.output_upscaling = nn.Sequential(
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
+ LayerNorm2d(transformer_dim // 4),
+ activation(),
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+ activation(),
+ )
+ self.use_high_res_features = use_high_res_features
+ if use_high_res_features:
+ self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1)
+ self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1)
+
+ self.output_hypernetworks_mlps = nn.ModuleList(
+ [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
+ )
+
+ self.iou_prediction_head = MLP(
+ transformer_dim,
+ iou_head_hidden_dim,
+ self.num_mask_tokens,
+ iou_head_depth,
+ sigmoid=iou_prediction_use_sigmoid,
+ )
+ if self.pred_obj_scores:
+ self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
+ if pred_obj_scores_mlp:
+ self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
+
+ # When outputting a single mask, optionally we can dynamically fall back to the best
+ # multimask output token if the single mask output token gives low stability scores.
+ self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
+ self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
+ self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
+
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ multimask_output: bool,
+ repeat_image: bool,
+ high_res_features: Optional[List[torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Predicts masks given image and prompt embeddings.
+
+ Args:
+ image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W).
+ image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings (B, C, H, W).
+ sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes with shape (B, N, C).
+ dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs with shape (B, C, H, W).
+ multimask_output (bool): Whether to return multiple masks or a single mask.
+ repeat_image (bool): Flag to repeat the image embeddings.
+ high_res_features (List[torch.Tensor] | None): Optional high-resolution features.
+
+ Returns:
+ (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
+ - masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W).
+ - iou_pred (torch.Tensor): Batched predictions of mask quality with shape (B, N).
+ - sam_tokens_out (torch.Tensor): Batched SAM token for mask output with shape (B, N, C).
+ - object_score_logits (torch.Tensor): Batched object score logits with shape (B, 1).
+
+ Examples:
+ >>> image_embeddings = torch.rand(1, 256, 64, 64)
+ >>> image_pe = torch.rand(1, 256, 64, 64)
+ >>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
+ >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
+ >>> decoder = SAM2MaskDecoder(256, transformer)
+ >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
+ ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False)
+ """
+ masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
+ image_embeddings=image_embeddings,
+ image_pe=image_pe,
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
+ dense_prompt_embeddings=dense_prompt_embeddings,
+ repeat_image=repeat_image,
+ high_res_features=high_res_features,
+ )
+
+ # Select the correct mask or masks for output
+ if multimask_output:
+ masks = masks[:, 1:, :, :]
+ iou_pred = iou_pred[:, 1:]
+ elif self.dynamic_multimask_via_stability and not self.training:
+ masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
+ else:
+ masks = masks[:, 0:1, :, :]
+ iou_pred = iou_pred[:, 0:1]
+
+ if multimask_output and self.use_multimask_token_for_obj_ptr:
+ sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
+ else:
+ # Take the mask output token. Here we *always* use the token for single mask output.
+ # At test time, even if we track after 1-click (and using multimask_output=True),
+ # we still take the single mask token here. The rationale is that we always track
+ # after multiple clicks during training, so the past tokens seen during training
+ # are always the single mask token (and we'll let it be the object-memory token).
+ sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
+
+ # Prepare output
+ return masks, iou_pred, sam_tokens_out, object_score_logits
+
+ def predict_masks(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ repeat_image: bool,
+ high_res_features: Optional[List[torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Predicts instance segmentation masks from image and prompt embeddings using a transformer."""
+ # Concatenate output tokens
+ s = 0
+ if self.pred_obj_scores:
+ output_tokens = torch.cat(
+ [
+ self.obj_score_token.weight,
+ self.iou_token.weight,
+ self.mask_tokens.weight,
+ ],
+ dim=0,
+ )
+ s = 1
+ else:
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+
+ # Expand per-image data in batch direction to be per-mask
+ if repeat_image:
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+ else:
+ assert image_embeddings.shape[0] == tokens.shape[0]
+ src = image_embeddings
+ src = src + dense_prompt_embeddings
+ assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+ b, c, h, w = src.shape
+
+ # Run the transformer
+ hs, src = self.transformer(src, pos_src, tokens)
+ iou_token_out = hs[:, s, :]
+ mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
+
+ # Upscale mask embeddings and predict masks using the mask tokens
+ src = src.transpose(1, 2).view(b, c, h, w)
+ if not self.use_high_res_features:
+ upscaled_embedding = self.output_upscaling(src)
+ else:
+ dc1, ln1, act1, dc2, act2 = self.output_upscaling
+ feat_s0, feat_s1 = high_res_features
+ upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
+ upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
+
+ hyper_in_list: List[torch.Tensor] = []
+ for i in range(self.num_mask_tokens):
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
+ hyper_in = torch.stack(hyper_in_list, dim=1)
+ b, c, h, w = upscaled_embedding.shape
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
+
+ # Generate mask quality predictions
+ iou_pred = self.iou_prediction_head(iou_token_out)
+ if self.pred_obj_scores:
+ assert s == 1
+ object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
+ else:
+ # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
+ object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
+
+ return masks, iou_pred, mask_tokens_out, object_score_logits
+
+ def _get_stability_scores(self, mask_logits):
+ """Computes mask stability scores based on IoU between upper and lower thresholds."""
+ mask_logits = mask_logits.flatten(-2)
+ stability_delta = self.dynamic_multimask_stability_delta
+ area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
+ area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
+ stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
+ return stability_scores
+
+ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
+ """
+ Dynamically selects the most stable mask output based on stability scores and IoU predictions.
+
+ This method is used when outputting a single mask. If the stability score from the current single-mask
+ output (based on output token 0) falls below a threshold, it instead selects from multi-mask outputs
+ (based on output tokens 1-3) the mask with the highest predicted IoU score. This ensures a valid mask
+ for both clicking and tracking scenarios.
+
+ Args:
+ all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is
+ batch size, N is number of masks (typically 4), and H, W are mask dimensions.
+ all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N).
+
+ Returns:
+ (Tuple[torch.Tensor, torch.Tensor]):
+ - mask_logits_out (torch.Tensor): Selected mask logits, shape (B, 1, H, W).
+ - iou_scores_out (torch.Tensor): Selected IoU scores, shape (B, 1).
+
+ Examples:
+ >>> decoder = SAM2MaskDecoder(...)
+ >>> all_mask_logits = torch.rand(2, 4, 256, 256) # 2 images, 4 masks each
+ >>> all_iou_scores = torch.rand(2, 4)
+ >>> mask_logits, iou_scores = decoder._dynamic_multimask_via_stability(all_mask_logits, all_iou_scores)
+ >>> print(mask_logits.shape, iou_scores.shape)
+ torch.Size([2, 1, 256, 256]) torch.Size([2, 1])
+ """
+ # The best mask from multimask output tokens (1~3)
+ multimask_logits = all_mask_logits[:, 1:, :, :]
+ multimask_iou_scores = all_iou_scores[:, 1:]
+ best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
+ batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
+ best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
+ best_multimask_logits = best_multimask_logits.unsqueeze(1)
+ best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
+ best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
+
+ # The mask from singlemask output token 0 and its stability score
+ singlemask_logits = all_mask_logits[:, 0:1, :, :]
+ singlemask_iou_scores = all_iou_scores[:, 0:1]
+ stability_scores = self._get_stability_scores(singlemask_logits)
+ is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
+
+ # Dynamically fall back to best multimask output upon low stability scores.
+ mask_logits_out = torch.where(
+ is_stable[..., None, None].expand_as(singlemask_logits),
+ singlemask_logits,
+ best_multimask_logits,
+ )
+ iou_scores_out = torch.where(
+ is_stable.expand_as(singlemask_iou_scores),
+ singlemask_iou_scores,
+ best_multimask_iou_scores,
+ )
+ return mask_logits_out, iou_scores_out
diff --git a/ultralytics/models/sam/modules/encoders.py b/ultralytics/models/sam/modules/encoders.py
index 2bf3945dc..9432fec4b 100644
--- a/ultralytics/models/sam/modules/encoders.py
+++ b/ultralytics/models/sam/modules/encoders.py
@@ -1,30 +1,48 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-from typing import Any, Optional, Tuple, Type
+from typing import List, Optional, Tuple, Type
-import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
-from ultralytics.nn.modules import LayerNorm2d, MLPBlock
+from ultralytics.nn.modules import LayerNorm2d
+
+from .blocks import (
+ Block,
+ CXBlock,
+ Fuser,
+ MaskDownSampler,
+ MultiScaleBlock,
+ PatchEmbed,
+ PositionEmbeddingRandom,
+ PositionEmbeddingSine,
+)
class ImageEncoderViT(nn.Module):
"""
- An image encoder using Vision Transformer (ViT) architecture for encoding an image into a compact latent space. The
- encoder takes an image, splits it into patches, and processes these patches through a series of transformer blocks.
- The encoded patches are then processed through a neck to generate the final encoded representation.
+ An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.
- This class and its supporting functions below lightly adapted from the ViTDet backbone available at
- https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py.
+ This class processes images by splitting them into patches, applying transformer blocks, and generating a final
+ encoded representation through a neck module.
Attributes:
img_size (int): Dimension of input images, assumed to be square.
patch_embed (PatchEmbed): Module for patch embedding.
- pos_embed (nn.Parameter, optional): Absolute positional embedding for patches.
+ pos_embed (nn.Parameter | None): Absolute positional embedding for patches.
blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings.
neck (nn.Sequential): Neck module to further process the output.
+
+ Methods:
+ forward: Processes input through patch embedding, positional embedding, blocks, and neck.
+
+ Examples:
+ >>> import torch
+ >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
+ >>> input_image = torch.randn(1, 3, 224, 224)
+ >>> output = encoder(input_image)
+ >>> print(output.shape)
"""
def __init__(
@@ -47,22 +65,38 @@ class ImageEncoderViT(nn.Module):
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
+ Initializes an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
+
Args:
- img_size (int): Input image size.
- patch_size (int): Patch size.
+ img_size (int): Input image size, assumed to be square.
+ patch_size (int): Size of image patches.
in_chans (int): Number of input image channels.
- embed_dim (int): Patch embedding dimension.
- depth (int): Depth of ViT.
- num_heads (int): Number of attention heads in each ViT block.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool): If True, add a learnable bias to query, key, value.
- norm_layer (nn.Module): Normalization layer.
- act_layer (nn.Module): Activation layer.
- use_abs_pos (bool): If True, use absolute positional embeddings.
- use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
- window_size (int): Window size for window attention blocks.
- global_attn_indexes (list): Indexes for blocks using global attention.
+ embed_dim (int): Dimension of patch embeddings.
+ depth (int): Number of transformer blocks.
+ num_heads (int): Number of attention heads in each block.
+ mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
+ out_chans (int): Number of output channels from the neck module.
+ qkv_bias (bool): If True, adds learnable bias to query, key, value projections.
+ norm_layer (Type[nn.Module]): Type of normalization layer to use.
+ act_layer (Type[nn.Module]): Type of activation layer to use.
+ use_abs_pos (bool): If True, uses absolute positional embeddings.
+ use_rel_pos (bool): If True, adds relative positional embeddings to attention maps.
+ rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
+ window_size (int): Size of attention window for windowed attention blocks.
+ global_attn_indexes (Tuple[int, ...]): Indices of blocks that use global attention.
+
+ Attributes:
+ img_size (int): Dimension of input images.
+ patch_embed (PatchEmbed): Module for patch embedding.
+ pos_embed (nn.Parameter | None): Absolute positional embedding for patches.
+ blocks (nn.ModuleList): List of transformer blocks.
+ neck (nn.Sequential): Neck module for final processing.
+
+ Examples:
+ >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
+ >>> input_image = torch.randn(1, 3, 224, 224)
+ >>> output = encoder(input_image)
+ >>> print(output.shape)
"""
super().__init__()
self.img_size = img_size
@@ -114,9 +148,7 @@ class ImageEncoderViT(nn.Module):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Processes input through patch embedding, applies positional embedding if present, and passes through blocks
- and neck.
- """
+ """Processes input through patch embedding, positional embedding, transformer blocks, and neck module."""
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
@@ -127,8 +159,7 @@ class ImageEncoderViT(nn.Module):
class PromptEncoder(nn.Module):
"""
- Encodes different types of prompts, including points, boxes, and masks, for input to SAM's mask decoder. The encoder
- produces both sparse and dense embeddings for the input prompts.
+ Encodes different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
Attributes:
embed_dim (int): Dimension of the embeddings.
@@ -137,10 +168,23 @@ class PromptEncoder(nn.Module):
pe_layer (PositionEmbeddingRandom): Module for random position embedding.
num_point_embeddings (int): Number of point embeddings for different types of points.
point_embeddings (nn.ModuleList): List of point embeddings.
- not_a_point_embed (nn.Embedding): Embedding for points that are not a part of any label.
+ not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
mask_input_size (Tuple[int, int]): Size of the input mask.
mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.
+
+ Methods:
+ get_dense_pe: Returns the positional encoding used to encode point prompts.
+ forward: Embeds different types of prompts, returning both sparse and dense embeddings.
+
+ Examples:
+ >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
+ >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
+ >>> boxes = torch.rand(1, 2, 2)
+ >>> masks = torch.rand(1, 1, 256, 256)
+ >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
+ >>> print(sparse_embeddings.shape, dense_embeddings.shape)
+ torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
"""
def __init__(
@@ -152,18 +196,37 @@ class PromptEncoder(nn.Module):
activation: Type[nn.Module] = nn.GELU,
) -> None:
"""
- Encodes prompts for input to SAM's mask decoder.
+ Initializes the PromptEncoder module for encoding various types of prompts.
+
+ This module encodes different types of prompts (points, boxes, masks) for input to SAM's mask decoder,
+ producing both sparse and dense embeddings.
Args:
- embed_dim (int): The prompts' embedding dimension
- image_embedding_size (tuple(int, int)): The spatial size of the
- image embedding, as (H, W).
- input_image_size (int): The padded size of the image as input
- to the image encoder, as (H, W).
- mask_in_chans (int): The number of hidden channels used for
- encoding input masks.
- activation (nn.Module): The activation to use when encoding
- input masks.
+ embed_dim (int): The dimension of the embeddings.
+ image_embedding_size (Tuple[int, int]): The spatial size of the image embedding as (H, W).
+ input_image_size (Tuple[int, int]): The padded size of the input image as (H, W).
+ mask_in_chans (int): The number of hidden channels used for encoding input masks.
+ activation (Type[nn.Module]): The activation function to use when encoding input masks.
+
+ Attributes:
+ embed_dim (int): Dimension of the embeddings.
+ input_image_size (Tuple[int, int]): Size of the input image as (H, W).
+ image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).
+ pe_layer (PositionEmbeddingRandom): Module for random position embedding.
+ num_point_embeddings (int): Number of point embeddings for different types of points.
+ point_embeddings (nn.ModuleList): List of point embeddings.
+ not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
+ mask_input_size (Tuple[int, int]): Size of the input mask.
+ mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
+
+ Examples:
+ >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
+ >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
+ >>> boxes = torch.rand(1, 2, 2)
+ >>> masks = torch.rand(1, 1, 256, 256)
+ >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
+ >>> print(sparse_embeddings.shape, dense_embeddings.shape)
+ torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
"""
super().__init__()
self.embed_dim = embed_dim
@@ -190,16 +253,25 @@ class PromptEncoder(nn.Module):
def get_dense_pe(self) -> torch.Tensor:
"""
- Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the
- image encoding.
+ Returns the dense positional encoding used for encoding point prompts.
+
+ This method generates a positional encoding for a dense set of points matching the shape of the image
+ encoding. The encoding is used to provide spatial information to the model when processing point prompts.
Returns:
- torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
+ (torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the
+ height and width of the image embedding size, respectively.
+
+ Examples:
+ >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
+ >>> dense_pe = prompt_encoder.get_dense_pe()
+ >>> print(dense_pe.shape)
+ torch.Size([1, 256, 64, 64])
"""
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
- """Embeds point prompts."""
+ """Embeds point prompts by applying positional encoding and label-specific embeddings."""
points = points + 0.5 # Shift to center of pixel
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
@@ -216,7 +288,7 @@ class PromptEncoder(nn.Module):
return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
- """Embeds box prompts."""
+ """Embeds box prompts by applying positional encoding and adding corner embeddings."""
boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes.reshape(-1, 2, 2)
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
@@ -225,7 +297,7 @@ class PromptEncoder(nn.Module):
return corner_embedding
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
- """Embeds mask inputs."""
+ """Embeds mask inputs by downscaling and processing through convolutional layers."""
return self.mask_downscaling(masks)
@staticmethod
@@ -258,14 +330,25 @@ class PromptEncoder(nn.Module):
Embeds different types of prompts, returning both sparse and dense embeddings.
Args:
- points (tuple(torch.Tensor, torch.Tensor), None): point coordinates and labels to embed.
- boxes (torch.Tensor, None): boxes to embed
- masks (torch.Tensor, None): masks to embed
+ points (Tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first
+ tensor contains coordinates with shape (B, N, 2), and the second tensor contains labels with
+ shape (B, N).
+ boxes (torch.Tensor | None): Boxes to embed with shape (B, M, 2, 2), where M is the number of boxes.
+ masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W).
Returns:
- torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined
- by the number of input points and boxes.
- torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W)
+ (Tuple[torch.Tensor, torch.Tensor]): A tuple containing:
+ - sparse_embeddings (torch.Tensor): Sparse embeddings for points and boxes with shape (B, N, embed_dim).
+ - dense_embeddings (torch.Tensor): Dense embeddings for masks of shape (B, embed_dim, embed_H, embed_W).
+
+ Examples:
+ >>> encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
+ >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
+ >>> boxes = torch.rand(1, 2, 2, 2)
+ >>> masks = torch.rand(1, 1, 256, 256)
+ >>> sparse_emb, dense_emb = encoder(points, boxes, masks)
+ >>> print(sparse_emb.shape, dense_emb.shape)
+ torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
"""
bs = self._get_batch_size(points, boxes, masks)
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
@@ -287,319 +370,421 @@ class PromptEncoder(nn.Module):
return sparse_embeddings, dense_embeddings
-class PositionEmbeddingRandom(nn.Module):
- """Positional encoding using random spatial frequencies."""
+class MemoryEncoder(nn.Module):
+ """
+ Encodes pixel features and masks into a memory representation for efficient image segmentation.
- def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
- """Initializes a position embedding using random spatial frequencies."""
- super().__init__()
- if scale is None or scale <= 0.0:
- scale = 1.0
- self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)))
-
- # Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
- torch.use_deterministic_algorithms(False)
- torch.backends.cudnn.deterministic = False
-
- def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
- """Positionally encode points that are normalized to [0,1]."""
- # Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
- coords = 2 * coords - 1
- coords = coords @ self.positional_encoding_gaussian_matrix
- coords = 2 * np.pi * coords
- # Outputs d_1 x ... x d_n x C shape
- return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
-
- def forward(self, size: Tuple[int, int]) -> torch.Tensor:
- """Generate positional encoding for a grid of the specified size."""
- h, w = size
- device: Any = self.positional_encoding_gaussian_matrix.device
- grid = torch.ones((h, w), device=device, dtype=torch.float32)
- y_embed = grid.cumsum(dim=0) - 0.5
- x_embed = grid.cumsum(dim=1) - 0.5
- y_embed = y_embed / h
- x_embed = x_embed / w
-
- pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
- return pe.permute(2, 0, 1) # C x H x W
-
- def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
- """Positionally encode points that are not normalized to [0,1]."""
- coords = coords_input.clone()
- coords[:, :, 0] = coords[:, :, 0] / image_size[1]
- coords[:, :, 1] = coords[:, :, 1] / image_size[0]
- return self._pe_encoding(coords.to(torch.float)) # B x N x C
-
-
-class Block(nn.Module):
- """Transformer blocks with support of window attention and residual propagation blocks."""
+ This class processes pixel-level features and masks, fusing them to generate encoded memory representations
+ suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
+
+ Attributes:
+ mask_downsampler (MaskDownSampler): Module for downsampling input masks.
+ pix_feat_proj (nn.Conv2d): Convolutional layer for projecting pixel features.
+ fuser (Fuser): Module for fusing pixel features and masks.
+ position_encoding (PositionEmbeddingSine): Module for adding positional encoding to features.
+ out_proj (nn.Module): Output projection layer, either nn.Identity or nn.Conv2d.
+
+ Methods:
+ forward: Processes input pixel features and masks to generate encoded memory representations.
+
+ Examples:
+ >>> import torch
+ >>> encoder = MemoryEncoder(out_dim=256, in_dim=256)
+ >>> pix_feat = torch.randn(1, 256, 64, 64)
+ >>> masks = torch.randn(1, 1, 64, 64)
+ >>> encoded_feat, pos = encoder(pix_feat, masks)
+ >>> print(encoded_feat.shape, pos.shape)
+ torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])
+ """
def __init__(
self,
- dim: int,
- num_heads: int,
- mlp_ratio: float = 4.0,
- qkv_bias: bool = True,
- norm_layer: Type[nn.Module] = nn.LayerNorm,
- act_layer: Type[nn.Module] = nn.GELU,
- use_rel_pos: bool = False,
- rel_pos_zero_init: bool = True,
- window_size: int = 0,
- input_size: Optional[Tuple[int, int]] = None,
- ) -> None:
- """
- Args:
- dim (int): Number of input channels.
- num_heads (int): Number of attention heads in each ViT block.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool): If True, add a learnable bias to query, key, value.
- norm_layer (nn.Module): Normalization layer.
- act_layer (nn.Module): Activation layer.
- use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
- window_size (int): Window size for window attention blocks. If it equals 0, then
- use global attention.
- input_size (tuple(int, int), None): Input resolution for calculating the relative
- positional parameter size.
- """
+ out_dim,
+ in_dim=256, # in_dim of pix_feats
+ ):
+ """Initializes the MemoryEncoder for encoding pixel features and masks into memory representations."""
super().__init__()
- self.norm1 = norm_layer(dim)
- self.attn = Attention(
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- use_rel_pos=use_rel_pos,
- rel_pos_zero_init=rel_pos_zero_init,
- input_size=input_size if window_size == 0 else (window_size, window_size),
- )
- self.norm2 = norm_layer(dim)
- self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
+ self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
- self.window_size = window_size
+ self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
+ self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
+ self.position_encoding = PositionEmbeddingSine(num_pos_feats=64)
+ self.out_proj = nn.Identity()
+ if out_dim != in_dim:
+ self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Executes a forward pass through the transformer block with window attention and non-overlapping windows."""
- shortcut = x
- x = self.norm1(x)
- # Window partition
- if self.window_size > 0:
- H, W = x.shape[1], x.shape[2]
- x, pad_hw = window_partition(x, self.window_size)
+ def forward(
+ self,
+ pix_feat: torch.Tensor,
+ masks: torch.Tensor,
+ skip_mask_sigmoid: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Processes pixel features and masks to generate encoded memory representations for segmentation."""
+ if not skip_mask_sigmoid:
+ masks = F.sigmoid(masks)
+ masks = self.mask_downsampler(masks)
- x = self.attn(x)
- # Reverse window partition
- if self.window_size > 0:
- x = window_unpartition(x, self.window_size, pad_hw, (H, W))
+ # Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA
+ pix_feat = pix_feat.to(masks.device)
- x = shortcut + x
- return x + self.mlp(self.norm2(x))
+ x = self.pix_feat_proj(pix_feat)
+ x = x + masks
+ x = self.fuser(x)
+ x = self.out_proj(x)
+ pos = self.position_encoding(x).to(x.dtype)
-class Attention(nn.Module):
- """Multi-head Attention block with relative position embeddings."""
+ return {"vision_features": x, "vision_pos_enc": [pos]}
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = True,
- use_rel_pos: bool = False,
- rel_pos_zero_init: bool = True,
- input_size: Optional[Tuple[int, int]] = None,
- ) -> None:
- """
- Initialize Attention module.
- Args:
- dim (int): Number of input channels.
- num_heads (int): Number of attention heads.
- qkv_bias (bool): If True, add a learnable bias to query, key, value.
- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
- input_size (tuple(int, int), None): Input resolution for calculating the relative
- positional parameter size.
- """
- super().__init__()
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = head_dim**-0.5
+class ImageEncoder(nn.Module):
+ """
+ Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings.
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.proj = nn.Linear(dim, dim)
+ This class combines a trunk network for feature extraction with a neck network for feature refinement
+ and positional encoding generation. It can optionally discard the lowest resolution features.
- self.use_rel_pos = use_rel_pos
- if self.use_rel_pos:
- assert input_size is not None, "Input size must be provided if using relative positional encoding."
- # Initialize relative positional embeddings
- self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
- self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
+ Attributes:
+ trunk (nn.Module): The trunk network for initial feature extraction.
+ neck (nn.Module): The neck network for feature refinement and positional encoding generation.
+ scalp (int): Number of lowest resolution feature levels to discard.
+
+ Methods:
+ forward: Processes the input image through the trunk and neck networks.
+
+ Examples:
+ >>> trunk = SomeTrunkNetwork()
+ >>> neck = SomeNeckNetwork()
+ >>> encoder = ImageEncoder(trunk, neck, scalp=1)
+ >>> image = torch.randn(1, 3, 224, 224)
+ >>> output = encoder(image)
+ >>> print(output.keys())
+ dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])
+ """
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Applies the forward operation including attention, normalization, MLP, and indexing within window limits."""
- B, H, W, _ = x.shape
- # qkv with shape (3, B, nHead, H * W, C)
- qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
- # q, k, v with shape (B * nHead, H * W, C)
- q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
+ def __init__(
+ self,
+ trunk: nn.Module,
+ neck: nn.Module,
+ scalp: int = 0,
+ ):
+ """Initializes the ImageEncoder with trunk and neck networks for feature extraction and refinement."""
+ super().__init__()
+ self.trunk = trunk
+ self.neck = neck
+ self.scalp = scalp
+ assert (
+ self.trunk.channel_list == self.neck.backbone_channel_list
+ ), f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match."
+
+ def forward(self, sample: torch.Tensor):
+ """Encodes input through patch embedding, positional embedding, transformer blocks, and neck module."""
+ features, pos = self.neck(self.trunk(sample))
+ if self.scalp > 0:
+ # Discard the lowest resolution features
+ features, pos = features[: -self.scalp], pos[: -self.scalp]
+
+ src = features[-1]
+ output = {
+ "vision_features": src,
+ "vision_pos_enc": pos,
+ "backbone_fpn": features,
+ }
+ return output
+
+
+class FpnNeck(nn.Module):
+ """
+ A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.
- attn = (q * self.scale) @ k.transpose(-2, -1)
+ This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
+ similar to ViT positional embedding interpolation.
- if self.use_rel_pos:
- attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
+ Attributes:
+ position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding module.
+ convs (nn.ModuleList): List of convolutional layers for each backbone level.
+ backbone_channel_list (List[int]): List of channel dimensions from the backbone.
+ fpn_interp_model (str): Interpolation mode for FPN feature resizing.
+ fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
+ fpn_top_down_levels (List[int]): Levels to have top-down features in outputs.
+
+ Methods:
+ forward: Performs forward pass through the FPN neck.
+
+ Examples:
+ >>> backbone_channels = [64, 128, 256, 512]
+ >>> fpn_neck = FpnNeck(256, backbone_channels)
+ >>> inputs = [torch.rand(1, c, 32, 32) for c in backbone_channels]
+ >>> outputs, positions = fpn_neck(inputs)
+ >>> print(len(outputs), len(positions))
+ 4 4
+ """
- attn = attn.softmax(dim=-1)
- x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
- return self.proj(x)
+ def __init__(
+ self,
+ d_model: int,
+ backbone_channel_list: List[int],
+ kernel_size: int = 1,
+ stride: int = 1,
+ padding: int = 0,
+ fpn_interp_model: str = "bilinear",
+ fuse_type: str = "sum",
+ fpn_top_down_levels: Optional[List[int]] = None,
+ ):
+ """
+ Initializes a modified Feature Pyramid Network (FPN) neck.
+ This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
+ similar to ViT positional embedding interpolation.
-def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
- """
- Partition into non-overlapping windows with padding if needed.
- Args:
- x (tensor): input tokens with [B, H, W, C].
- window_size (int): window size.
-
- Returns:
- windows: windows after partition with [B * num_windows, window_size, window_size, C].
- (Hp, Wp): padded height and width before partition
- """
- B, H, W, C = x.shape
+ Args:
+ d_model (int): Dimension of the model.
+ backbone_channel_list (List[int]): List of channel dimensions from the backbone.
+ kernel_size (int): Kernel size for the convolutional layers.
+ stride (int): Stride for the convolutional layers.
+ padding (int): Padding for the convolutional layers.
+ fpn_interp_model (str): Interpolation mode for FPN feature resizing.
+ fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
+ fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs.
+
+ Examples:
+ >>> backbone_channels = [64, 128, 256, 512]
+ >>> fpn_neck = FpnNeck(256, backbone_channels)
+ >>> print(fpn_neck)
+ """
+ super().__init__()
+ self.position_encoding = PositionEmbeddingSine(num_pos_feats=256)
+ self.convs = nn.ModuleList()
+ self.backbone_channel_list = backbone_channel_list
+ for dim in backbone_channel_list:
+ current = nn.Sequential()
+ current.add_module(
+ "conv",
+ nn.Conv2d(
+ in_channels=dim,
+ out_channels=d_model,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ ),
+ )
- pad_h = (window_size - H % window_size) % window_size
- pad_w = (window_size - W % window_size) % window_size
- if pad_h > 0 or pad_w > 0:
- x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
- Hp, Wp = H + pad_h, W + pad_w
+ self.convs.append(current)
+ self.fpn_interp_model = fpn_interp_model
+ assert fuse_type in ["sum", "avg"]
+ self.fuse_type = fuse_type
+
+ # levels to have top-down features in its outputs
+ # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
+ # have top-down propagation, while outputs of level 0 and level 1 have only
+ # lateral features from the same backbone level.
+ if fpn_top_down_levels is None:
+ # default is to have top-down features on all levels
+ fpn_top_down_levels = range(len(self.convs))
+ self.fpn_top_down_levels = list(fpn_top_down_levels)
+
+ def forward(self, xs: List[torch.Tensor]):
+ """
+ Performs forward pass through the Feature Pyramid Network (FPN) neck.
- x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- return windows, (Hp, Wp)
+ This method processes a list of input tensors from the backbone through the FPN, applying lateral connections
+ and top-down feature fusion. It generates output feature maps and corresponding positional encodings.
+ Args:
+ xs (List[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W).
-def window_unpartition(
- windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
-) -> torch.Tensor:
+ Returns:
+ (Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing:
+ - out (List[torch.Tensor]): List of output feature maps after FPN processing, each with shape
+ (B, d_model, H, W).
+ - pos (List[torch.Tensor]): List of positional encodings corresponding to each output feature map.
+
+ Examples:
+ >>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])
+ >>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]]
+ >>> outputs, positions = fpn_neck(inputs)
+ >>> print(len(outputs), len(positions))
+ 4 4
+ """
+ out = [None] * len(self.convs)
+ pos = [None] * len(self.convs)
+ assert len(xs) == len(self.convs)
+ # fpn forward pass
+ # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
+ prev_features = None
+ # forward in top-down order (from low to high resolution)
+ n = len(self.convs) - 1
+ for i in range(n, -1, -1):
+ x = xs[i]
+ lateral_features = self.convs[n - i](x)
+ if i in self.fpn_top_down_levels and prev_features is not None:
+ top_down_features = F.interpolate(
+ prev_features.to(dtype=torch.float32),
+ scale_factor=2.0,
+ mode=self.fpn_interp_model,
+ align_corners=(None if self.fpn_interp_model == "nearest" else False),
+ antialias=False,
+ )
+ prev_features = lateral_features + top_down_features
+ if self.fuse_type == "avg":
+ prev_features /= 2
+ else:
+ prev_features = lateral_features
+ x_out = prev_features
+ out[i] = x_out
+ pos[i] = self.position_encoding(x_out).to(x_out.dtype)
+
+ return out, pos
+
+
+class Hiera(nn.Module):
"""
- Window unpartition into original sequences and removing padding.
+ Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.
- Args:
- windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
- window_size (int): window size.
- pad_hw (Tuple): padded height and width (Hp, Wp).
- hw (Tuple): original height and width (H, W) before padding.
+ This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for
+ efficient multiscale feature extraction. It uses a series of transformer blocks organized into stages,
+ with optional pooling and global attention mechanisms.
- Returns:
- x: unpartitioned sequences with [B, H, W, C].
+ Attributes:
+ window_spec (Tuple[int, ...]): Window sizes for each stage.
+ q_stride (Tuple[int, int]): Downsampling stride between stages.
+ stage_ends (List[int]): Indices of the last block in each stage.
+ q_pool_blocks (List[int]): Indices of blocks where pooling is applied.
+ return_interm_layers (bool): Whether to return intermediate layer outputs.
+ patch_embed (PatchEmbed): Module for patch embedding.
+ global_att_blocks (Tuple[int, ...]): Indices of blocks with global attention.
+ window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background.
+ pos_embed (nn.Parameter): Positional embedding for the background.
+ pos_embed_window (nn.Parameter): Positional embedding for the window.
+ blocks (nn.ModuleList): List of MultiScaleBlock modules.
+ channel_list (List[int]): List of output channel dimensions for each stage.
+
+ Methods:
+ _get_pos_embed: Generates positional embeddings by interpolating and combining window and background embeddings.
+ forward: Performs the forward pass through the Hiera model.
+
+ Examples:
+ >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
+ >>> input_tensor = torch.randn(1, 3, 224, 224)
+ >>> output_features = model(input_tensor)
+ >>> for feat in output_features:
+ ... print(feat.shape)
"""
- Hp, Wp = pad_hw
- H, W = hw
- B = windows.shape[0] // (Hp * Wp // window_size // window_size)
- x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
-
- if Hp > H or Wp > W:
- x = x[:, :H, :W, :].contiguous()
- return x
+ def __init__(
+ self,
+ embed_dim: int = 96, # initial embed dim
+ num_heads: int = 1, # initial number of heads
+ drop_path_rate: float = 0.0, # stochastic depth
+ q_pool: int = 3, # number of q_pool stages
+ q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
+ stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
+ dim_mul: float = 2.0, # dim_mul factor at stage shift
+ head_mul: float = 2.0, # head_mul factor at stage shift
+ window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
+ # window size per stage, when not using global att.
+ window_spec: Tuple[int, ...] = (
+ 8,
+ 4,
+ 14,
+ 7,
+ ),
+ # global attn in these blocks
+ global_att_blocks: Tuple[int, ...] = (
+ 12,
+ 16,
+ 20,
+ ),
+ return_interm_layers=True, # return feats from every stage
+ ):
+ """Initializes the Hiera model, configuring its hierarchical vision transformer architecture."""
+ super().__init__()
-def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
- """
- Get relative positional embeddings according to the relative positions of query and key sizes.
+ assert len(stages) == len(window_spec)
+ self.window_spec = window_spec
- Args:
- q_size (int): size of query q.
- k_size (int): size of key k.
- rel_pos (Tensor): relative position embeddings (L, C).
+ depth = sum(stages)
+ self.q_stride = q_stride
+ self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
+ assert 0 <= q_pool <= len(self.stage_ends[:-1])
+ self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
+ self.return_interm_layers = return_interm_layers
- Returns:
- Extracted positional embeddings according to relative positions.
- """
- max_rel_dist = int(2 * max(q_size, k_size) - 1)
- # Interpolate rel pos if needed.
- if rel_pos.shape[0] != max_rel_dist:
- # Interpolate rel pos.
- rel_pos_resized = F.interpolate(
- rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
- size=max_rel_dist,
- mode="linear",
+ self.patch_embed = PatchEmbed(
+ embed_dim=embed_dim,
+ kernel_size=(7, 7),
+ stride=(4, 4),
+ padding=(3, 3),
)
- rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
- else:
- rel_pos_resized = rel_pos
-
- # Scale the coords with short length if shapes for q and k are different.
- q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
- k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
- relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
-
- return rel_pos_resized[relative_coords.long()]
-
-
-def add_decomposed_rel_pos(
- attn: torch.Tensor,
- q: torch.Tensor,
- rel_pos_h: torch.Tensor,
- rel_pos_w: torch.Tensor,
- q_size: Tuple[int, int],
- k_size: Tuple[int, int],
-) -> torch.Tensor:
- """
- Calculate decomposed Relative Positional Embeddings from mvitv2 paper at
- https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py.
-
- Args:
- attn (Tensor): attention map.
- q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
- rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
- rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
- q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
- k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
-
- Returns:
- attn (Tensor): attention map with added relative positional embeddings.
- """
- q_h, q_w = q_size
- k_h, k_w = k_size
- Rh = get_rel_pos(q_h, k_h, rel_pos_h)
- Rw = get_rel_pos(q_w, k_w, rel_pos_w)
+ # Which blocks have global att?
+ self.global_att_blocks = global_att_blocks
- B, _, dim = q.shape
- r_q = q.reshape(B, q_h, q_w, dim)
- rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
- rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
+ # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
+ self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
+ self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))
+ self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
- attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
- B, q_h * q_w, k_h * k_w
- )
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
- return attn
+ cur_stage = 1
+ self.blocks = nn.ModuleList()
+ for i in range(depth):
+ dim_out = embed_dim
+ # lags by a block, so first block of
+ # next stage uses an initial window size
+ # of previous stage and final window size of current stage
+ window_size = self.window_spec[cur_stage - 1]
-class PatchEmbed(nn.Module):
- """Image to Patch Embedding."""
+ if self.global_att_blocks is not None:
+ window_size = 0 if i in self.global_att_blocks else window_size
- def __init__(
- self,
- kernel_size: Tuple[int, int] = (16, 16),
- stride: Tuple[int, int] = (16, 16),
- padding: Tuple[int, int] = (0, 0),
- in_chans: int = 3,
- embed_dim: int = 768,
- ) -> None:
- """
- Initialize PatchEmbed module.
+ if i - 1 in self.stage_ends:
+ dim_out = int(embed_dim * dim_mul)
+ num_heads = int(num_heads * head_mul)
+ cur_stage += 1
- Args:
- kernel_size (Tuple): kernel size of the projection layer.
- stride (Tuple): stride of the projection layer.
- padding (Tuple): padding size of the projection layer.
- in_chans (int): Number of input image channels.
- embed_dim (int): Patch embedding dimension.
- """
- super().__init__()
+ block = MultiScaleBlock(
+ dim=embed_dim,
+ dim_out=dim_out,
+ num_heads=num_heads,
+ drop_path=dpr[i],
+ q_stride=self.q_stride if i in self.q_pool_blocks else None,
+ window_size=window_size,
+ )
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
+ embed_dim = dim_out
+ self.blocks.append(block)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Computes patch embedding by applying convolution and transposing resulting tensor."""
- return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C
+ self.channel_list = (
+ [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
+ if return_interm_layers
+ else [self.blocks[-1].dim_out]
+ )
+
+ def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
+ """Generates positional embeddings by interpolating and combining window and background embeddings."""
+ h, w = hw
+ window_embed = self.pos_embed_window
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
+ pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
+ return pos_embed
+
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ """Performs forward pass through Hiera model, extracting multiscale features from input images."""
+ x = self.patch_embed(x)
+ # x: (B, H, W, C)
+
+ # Add pos embed
+ x = x + self._get_pos_embed(x.shape[1:3])
+
+ outputs = []
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):
+ feats = x.permute(0, 3, 1, 2)
+ outputs.append(feats)
+
+ return outputs
diff --git a/ultralytics/models/sam2/modules/memory_attention.py b/ultralytics/models/sam/modules/memory_attention.py
similarity index 57%
rename from ultralytics/models/sam2/modules/memory_attention.py
rename to ultralytics/models/sam/modules/memory_attention.py
index 8b5673c31..b55b07302 100644
--- a/ultralytics/models/sam2/modules/memory_attention.py
+++ b/ultralytics/models/sam/modules/memory_attention.py
@@ -6,11 +6,50 @@ from typing import Optional
import torch
from torch import Tensor, nn
-from .sam2_blocks import RoPEAttention
+from .blocks import RoPEAttention
class MemoryAttentionLayer(nn.Module):
- """Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks."""
+ """
+ Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
+
+ This class combines self-attention, cross-attention, and feedforward components to process input tensors and
+ generate memory-based attention outputs.
+
+ Attributes:
+ d_model (int): Dimensionality of the model.
+ dim_feedforward (int): Dimensionality of the feedforward network.
+ dropout_value (float): Dropout rate for regularization.
+ self_attn (RoPEAttention): Self-attention mechanism using RoPE (Rotary Position Embedding).
+ cross_attn_image (RoPEAttention): Cross-attention mechanism for image processing.
+ linear1 (nn.Linear): First linear layer of the feedforward network.
+ linear2 (nn.Linear): Second linear layer of the feedforward network.
+ norm1 (nn.LayerNorm): Layer normalization for self-attention output.
+ norm2 (nn.LayerNorm): Layer normalization for cross-attention output.
+ norm3 (nn.LayerNorm): Layer normalization for feedforward network output.
+ dropout1 (nn.Dropout): Dropout layer after self-attention.
+ dropout2 (nn.Dropout): Dropout layer after cross-attention.
+ dropout3 (nn.Dropout): Dropout layer after feedforward network.
+ activation (nn.ReLU): Activation function for the feedforward network.
+ pos_enc_at_attn (bool): Flag to add positional encoding at attention.
+ pos_enc_at_cross_attn_queries (bool): Flag to add positional encoding to cross-attention queries.
+ pos_enc_at_cross_attn_keys (bool): Flag to add positional encoding to cross-attention keys.
+
+ Methods:
+ forward: Performs the full memory attention operation on input tensors.
+ _forward_sa: Performs self-attention on input tensor.
+ _forward_ca: Performs cross-attention between target and memory tensors.
+
+ Examples:
+ >>> layer = MemoryAttentionLayer(d_model=256, dim_feedforward=2048, dropout=0.1)
+ >>> tgt = torch.randn(1, 100, 256)
+ >>> memory = torch.randn(1, 100, 64)
+ >>> pos = torch.randn(1, 100, 256)
+ >>> query_pos = torch.randn(1, 100, 256)
+ >>> output = layer(tgt, memory, pos, query_pos)
+ >>> print(output.shape)
+ torch.Size([1, 100, 256])
+ """
def __init__(
self,
@@ -21,7 +60,7 @@ class MemoryAttentionLayer(nn.Module):
pos_enc_at_cross_attn_keys: bool = True,
pos_enc_at_cross_attn_queries: bool = False,
):
- """Initializes a MemoryAttentionLayer with self-attention, cross-attention, and feedforward components."""
+ """Initializes a memory attention layer with self-attention, cross-attention, and feedforward components."""
super().__init__()
self.d_model = d_model
self.dim_feedforward = dim_feedforward
@@ -88,7 +127,7 @@ class MemoryAttentionLayer(nn.Module):
query_pos: Optional[Tensor] = None,
num_k_exclude_rope: int = 0,
) -> torch.Tensor:
- """Performs self-attention, cross-attention, and MLP operations on input tensors for memory-based attention."""
+ """Processes input tensors using self-attention, cross-attention, and MLP for memory-based attention."""
tgt = self._forward_sa(tgt, query_pos)
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
# MLP
@@ -99,7 +138,35 @@ class MemoryAttentionLayer(nn.Module):
class MemoryAttention(nn.Module):
- """Memory attention module for processing sequential data with self and cross-attention mechanisms."""
+ """
+ Memory attention module for processing sequential data with self and cross-attention mechanisms.
+
+ This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
+ for processing sequential data, particularly useful in transformer-like architectures.
+
+ Attributes:
+ d_model (int): The dimension of the model's hidden state.
+ layers (nn.ModuleList): A list of MemoryAttentionLayer modules.
+ num_layers (int): The number of attention layers.
+ norm (nn.LayerNorm): Layer normalization applied to the output.
+ pos_enc_at_input (bool): Whether to apply positional encoding at the input.
+ batch_first (bool): Whether the input tensors are in batch-first format.
+
+ Methods:
+ forward: Processes input tensors through the attention layers.
+
+ Examples:
+ >>> d_model = 256
+ >>> layer = MemoryAttentionLayer(d_model)
+ >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
+ >>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
+ >>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
+ >>> curr_pos = torch.randn(10, 32, d_model)
+ >>> memory_pos = torch.randn(20, 32, d_model)
+ >>> output = attention(curr, memory, curr_pos, memory_pos)
+ >>> print(output.shape)
+ torch.Size([10, 32, 256])
+ """
def __init__(
self,
@@ -126,7 +193,7 @@ class MemoryAttention(nn.Module):
memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
):
- """Applies self-attention and cross-attention to input tensors, processing through multiple layers."""
+ """Processes input tensors through multiple attention layers, applying self and cross-attention mechanisms."""
if isinstance(curr, list):
assert isinstance(curr_pos, list)
assert len(curr) == len(curr_pos) == 1
diff --git a/ultralytics/models/sam/modules/sam.py b/ultralytics/models/sam/modules/sam.py
index 1617527f0..27509575b 100644
--- a/ultralytics/models/sam/modules/sam.py
+++ b/ultralytics/models/sam/modules/sam.py
@@ -9,25 +9,48 @@
from typing import List
import torch
+import torch.nn.functional as F
from torch import nn
+from torch.nn.init import trunc_normal_
-from .decoders import MaskDecoder
+from ultralytics.nn.modules import MLP
+
+from .blocks import SAM2TwoWayTransformer
+from .decoders import MaskDecoder, SAM2MaskDecoder
from .encoders import ImageEncoderViT, PromptEncoder
+from .utils import get_1d_sine_pe, select_closest_cond_frames
+
+# a large negative value as a placeholder score for missing objects
+NO_OBJ_SCORE = -1024.0
class SAMModel(nn.Module):
"""
- SAMModel (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate
- image embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by
- the mask decoder to predict object masks.
+ Segment Anything Model (SAM) for object segmentation tasks.
+
+ This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images
+ and input prompts.
Attributes:
mask_threshold (float): Threshold value for mask prediction.
- image_encoder (ImageEncoderViT): The backbone used to encode the image into embeddings.
- prompt_encoder (PromptEncoder): Encodes various types of input prompts.
- mask_decoder (MaskDecoder): Predicts object masks from the image and prompt embeddings.
- pixel_mean (List[float]): Mean pixel values for image normalization.
- pixel_std (List[float]): Standard deviation values for image normalization.
+ image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.
+ prompt_encoder (PromptEncoder): Encoder for various types of input prompts.
+ mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.
+ pixel_mean (torch.Tensor): Mean pixel values for image normalization, shape (3, 1, 1).
+ pixel_std (torch.Tensor): Standard deviation values for image normalization, shape (3, 1, 1).
+
+ Methods:
+ __init__: Initializes the SAMModel with encoders, decoder, and normalization parameters.
+
+ Examples:
+ >>> image_encoder = ImageEncoderViT(...)
+ >>> prompt_encoder = PromptEncoder(...)
+ >>> mask_decoder = MaskDecoder(...)
+ >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
+ >>> # Further usage depends on SAMPredictor class
+
+ Notes:
+ All forward() operations are implemented in the SAMPredictor class.
"""
mask_threshold: float = 0.0
@@ -43,17 +66,22 @@ class SAMModel(nn.Module):
"""
Initialize the SAMModel class to predict object masks from an image and input prompts.
- Note:
- All forward() operations moved to SAMPredictor.
-
Args:
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
- pixel_mean (List[float], optional): Mean values for normalizing pixels in the input image. Defaults to
- (123.675, 116.28, 103.53).
- pixel_std (List[float], optional): Std values for normalizing pixels in the input image. Defaults to
- (58.395, 57.12, 57.375).
+ pixel_mean (List[float]): Mean values for normalizing pixels in the input image.
+ pixel_std (List[float]): Std values for normalizing pixels in the input image.
+
+ Examples:
+ >>> image_encoder = ImageEncoderViT(...)
+ >>> prompt_encoder = PromptEncoder(...)
+ >>> mask_decoder = MaskDecoder(...)
+ >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
+ >>> # Further usage depends on SAMPredictor class
+
+ Notes:
+ All forward() operations moved to SAMPredictor.
"""
super().__init__()
self.image_encoder = image_encoder
@@ -61,3 +89,846 @@ class SAMModel(nn.Module):
self.mask_decoder = mask_decoder
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+
+class SAM2Model(torch.nn.Module):
+ """
+ SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
+
+ This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms
+ for temporal consistency and efficient tracking of objects across frames.
+
+ Attributes:
+ mask_threshold (float): Threshold value for mask prediction.
+ image_encoder (ImageEncoderViT): Visual encoder for extracting image features.
+ memory_attention (nn.Module): Module for attending to memory features.
+ memory_encoder (nn.Module): Encoder for generating memory representations.
+ num_maskmem (int): Number of accessible memory frames.
+ image_size (int): Size of input images.
+ backbone_stride (int): Stride of the backbone network output.
+ sam_prompt_embed_dim (int): Dimension of SAM prompt embeddings.
+ sam_image_embedding_size (int): Size of SAM image embeddings.
+ sam_prompt_encoder (PromptEncoder): Encoder for processing input prompts.
+ sam_mask_decoder (SAM2MaskDecoder): Decoder for generating object masks.
+ obj_ptr_proj (nn.Module): Projection layer for object pointers.
+ obj_ptr_tpos_proj (nn.Module): Projection for temporal positional encoding in object pointers.
+
+ Methods:
+ forward_image: Processes image batch through encoder to extract multi-level features.
+ track_step: Performs a single tracking step, updating object masks and memory features.
+
+ Examples:
+ >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
+ >>> image_batch = torch.rand(1, 3, 512, 512)
+ >>> features = model.forward_image(image_batch)
+ >>> track_results = model.track_step(0, True, features, None, None, None, {})
+ """
+
+ mask_threshold: float = 0.0
+
+ def __init__(
+ self,
+ image_encoder,
+ memory_attention,
+ memory_encoder,
+ num_maskmem=7,
+ image_size=512,
+ backbone_stride=16,
+ sigmoid_scale_for_mem_enc=1.0,
+ sigmoid_bias_for_mem_enc=0.0,
+ binarize_mask_from_pts_for_mem_enc=False,
+ use_mask_input_as_output_without_sam=False,
+ max_cond_frames_in_attn=-1,
+ directly_add_no_mem_embed=False,
+ use_high_res_features_in_sam=False,
+ multimask_output_in_sam=False,
+ multimask_min_pt_num=1,
+ multimask_max_pt_num=1,
+ multimask_output_for_tracking=False,
+ use_multimask_token_for_obj_ptr: bool = False,
+ iou_prediction_use_sigmoid=False,
+ memory_temporal_stride_for_eval=1,
+ add_all_frames_to_correct_as_cond=False,
+ non_overlap_masks_for_mem_enc=False,
+ use_obj_ptrs_in_encoder=False,
+ max_obj_ptrs_in_encoder=16,
+ add_tpos_enc_to_obj_ptrs=True,
+ proj_tpos_enc_in_obj_ptrs=False,
+ only_obj_ptrs_in_the_past_for_eval=False,
+ pred_obj_scores: bool = False,
+ pred_obj_scores_mlp: bool = False,
+ fixed_no_obj_ptr: bool = False,
+ soft_no_obj_ptr: bool = False,
+ use_mlp_for_obj_ptr_proj: bool = False,
+ sam_mask_decoder_extra_args=None,
+ compile_image_encoder: bool = False,
+ ):
+ """
+ Initializes the SAM2Model for video object segmentation with memory-based tracking.
+
+ Args:
+ image_encoder (nn.Module): Visual encoder for extracting image features.
+ memory_attention (nn.Module): Module for attending to memory features.
+ memory_encoder (nn.Module): Encoder for generating memory representations.
+ num_maskmem (int): Number of accessible memory frames. Default is 7 (1 input frame + 6 previous frames).
+ image_size (int): Size of input images.
+ backbone_stride (int): Stride of the image backbone output.
+ sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
+ sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
+ binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
+ with clicks during evaluation.
+ use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
+ prompt encoder and mask decoder on frames with mask input.
+ max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
+ -1 means no limit.
+ directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
+ first frame.
+ use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
+ multimask_output_in_sam (bool): Whether to output multiple (3) masks for the first click on initial
+ conditioning frames.
+ multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
+ multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
+ multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
+ use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
+ iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
+ memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
+ add_all_frames_to_correct_as_cond (bool): Whether to append frames with correction clicks to conditioning
+ frame list.
+ non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
+ memory encoder during evaluation.
+ use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
+ max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder
+ cross-attention.
+ add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in
+ the encoder.
+ proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
+ encoding in object pointers.
+ only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past
+ during evaluation.
+ pred_obj_scores (bool): Whether to predict if there is an object in the frame.
+ pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
+ fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
+ soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.
+ use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
+ sam_mask_decoder_extra_args (Dict | None): Extra arguments for constructing the SAM mask decoder.
+ compile_image_encoder (bool): Whether to compile the image encoder for faster inference.
+
+ Examples:
+ >>> image_encoder = ImageEncoderViT(...)
+ >>> memory_attention = SAM2TwoWayTransformer(...)
+ >>> memory_encoder = nn.Sequential(...)
+ >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
+ >>> image_batch = torch.rand(1, 3, 512, 512)
+ >>> features = model.forward_image(image_batch)
+ >>> track_results = model.track_step(0, True, features, None, None, None, {})
+ """
+ super().__init__()
+
+ # Part 1: the image backbone
+ self.image_encoder = image_encoder
+ # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
+ self.use_high_res_features_in_sam = use_high_res_features_in_sam
+ self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
+ self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
+ self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
+ if use_obj_ptrs_in_encoder:
+ # A conv layer to downsample the mask prompt to stride 4 (the same stride as
+ # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
+ # so that it can be fed into the SAM mask decoder to generate a pointer.
+ self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
+ self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
+ if proj_tpos_enc_in_obj_ptrs:
+ assert add_tpos_enc_to_obj_ptrs # these options need to be used together
+ self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
+ self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
+
+ # Part 2: memory attention to condition current frame's visual features
+ # with memories (and obj ptrs) from past frames
+ self.memory_attention = memory_attention
+ self.hidden_dim = memory_attention.d_model
+
+ # Part 3: memory encoder for the previous frame's outputs
+ self.memory_encoder = memory_encoder
+ self.mem_dim = self.hidden_dim
+ if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
+ # if there is compression of memories along channel dim
+ self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
+ self.num_maskmem = num_maskmem # Number of memories accessible
+ # Temporal encoding of the memories
+ self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
+ trunc_normal_(self.maskmem_tpos_enc, std=0.02)
+ # a single token to indicate no memory embedding from previous frames
+ self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+ self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+ trunc_normal_(self.no_mem_embed, std=0.02)
+ trunc_normal_(self.no_mem_pos_enc, std=0.02)
+ self.directly_add_no_mem_embed = directly_add_no_mem_embed
+ # Apply sigmoid to the output raw mask logits (to turn them from
+ # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
+ self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
+ self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
+ self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
+ self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
+ self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
+ # On frames with mask input, whether to directly output the input mask without
+ # using a SAM prompt encoder + mask decoder
+ self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
+ self.multimask_output_in_sam = multimask_output_in_sam
+ self.multimask_min_pt_num = multimask_min_pt_num
+ self.multimask_max_pt_num = multimask_max_pt_num
+ self.multimask_output_for_tracking = multimask_output_for_tracking
+ self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
+ self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
+
+ # Part 4: SAM-style prompt encoder (for both mask and point inputs)
+ # and SAM-style mask decoder for the final mask output
+ self.image_size = image_size
+ self.backbone_stride = backbone_stride
+ self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
+ self.pred_obj_scores = pred_obj_scores
+ self.pred_obj_scores_mlp = pred_obj_scores_mlp
+ self.fixed_no_obj_ptr = fixed_no_obj_ptr
+ self.soft_no_obj_ptr = soft_no_obj_ptr
+ if self.fixed_no_obj_ptr:
+ assert self.pred_obj_scores
+ assert self.use_obj_ptrs_in_encoder
+ if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
+ self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
+ trunc_normal_(self.no_obj_ptr, std=0.02)
+ self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
+
+ self._build_sam_heads()
+ self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
+ self.max_cond_frames_in_attn = max_cond_frames_in_attn
+
+ # Model compilation
+ if compile_image_encoder:
+ # Compile the forward function (not the full module) to allow loading checkpoints.
+ print("Image encoder compilation is enabled. First forward pass will be slow.")
+ self.image_encoder.forward = torch.compile(
+ self.image_encoder.forward,
+ mode="max-autotune",
+ fullgraph=True,
+ dynamic=False,
+ )
+
+ @property
+ def device(self):
+ """Returns the device on which the model's parameters are stored."""
+ return next(self.parameters()).device
+
+ def forward(self, *args, **kwargs):
+ """Processes image and prompt inputs to generate object masks and scores in video sequences."""
+ raise NotImplementedError(
+ "Please use the corresponding methods in SAM2VideoPredictor for inference."
+ "See notebooks/video_predictor_example.ipynb for an example."
+ )
+
+ def _build_sam_heads(self):
+ """Builds SAM-style prompt encoder and mask decoder for image segmentation tasks."""
+ self.sam_prompt_embed_dim = self.hidden_dim
+ self.sam_image_embedding_size = self.image_size // self.backbone_stride
+
+ # build PromptEncoder and MaskDecoder from SAM
+ # (their hyperparameters like `mask_in_chans=16` are from SAM code)
+ self.sam_prompt_encoder = PromptEncoder(
+ embed_dim=self.sam_prompt_embed_dim,
+ image_embedding_size=(
+ self.sam_image_embedding_size,
+ self.sam_image_embedding_size,
+ ),
+ input_image_size=(self.image_size, self.image_size),
+ mask_in_chans=16,
+ )
+ self.sam_mask_decoder = SAM2MaskDecoder(
+ num_multimask_outputs=3,
+ transformer=SAM2TwoWayTransformer(
+ depth=2,
+ embedding_dim=self.sam_prompt_embed_dim,
+ mlp_dim=2048,
+ num_heads=8,
+ ),
+ transformer_dim=self.sam_prompt_embed_dim,
+ iou_head_depth=3,
+ iou_head_hidden_dim=256,
+ use_high_res_features=self.use_high_res_features_in_sam,
+ iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
+ pred_obj_scores=self.pred_obj_scores,
+ pred_obj_scores_mlp=self.pred_obj_scores_mlp,
+ use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
+ **(self.sam_mask_decoder_extra_args or {}),
+ )
+ if self.use_obj_ptrs_in_encoder:
+ # a linear projection on SAM output tokens to turn them into object pointers
+ self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
+ if self.use_mlp_for_obj_ptr_proj:
+ self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
+ else:
+ self.obj_ptr_proj = torch.nn.Identity()
+ if self.proj_tpos_enc_in_obj_ptrs:
+ # a linear projection on temporal positional encoding in object pointers to
+ # avoid potential interference with spatial positional encoding
+ self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
+ else:
+ self.obj_ptr_tpos_proj = torch.nn.Identity()
+
+ def _forward_sam_heads(
+ self,
+ backbone_features,
+ point_inputs=None,
+ mask_inputs=None,
+ high_res_features=None,
+ multimask_output=False,
+ ):
+ """
+ Forward pass through SAM prompt encoders and mask heads.
+
+ This method processes image features and optional point/mask inputs to generate object masks and scores.
+
+ Args:
+ backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
+ point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts.
+ 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute
+ pixel-unit coordinates in (x, y) format for P input points.
+ 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks,
+ 0 means negative clicks, and -1 means padding.
+ mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the
+ same spatial size as the image.
+ high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes
+ (B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps
+ for SAM decoder.
+ multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False,
+ output only 1 mask and its IoU estimate.
+
+ Returns:
+ (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]):
+ low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
+ high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
+ ious: Tensor of shape (B, M) with estimated IoU for each output mask.
+ low_res_masks: Tensor of shape (B, 1, H*4, W*4) with best low-resolution mask.
+ high_res_masks: Tensor of shape (B, 1, H*16, W*16) with best high-resolution mask.
+ obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
+ object_score_logits: Tensor of shape (B,) with object score logits.
+
+ Where M is 3 if multimask_output=True, and 1 if multimask_output=False.
+
+ Examples:
+ >>> backbone_features = torch.rand(1, 256, 32, 32)
+ >>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])}
+ >>> mask_inputs = torch.rand(1, 1, 512, 512)
+ >>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)
+ >>> low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits = results
+ """
+ B = backbone_features.size(0)
+ device = backbone_features.device
+ assert backbone_features.size(1) == self.sam_prompt_embed_dim
+ assert backbone_features.size(2) == self.sam_image_embedding_size
+ assert backbone_features.size(3) == self.sam_image_embedding_size
+
+ # a) Handle point prompts
+ if point_inputs is not None:
+ sam_point_coords = point_inputs["point_coords"]
+ sam_point_labels = point_inputs["point_labels"]
+ assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
+ else:
+ # If no points are provide, pad with an empty point (with label -1)
+ sam_point_coords = torch.zeros(B, 1, 2, device=device)
+ sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
+
+ # b) Handle mask prompts
+ if mask_inputs is not None:
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
+ # and feed it as a dense mask prompt into the SAM mask encoder
+ assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
+ if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
+ sam_mask_prompt = F.interpolate(
+ mask_inputs.float(),
+ size=self.sam_prompt_encoder.mask_input_size,
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ )
+ else:
+ sam_mask_prompt = mask_inputs
+ else:
+ # Otherwise, simply feed None (and SAM's prompt encoder will add
+ # a learned `no_mask_embed` to indicate no mask input in this case).
+ sam_mask_prompt = None
+
+ sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
+ points=(sam_point_coords, sam_point_labels),
+ boxes=None,
+ masks=sam_mask_prompt,
+ )
+ (
+ low_res_multimasks,
+ ious,
+ sam_output_tokens,
+ object_score_logits,
+ ) = self.sam_mask_decoder(
+ image_embeddings=backbone_features,
+ image_pe=self.sam_prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ repeat_image=False, # the image is already batched
+ high_res_features=high_res_features,
+ )
+ if self.pred_obj_scores:
+ is_obj_appearing = object_score_logits > 0
+
+ # Mask used for spatial memories is always a *hard* choice between obj and no obj,
+ # consistent with the actual mask prediction
+ low_res_multimasks = torch.where(
+ is_obj_appearing[:, None, None],
+ low_res_multimasks,
+ NO_OBJ_SCORE,
+ )
+
+ # convert masks from possibly bfloat16 (or float16) to float32
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
+ low_res_multimasks = low_res_multimasks.float()
+ high_res_multimasks = F.interpolate(
+ low_res_multimasks,
+ size=(self.image_size, self.image_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ sam_output_token = sam_output_tokens[:, 0]
+ if multimask_output:
+ # take the best mask prediction (with the highest IoU estimation)
+ best_iou_inds = torch.argmax(ious, dim=-1)
+ batch_inds = torch.arange(B, device=device)
+ low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+ high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+ if sam_output_tokens.size(1) > 1:
+ sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
+ else:
+ low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
+
+ # Extract object pointer from the SAM output token (with occlusion handling)
+ obj_ptr = self.obj_ptr_proj(sam_output_token)
+ if self.pred_obj_scores:
+ # Allow *soft* no obj ptr, unlike for masks
+ if self.soft_no_obj_ptr:
+ # Only hard possible with gt
+ assert not self.teacher_force_obj_scores_for_mem
+ lambda_is_obj_appearing = object_score_logits.sigmoid()
+ else:
+ lambda_is_obj_appearing = is_obj_appearing.float()
+
+ if self.fixed_no_obj_ptr:
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
+
+ return (
+ low_res_multimasks,
+ high_res_multimasks,
+ ious,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ object_score_logits,
+ )
+
+ def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
+ """Processes mask inputs directly as output, bypassing SAM encoder/decoder."""
+ # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
+ out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
+ mask_inputs_float = mask_inputs.float()
+ high_res_masks = mask_inputs_float * out_scale + out_bias
+ low_res_masks = F.interpolate(
+ high_res_masks,
+ size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
+ align_corners=False,
+ mode="bilinear",
+ antialias=True, # use antialias for downsampling
+ )
+ # a dummy IoU prediction of all 1's under mask input
+ ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
+ if not self.use_obj_ptrs_in_encoder:
+ # all zeros as a dummy object pointer (of shape [B, C])
+ obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)
+ else:
+ # produce an object pointer using the SAM decoder from the mask input
+ _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
+ backbone_features=backbone_features,
+ mask_inputs=self.mask_downsample(mask_inputs_float),
+ high_res_features=high_res_features,
+ )
+ # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
+ # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
+ # on the object_scores from the SAM decoder.
+ is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
+ is_obj_appearing = is_obj_appearing[..., None]
+ lambda_is_obj_appearing = is_obj_appearing.float()
+ object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
+ if self.pred_obj_scores:
+ if self.fixed_no_obj_ptr:
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
+
+ return (
+ low_res_masks,
+ high_res_masks,
+ ious,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ object_score_logits,
+ )
+
+ def forward_image(self, img_batch: torch.Tensor):
+ """Processes image batch through encoder to extract multi-level features for SAM model."""
+ backbone_out = self.image_encoder(img_batch)
+ if self.use_high_res_features_in_sam:
+ # precompute projected level 0 and level 1 features in SAM decoder
+ # to avoid running it again on every SAM click
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
+ return backbone_out
+
+ def _prepare_backbone_features(self, backbone_out):
+ """Prepares and flattens visual features from the image backbone output for further processing."""
+ backbone_out = backbone_out.copy()
+ assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
+ assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
+
+ feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
+ vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
+
+ feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
+ # flatten NxCxHxW to HWxNxC
+ vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
+ vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
+
+ return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
+
+ def _prepare_memory_conditioned_features(
+ self,
+ frame_idx,
+ is_init_cond_frame,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ output_dict,
+ num_frames,
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
+ ):
+ """Prepares memory-conditioned features by fusing current frame's visual features with previous memories."""
+ B = current_vision_feats[-1].size(1) # batch size on this frame
+ C = self.hidden_dim
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
+ device = current_vision_feats[-1].device
+ # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
+ # In this case, we skip the fusion with any memory.
+ if self.num_maskmem == 0: # Disable memory and skip fusion
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
+ return pix_feat
+
+ num_obj_ptr_tokens = 0
+ # Step 1: condition the visual features of the current frame on previous memories
+ if not is_init_cond_frame:
+ # Retrieve the memories encoded with the maskmem backbone
+ to_cat_memory, to_cat_memory_pos_embed = [], []
+ # Add conditioning frames's output first (all cond frames have t_pos=0 for
+ # when getting temporal positional embedding below)
+ assert len(output_dict["cond_frame_outputs"]) > 0
+ # Select a maximum number of temporally closest cond frames for cross attention
+ cond_outputs = output_dict["cond_frame_outputs"]
+ selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
+ frame_idx, cond_outputs, self.max_cond_frames_in_attn
+ )
+ t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
+ # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
+ # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
+ # We also allow taking the memory frame non-consecutively (with r>1), in which case
+ # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
+ r = self.memory_temporal_stride_for_eval
+ for t_pos in range(1, self.num_maskmem):
+ t_rel = self.num_maskmem - t_pos # how many frames before current frame
+ if t_rel == 1:
+ # for t_rel == 1, we take the last frame (regardless of r)
+ if not track_in_reverse:
+ # the frame immediately before this frame (i.e. frame_idx - 1)
+ prev_frame_idx = frame_idx - t_rel
+ else:
+ # the frame immediately after this frame (i.e. frame_idx + 1)
+ prev_frame_idx = frame_idx + t_rel
+ else:
+ # for t_rel >= 2, we take the memory frame from every r-th frames
+ if not track_in_reverse:
+ # first find the nearest frame among every r-th frames before this frame
+ # for r=1, this would be (frame_idx - 2)
+ prev_frame_idx = ((frame_idx - 2) // r) * r
+ # then seek further among every r-th frames
+ prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
+ else:
+ # first find the nearest frame among every r-th frames after this frame
+ # for r=1, this would be (frame_idx + 2)
+ prev_frame_idx = -(-(frame_idx + 2) // r) * r
+ # then seek further among every r-th frames
+ prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
+ out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
+ if out is None:
+ # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
+ # frames, we still attend to it as if it's a non-conditioning frame.
+ out = unselected_cond_outputs.get(prev_frame_idx, None)
+ t_pos_and_prevs.append((t_pos, out))
+
+ for t_pos, prev in t_pos_and_prevs:
+ if prev is None:
+ continue # skip padding frames
+ # "maskmem_features" might have been offloaded to CPU in demo use cases,
+ # so we load it back to GPU (it's a no-op if it's already on GPU).
+ feats = prev["maskmem_features"].cuda(non_blocking=True)
+ to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
+ # Spatial positional encoding (it might have been offloaded to CPU in eval)
+ maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
+ maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
+ # Temporal positional encoding
+ maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
+ to_cat_memory_pos_embed.append(maskmem_enc)
+
+ # Construct the list of past object pointers
+ if self.use_obj_ptrs_in_encoder:
+ max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
+ # First add those object pointers from selected conditioning frames
+ # (optionally, only include object pointers in the past during evaluation)
+ if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
+ ptr_cond_outputs = {
+ t: out
+ for t, out in selected_cond_outputs.items()
+ if (t >= frame_idx if track_in_reverse else t <= frame_idx)
+ }
+ else:
+ ptr_cond_outputs = selected_cond_outputs
+ pos_and_ptrs = [
+ # Temporal pos encoding contains how far away each pointer is from current frame
+ (abs(frame_idx - t), out["obj_ptr"])
+ for t, out in ptr_cond_outputs.items()
+ ]
+ # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
+ for t_diff in range(1, max_obj_ptrs_in_encoder):
+ t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
+ if t < 0 or (num_frames is not None and t >= num_frames):
+ break
+ out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None))
+ if out is not None:
+ pos_and_ptrs.append((t_diff, out["obj_ptr"]))
+ # If we have at least one object pointer, add them to the across attention
+ if len(pos_and_ptrs) > 0:
+ pos_list, ptrs_list = zip(*pos_and_ptrs)
+ # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
+ obj_ptrs = torch.stack(ptrs_list, dim=0)
+ # a temporal positional embedding based on how far each object pointer is from
+ # the current frame (sine embedding normalized by the max pointer num).
+ if self.add_tpos_enc_to_obj_ptrs:
+ t_diff_max = max_obj_ptrs_in_encoder - 1
+ tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
+ obj_pos = torch.tensor(pos_list, device=device)
+ obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
+ obj_pos = self.obj_ptr_tpos_proj(obj_pos)
+ obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
+ else:
+ obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
+ if self.mem_dim < C:
+ # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
+ obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
+ obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
+ obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
+ to_cat_memory.append(obj_ptrs)
+ to_cat_memory_pos_embed.append(obj_pos)
+ num_obj_ptr_tokens = obj_ptrs.shape[0]
+ else:
+ num_obj_ptr_tokens = 0
+ else:
+ # for initial conditioning frames, encode them without using any previous memory
+ if self.directly_add_no_mem_embed:
+ # directly add no-mem embedding (instead of using the transformer encoder)
+ pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
+ return pix_feat_with_mem
+
+ # Use a dummy token on the first frame (to avoid empty memory input to transformer encoder)
+ to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
+ to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
+
+ # Step 2: Concatenate the memories and forward through the transformer encoder
+ memory = torch.cat(to_cat_memory, dim=0)
+ memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
+
+ pix_feat_with_mem = self.memory_attention(
+ curr=current_vision_feats,
+ curr_pos=current_vision_pos_embeds,
+ memory=memory,
+ memory_pos=memory_pos_embed,
+ num_obj_ptr_tokens=num_obj_ptr_tokens,
+ )
+ # reshape the output (HW)BC => BCHW
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
+ return pix_feat_with_mem
+
+ def _encode_new_memory(
+ self,
+ current_vision_feats,
+ feat_sizes,
+ pred_masks_high_res,
+ is_mask_from_pts,
+ ):
+ """Encodes frame features and masks into a new memory representation for video segmentation."""
+ B = current_vision_feats[-1].size(1) # batch size on this frame
+ C = self.hidden_dim
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
+ # top-level feature, (HW)BC => BCHW
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
+ if self.non_overlap_masks_for_mem_enc and not self.training:
+ # optionally, apply non-overlapping constraints to the masks (it's applied
+ # in the batch dimension and should only be used during eval, where all
+ # the objects come from the same video under batch size 1).
+ pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
+ # scale the raw mask logits with a temperature before applying sigmoid
+ binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
+ if binarize and not self.training:
+ mask_for_mem = (pred_masks_high_res > 0).float()
+ else:
+ # apply sigmoid on the raw mask logits to turn them into range (0, 1)
+ mask_for_mem = torch.sigmoid(pred_masks_high_res)
+ # apply scale and bias terms to the sigmoid probabilities
+ if self.sigmoid_scale_for_mem_enc != 1.0:
+ mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
+ if self.sigmoid_bias_for_mem_enc != 0.0:
+ mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
+ maskmem_out = self.memory_encoder(
+ pix_feat,
+ mask_for_mem,
+ skip_mask_sigmoid=True, # sigmoid already applied
+ )
+ maskmem_features = maskmem_out["vision_features"]
+ maskmem_pos_enc = maskmem_out["vision_pos_enc"]
+
+ return maskmem_features, maskmem_pos_enc
+
+ def track_step(
+ self,
+ frame_idx,
+ is_init_cond_frame,
+ current_vision_feats,
+ current_vision_pos_embeds,
+ feat_sizes,
+ point_inputs,
+ mask_inputs,
+ output_dict,
+ num_frames,
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
+ # Whether to run the memory encoder on the predicted masks. Sometimes we might want
+ # to skip the memory encoder with `run_mem_encoder=False`. For example,
+ # in demo we might call `track_step` multiple times for each user click,
+ # and only encode the memory when the user finalizes their clicks. And in ablation
+ # settings like SAM training on static images, we don't need the memory encoder.
+ run_mem_encoder=True,
+ # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
+ prev_sam_mask_logits=None,
+ ):
+ """Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
+ current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
+ # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
+ if len(current_vision_feats) > 1:
+ high_res_features = [
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
+ for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
+ ]
+ else:
+ high_res_features = None
+ if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
+ # When use_mask_input_as_output_without_sam=True, we directly output the mask input
+ # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0)
+ pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
+ sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
+ else:
+ # fused the visual feature with previous memory features in the memory bank
+ pix_feat_with_mem = self._prepare_memory_conditioned_features(
+ frame_idx=frame_idx,
+ is_init_cond_frame=is_init_cond_frame,
+ current_vision_feats=current_vision_feats[-1:],
+ current_vision_pos_embeds=current_vision_pos_embeds[-1:],
+ feat_sizes=feat_sizes[-1:],
+ output_dict=output_dict,
+ num_frames=num_frames,
+ track_in_reverse=track_in_reverse,
+ )
+ # apply SAM-style segmentation head
+ # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
+ # e.g. in demo where such logits come from earlier interaction instead of correction sampling
+ # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
+ if prev_sam_mask_logits is not None:
+ assert point_inputs is not None and mask_inputs is None
+ mask_inputs = prev_sam_mask_logits
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
+ sam_outputs = self._forward_sam_heads(
+ backbone_features=pix_feat_with_mem,
+ point_inputs=point_inputs,
+ mask_inputs=mask_inputs,
+ high_res_features=high_res_features,
+ multimask_output=multimask_output,
+ )
+ (
+ _,
+ _,
+ _,
+ low_res_masks,
+ high_res_masks,
+ obj_ptr,
+ _,
+ ) = sam_outputs
+
+ current_out["pred_masks"] = low_res_masks
+ current_out["pred_masks_high_res"] = high_res_masks
+ current_out["obj_ptr"] = obj_ptr
+
+ # Finally run the memory encoder on the predicted mask to encode
+ # it into a new memory feature (that can be used in future frames)
+ if run_mem_encoder and self.num_maskmem > 0:
+ high_res_masks_for_mem_enc = high_res_masks
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
+ current_vision_feats=current_vision_feats,
+ feat_sizes=feat_sizes,
+ pred_masks_high_res=high_res_masks_for_mem_enc,
+ is_mask_from_pts=(point_inputs is not None),
+ )
+ current_out["maskmem_features"] = maskmem_features
+ current_out["maskmem_pos_enc"] = maskmem_pos_enc
+ else:
+ current_out["maskmem_features"] = None
+ current_out["maskmem_pos_enc"] = None
+
+ return current_out
+
+ def _use_multimask(self, is_init_cond_frame, point_inputs):
+ """Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
+ num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
+ multimask_output = (
+ self.multimask_output_in_sam
+ and (is_init_cond_frame or self.multimask_output_for_tracking)
+ and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
+ )
+ return multimask_output
+
+ def _apply_non_overlapping_constraints(self, pred_masks):
+ """Applies non-overlapping constraints to masks, keeping highest scoring object per location."""
+ batch_size = pred_masks.size(0)
+ if batch_size == 1:
+ return pred_masks
+
+ device = pred_masks.device
+ # "max_obj_inds": object index of the object with the highest score at each location
+ max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
+ # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
+ batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
+ keep = max_obj_inds == batch_obj_inds
+ # suppress overlapping regions' scores below -10.0 so that the foreground regions
+ # don't overlap (here sigmoid(-10.0)=4.5398e-05)
+ pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
+ return pred_masks
diff --git a/ultralytics/models/sam/modules/tiny_encoder.py b/ultralytics/models/sam/modules/tiny_encoder.py
index bc026dd64..0aca78ed6 100644
--- a/ultralytics/models/sam/modules/tiny_encoder.py
+++ b/ultralytics/models/sam/modules/tiny_encoder.py
@@ -17,16 +17,40 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
+from ultralytics.nn.modules import LayerNorm2d
from ultralytics.utils.instance import to_2tuple
class Conv2d_BN(torch.nn.Sequential):
- """A sequential container that performs 2D convolution followed by batch normalization."""
+ """
+ A sequential container that performs 2D convolution followed by batch normalization.
+
+ Attributes:
+ c (torch.nn.Conv2d): 2D convolution layer.
+ 1 (torch.nn.BatchNorm2d): Batch normalization layer.
+
+ Methods:
+ __init__: Initializes the Conv2d_BN with specified parameters.
+
+ Args:
+ a (int): Number of input channels.
+ b (int): Number of output channels.
+ ks (int): Kernel size for the convolution. Defaults to 1.
+ stride (int): Stride for the convolution. Defaults to 1.
+ pad (int): Padding for the convolution. Defaults to 0.
+ dilation (int): Dilation factor for the convolution. Defaults to 1.
+ groups (int): Number of groups for the convolution. Defaults to 1.
+ bn_weight_init (float): Initial value for batch normalization weight. Defaults to 1.
+
+ Examples:
+ >>> conv_bn = Conv2d_BN(3, 64, ks=3, stride=1, pad=1)
+ >>> input_tensor = torch.randn(1, 3, 224, 224)
+ >>> output = conv_bn(input_tensor)
+ >>> print(output.shape)
+ """
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
- """Initializes the MBConv model with given input channels, output channels, expansion ratio, activation, and
- drop path.
- """
+ """Initializes a sequential container with 2D convolution followed by batch normalization."""
super().__init__()
self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = torch.nn.BatchNorm2d(b)
@@ -36,12 +60,29 @@ class Conv2d_BN(torch.nn.Sequential):
class PatchEmbed(nn.Module):
- """Embeds images into patches and projects them into a specified embedding dimension."""
+ """
+ Embeds images into patches and projects them into a specified embedding dimension.
+
+ Attributes:
+ patches_resolution (Tuple[int, int]): Resolution of the patches after embedding.
+ num_patches (int): Total number of patches.
+ in_chans (int): Number of input channels.
+ embed_dim (int): Dimension of the embedding.
+ seq (nn.Sequential): Sequence of convolutional and activation layers for patch embedding.
+
+ Methods:
+ forward: Processes the input tensor through the patch embedding sequence.
+
+ Examples:
+ >>> import torch
+ >>> patch_embed = PatchEmbed(in_chans=3, embed_dim=96, resolution=224, activation=nn.GELU)
+ >>> x = torch.randn(1, 3, 224, 224)
+ >>> output = patch_embed(x)
+ >>> print(output.shape)
+ """
def __init__(self, in_chans, embed_dim, resolution, activation):
- """Initialize the PatchMerging class with specified input, output dimensions, resolution and activation
- function.
- """
+ """Initializes patch embedding with convolutional layers for image-to-patch conversion and projection."""
super().__init__()
img_size: Tuple[int, int] = to_2tuple(resolution)
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
@@ -56,17 +97,40 @@ class PatchEmbed(nn.Module):
)
def forward(self, x):
- """Runs input tensor 'x' through the PatchMerging model's sequence of operations."""
+ """Processes input tensor through patch embedding sequence, converting images to patch embeddings."""
return self.seq(x)
class MBConv(nn.Module):
- """Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture."""
+ """
+ Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.
+
+ Attributes:
+ in_chans (int): Number of input channels.
+ hidden_chans (int): Number of hidden channels.
+ out_chans (int): Number of output channels.
+ conv1 (Conv2d_BN): First convolutional layer.
+ act1 (nn.Module): First activation function.
+ conv2 (Conv2d_BN): Depthwise convolutional layer.
+ act2 (nn.Module): Second activation function.
+ conv3 (Conv2d_BN): Final convolutional layer.
+ act3 (nn.Module): Third activation function.
+ drop_path (nn.Module): Drop path layer (Identity for inference).
+
+ Methods:
+ forward: Performs the forward pass through the MBConv layer.
+
+ Examples:
+ >>> in_chans, out_chans = 32, 64
+ >>> mbconv = MBConv(in_chans, out_chans, expand_ratio=4, activation=nn.ReLU, drop_path=0.1)
+ >>> x = torch.randn(1, in_chans, 56, 56)
+ >>> output = mbconv(x)
+ >>> print(output.shape)
+ torch.Size([1, 64, 56, 56])
+ """
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
- """Initializes a convolutional layer with specified dimensions, input resolution, depth, and activation
- function.
- """
+ """Initializes the MBConv layer with specified input/output channels, expansion ratio, and activation."""
super().__init__()
self.in_chans = in_chans
self.hidden_chans = int(in_chans * expand_ratio)
@@ -86,7 +150,7 @@ class MBConv(nn.Module):
self.drop_path = nn.Identity()
def forward(self, x):
- """Implements the forward pass for the model architecture."""
+ """Implements the forward pass of MBConv, applying convolutions and skip connection."""
shortcut = x
x = self.conv1(x)
x = self.act1(x)
@@ -99,12 +163,34 @@ class MBConv(nn.Module):
class PatchMerging(nn.Module):
- """Merges neighboring patches in the feature map and projects to a new dimension."""
+ """
+ Merges neighboring patches in the feature map and projects to a new dimension.
+
+ This class implements a patch merging operation that combines spatial information and adjusts the feature
+ dimension. It uses a series of convolutional layers with batch normalization to achieve this.
+
+ Attributes:
+ input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map.
+ dim (int): The input dimension of the feature map.
+ out_dim (int): The output dimension after merging and projection.
+ act (nn.Module): The activation function used between convolutions.
+ conv1 (Conv2d_BN): The first convolutional layer for dimension projection.
+ conv2 (Conv2d_BN): The second convolutional layer for spatial merging.
+ conv3 (Conv2d_BN): The third convolutional layer for final projection.
+
+ Methods:
+ forward: Applies the patch merging operation to the input tensor.
+
+ Examples:
+ >>> input_resolution = (56, 56)
+ >>> patch_merging = PatchMerging(input_resolution, dim=64, out_dim=128, activation=nn.ReLU)
+ >>> x = torch.randn(4, 64, 56, 56)
+ >>> output = patch_merging(x)
+ >>> print(output.shape)
+ """
def __init__(self, input_resolution, dim, out_dim, activation):
- """Initializes the ConvLayer with specific dimension, input resolution, depth, activation, drop path, and other
- optional parameters.
- """
+ """Initializes the PatchMerging module for merging and projecting neighboring patches in feature maps."""
super().__init__()
self.input_resolution = input_resolution
@@ -117,7 +203,7 @@ class PatchMerging(nn.Module):
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
def forward(self, x):
- """Applies forward pass on the input utilizing convolution and activation layers, and returns the result."""
+ """Applies patch merging and dimension projection to the input feature map."""
if x.ndim == 3:
H, W = self.input_resolution
B = len(x)
@@ -137,7 +223,24 @@ class ConvLayer(nn.Module):
"""
Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
- Optionally applies downsample operations to the output, and provides support for gradient checkpointing.
+ This layer optionally applies downsample operations to the output and supports gradient checkpointing.
+
+ Attributes:
+ dim (int): Dimensionality of the input and output.
+ input_resolution (Tuple[int, int]): Resolution of the input image.
+ depth (int): Number of MBConv layers in the block.
+ use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
+ blocks (nn.ModuleList): List of MBConv layers.
+ downsample (Optional[Callable]): Function for downsampling the output.
+
+ Methods:
+ forward: Processes the input through the convolutional layers.
+
+ Examples:
+ >>> input_tensor = torch.randn(1, 64, 56, 56)
+ >>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
+ >>> output = conv_layer(input_tensor)
+ >>> print(output.shape)
"""
def __init__(
@@ -155,16 +258,25 @@ class ConvLayer(nn.Module):
"""
Initializes the ConvLayer with the given dimensions and settings.
+ This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and
+ optionally applies downsampling to the output.
+
Args:
dim (int): The dimensionality of the input and output.
input_resolution (Tuple[int, int]): The resolution of the input image.
depth (int): The number of MBConv layers in the block.
activation (Callable): Activation function applied after each convolution.
- drop_path (Union[float, List[float]]): Drop path rate. Single float or a list of floats for each MBConv.
+ drop_path (float | List[float]): Drop path rate. Single float or a list of floats for each MBConv.
downsample (Optional[Callable]): Function for downsampling the output. None to skip downsampling.
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`.
conv_expand_ratio (float): Expansion ratio for the MBConv layers.
+
+ Examples:
+ >>> input_tensor = torch.randn(1, 64, 56, 56)
+ >>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
+ >>> output = conv_layer(input_tensor)
+ >>> print(output.shape)
"""
super().__init__()
self.dim = dim
@@ -194,7 +306,7 @@ class ConvLayer(nn.Module):
)
def forward(self, x):
- """Processes the input through a series of convolutional layers and returns the activated output."""
+ """Processes input through convolutional layers, applying MBConv blocks and optional downsampling."""
for blk in self.blocks:
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
return x if self.downsample is None else self.downsample(x)
@@ -202,13 +314,33 @@ class ConvLayer(nn.Module):
class Mlp(nn.Module):
"""
- Multi-layer Perceptron (MLP) for transformer architectures.
+ Multi-layer Perceptron (MLP) module for transformer architectures.
+
+ This module applies layer normalization, two fully-connected layers with an activation function in between,
+ and dropout. It is commonly used in transformer-based architectures.
- This layer takes an input with in_features, applies layer normalization and two fully-connected layers.
+ Attributes:
+ norm (nn.LayerNorm): Layer normalization applied to the input.
+ fc1 (nn.Linear): First fully-connected layer.
+ fc2 (nn.Linear): Second fully-connected layer.
+ act (nn.Module): Activation function applied after the first fully-connected layer.
+ drop (nn.Dropout): Dropout layer applied after the activation function.
+
+ Methods:
+ forward: Applies the MLP operations on the input tensor.
+
+ Examples:
+ >>> import torch
+ >>> from torch import nn
+ >>> mlp = Mlp(in_features=256, hidden_features=512, out_features=256, act_layer=nn.GELU, drop=0.1)
+ >>> x = torch.randn(32, 100, 256)
+ >>> output = mlp(x)
+ >>> print(output.shape)
+ torch.Size([32, 100, 256])
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
- """Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc."""
+ """Initializes a multi-layer perceptron with configurable input, hidden, and output dimensions."""
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
@@ -219,7 +351,7 @@ class Mlp(nn.Module):
self.drop = nn.Dropout(drop)
def forward(self, x):
- """Applies operations on input x and returns modified x, runs downsample if not None."""
+ """Applies MLP operations: layer norm, FC layers, activation, and dropout to the input tensor."""
x = self.norm(x)
x = self.fc1(x)
x = self.act(x)
@@ -230,12 +362,37 @@ class Mlp(nn.Module):
class Attention(torch.nn.Module):
"""
- Multi-head attention module with support for spatial awareness, applying attention biases based on spatial
- resolution. Implements trainable attention biases for each unique offset between spatial positions in the resolution
- grid.
+ Multi-head attention module with spatial awareness and trainable attention biases.
+
+ This module implements a multi-head attention mechanism with support for spatial awareness, applying
+ attention biases based on spatial resolution. It includes trainable attention biases for each unique
+ offset between spatial positions in the resolution grid.
Attributes:
- ab (Tensor, optional): Cached attention biases for inference, deleted during training.
+ num_heads (int): Number of attention heads.
+ scale (float): Scaling factor for attention scores.
+ key_dim (int): Dimensionality of the keys and queries.
+ nh_kd (int): Product of num_heads and key_dim.
+ d (int): Dimensionality of the value vectors.
+ dh (int): Product of d and num_heads.
+ attn_ratio (float): Attention ratio affecting the dimensions of the value vectors.
+ norm (nn.LayerNorm): Layer normalization applied to input.
+ qkv (nn.Linear): Linear layer for computing query, key, and value projections.
+ proj (nn.Linear): Linear layer for final projection.
+ attention_biases (nn.Parameter): Learnable attention biases.
+ attention_bias_idxs (Tensor): Indices for attention biases.
+ ab (Tensor): Cached attention biases for inference, deleted during training.
+
+ Methods:
+ train: Sets the module in training mode and handles the 'ab' attribute.
+ forward: Performs the forward pass of the attention mechanism.
+
+ Examples:
+ >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
+ >>> x = torch.randn(1, 196, 256)
+ >>> output = attn(x)
+ >>> print(output.shape)
+ torch.Size([1, 196, 256])
"""
def __init__(
@@ -247,17 +404,28 @@ class Attention(torch.nn.Module):
resolution=(14, 14),
):
"""
- Initializes the Attention module.
+ Initializes the Attention module for multi-head attention with spatial awareness.
+
+ This module implements a multi-head attention mechanism with support for spatial awareness, applying
+ attention biases based on spatial resolution. It includes trainable attention biases for each unique
+ offset between spatial positions in the resolution grid.
Args:
dim (int): The dimensionality of the input and output.
key_dim (int): The dimensionality of the keys and queries.
- num_heads (int, optional): Number of attention heads. Default is 8.
- attn_ratio (float, optional): Attention ratio, affecting the dimensions of the value vectors. Default is 4.
- resolution (Tuple[int, int], optional): Spatial resolution of the input feature map. Default is (14, 14).
+ num_heads (int): Number of attention heads. Default is 8.
+ attn_ratio (float): Attention ratio, affecting the dimensions of the value vectors. Default is 4.
+ resolution (Tuple[int, int]): Spatial resolution of the input feature map. Default is (14, 14).
Raises:
- AssertionError: If `resolution` is not a tuple of length 2.
+ AssertionError: If 'resolution' is not a tuple of length 2.
+
+ Examples:
+ >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
+ >>> x = torch.randn(1, 196, 256)
+ >>> output = attn(x)
+ >>> print(output.shape)
+ torch.Size([1, 196, 256])
"""
super().__init__()
@@ -290,7 +458,7 @@ class Attention(torch.nn.Module):
@torch.no_grad()
def train(self, mode=True):
- """Sets the module in training mode and handles attribute 'ab' based on the mode."""
+ """Performs multi-head attention with spatial awareness and trainable attention biases."""
super().train(mode)
if mode and hasattr(self, "ab"):
del self.ab
@@ -298,7 +466,7 @@ class Attention(torch.nn.Module):
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x): # x
- """Performs forward pass over the input tensor 'x' by applying normalization and querying keys/values."""
+ """Applies multi-head attention with spatial awareness and trainable attention biases."""
B, N, _ = x.shape # B, N, C
# Normalization
@@ -322,7 +490,34 @@ class Attention(torch.nn.Module):
class TinyViTBlock(nn.Module):
- """TinyViT Block that applies self-attention and a local convolution to the input."""
+ """
+ TinyViT Block that applies self-attention and a local convolution to the input.
+
+ This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
+ local convolutions to process input features efficiently.
+
+ Attributes:
+ dim (int): The dimensionality of the input and output.
+ input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
+ num_heads (int): Number of attention heads.
+ window_size (int): Size of the attention window.
+ mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
+ drop_path (nn.Module): Stochastic depth layer, identity function during inference.
+ attn (Attention): Self-attention module.
+ mlp (Mlp): Multi-layer perceptron module.
+ local_conv (Conv2d_BN): Depth-wise local convolution layer.
+
+ Methods:
+ forward: Processes the input through the TinyViT block.
+ extra_repr: Returns a string with extra information about the block's parameters.
+
+ Examples:
+ >>> input_tensor = torch.randn(1, 196, 192)
+ >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
+ >>> output = block(input_tensor)
+ >>> print(output.shape)
+ torch.Size([1, 196, 192])
+ """
def __init__(
self,
@@ -337,22 +532,32 @@ class TinyViTBlock(nn.Module):
activation=nn.GELU,
):
"""
- Initializes the TinyViTBlock.
+ Initializes a TinyViT block with self-attention and local convolution.
+
+ This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
+ local convolutions to process input features efficiently.
Args:
- dim (int): The dimensionality of the input and output.
- input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
+ dim (int): Dimensionality of the input and output features.
+ input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width).
num_heads (int): Number of attention heads.
- window_size (int, optional): Window size for attention. Default is 7.
- mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4.
- drop (float, optional): Dropout rate. Default is 0.
- drop_path (float, optional): Stochastic depth rate. Default is 0.
- local_conv_size (int, optional): The kernel size of the local convolution. Default is 3.
- activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU.
+ window_size (int): Size of the attention window. Must be greater than 0.
+ mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
+ drop (float): Dropout rate.
+ drop_path (float): Stochastic depth rate.
+ local_conv_size (int): Kernel size of the local convolution.
+ activation (torch.nn.Module): Activation function for MLP.
Raises:
- AssertionError: If `window_size` is not greater than 0.
- AssertionError: If `dim` is not divisible by `num_heads`.
+ AssertionError: If window_size is not greater than 0.
+ AssertionError: If dim is not divisible by num_heads.
+
+ Examples:
+ >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
+ >>> input_tensor = torch.randn(1, 196, 192)
+ >>> output = block(input_tensor)
+ >>> print(output.shape)
+ torch.Size([1, 196, 192])
"""
super().__init__()
self.dim = dim
@@ -380,9 +585,7 @@ class TinyViTBlock(nn.Module):
self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
def forward(self, x):
- """Applies attention-based transformation or padding to input 'x' before passing it through a local
- convolution.
- """
+ """Applies self-attention, local convolution, and MLP operations to the input tensor."""
h, w = self.input_resolution
b, hw, c = x.shape # batch, height*width, channels
assert hw == h * w, "input feature has wrong size"
@@ -424,8 +627,19 @@ class TinyViTBlock(nn.Module):
return x + self.drop_path(self.mlp(x))
def extra_repr(self) -> str:
- """Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of
- attentions heads, window size, and MLP ratio.
+ """
+ Returns a string representation of the TinyViTBlock's parameters.
+
+ This method provides a formatted string containing key information about the TinyViTBlock, including its
+ dimension, input resolution, number of attention heads, window size, and MLP ratio.
+
+ Returns:
+ (str): A formatted string containing the block's parameters.
+
+ Examples:
+ >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0)
+ >>> print(block.extra_repr())
+ dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0
"""
return (
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
@@ -434,7 +648,31 @@ class TinyViTBlock(nn.Module):
class BasicLayer(nn.Module):
- """A basic TinyViT layer for one stage in a TinyViT architecture."""
+ """
+ A basic TinyViT layer for one stage in a TinyViT architecture.
+
+ This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks
+ and an optional downsampling operation.
+
+ Attributes:
+ dim (int): The dimensionality of the input and output features.
+ input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
+ depth (int): Number of TinyViT blocks in this layer.
+ use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
+ blocks (nn.ModuleList): List of TinyViT blocks that make up this layer.
+ downsample (nn.Module | None): Downsample layer at the end of the layer, if specified.
+
+ Methods:
+ forward: Processes the input through the layer's blocks and optional downsampling.
+ extra_repr: Returns a string with the layer's parameters for printing.
+
+ Examples:
+ >>> input_tensor = torch.randn(1, 3136, 192)
+ >>> layer = BasicLayer(dim=192, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
+ >>> output = layer(input_tensor)
+ >>> print(output.shape)
+ torch.Size([1, 784, 384])
+ """
def __init__(
self,
@@ -453,25 +691,34 @@ class BasicLayer(nn.Module):
out_dim=None,
):
"""
- Initializes the BasicLayer.
+ Initializes a BasicLayer in the TinyViT architecture.
+
+ This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to
+ process feature maps at a specific resolution and dimensionality within the TinyViT model.
Args:
- dim (int): The dimensionality of the input and output.
- input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
- depth (int): Number of TinyViT blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4.
- drop (float, optional): Dropout rate. Default is 0.
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default is 0.
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default is None.
- use_checkpoint (bool, optional): Whether to use checkpointing to save memory. Default is False.
- local_conv_size (int, optional): Kernel size of the local convolution. Default is 3.
- activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU.
- out_dim (int | None, optional): The output dimension of the layer. Default is None.
+ dim (int): Dimensionality of the input and output features.
+ input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width).
+ depth (int): Number of TinyViT blocks in this layer.
+ num_heads (int): Number of attention heads in each TinyViT block.
+ window_size (int): Size of the local window for attention computation.
+ mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
+ drop (float): Dropout rate.
+ drop_path (float | List[float]): Stochastic depth rate. Can be a float or a list of floats for each block.
+ downsample (nn.Module | None): Downsampling layer at the end of the layer. None to skip downsampling.
+ use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
+ local_conv_size (int): Kernel size for the local convolution in each TinyViT block.
+ activation (nn.Module): Activation function used in the MLP.
+ out_dim (int | None): Output dimension after downsampling. None means it will be the same as `dim`.
Raises:
- ValueError: If `drop_path` is a list of float but its length doesn't match `depth`.
+ ValueError: If `drop_path` is a list and its length doesn't match `depth`.
+
+ Examples:
+ >>> layer = BasicLayer(dim=96, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
+ >>> x = torch.randn(1, 56*56, 96)
+ >>> output = layer(x)
+ >>> print(output.shape)
"""
super().__init__()
self.dim = dim
@@ -505,58 +752,49 @@ class BasicLayer(nn.Module):
)
def forward(self, x):
- """Performs forward propagation on the input tensor and returns a normalized tensor."""
+ """Processes input through TinyViT blocks and optional downsampling."""
for blk in self.blocks:
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
return x if self.downsample is None else self.downsample(x)
def extra_repr(self) -> str:
- """Returns a string representation of the extra_repr function with the layer's parameters."""
+ """Returns a string with the layer's parameters for printing."""
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-class LayerNorm2d(nn.Module):
- """A PyTorch implementation of Layer Normalization in 2D."""
-
- def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
- """Initialize LayerNorm2d with the number of channels and an optional epsilon."""
- super().__init__()
- self.weight = nn.Parameter(torch.ones(num_channels))
- self.bias = nn.Parameter(torch.zeros(num_channels))
- self.eps = eps
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Perform a forward pass, normalizing the input tensor."""
- u = x.mean(1, keepdim=True)
- s = (x - u).pow(2).mean(1, keepdim=True)
- x = (x - u) / torch.sqrt(s + self.eps)
- return self.weight[:, None, None] * x + self.bias[:, None, None]
-
-
class TinyViT(nn.Module):
"""
- The TinyViT architecture for vision tasks.
+ TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.
+
+ This class implements the TinyViT model, which combines elements of vision transformers and convolutional
+ neural networks for improved efficiency and performance on vision tasks.
Attributes:
img_size (int): Input image size.
- in_chans (int): Number of input channels.
num_classes (int): Number of classification classes.
- embed_dims (List[int]): List of embedding dimensions for each layer.
- depths (List[int]): List of depths for each layer.
- num_heads (List[int]): List of number of attention heads for each layer.
- window_sizes (List[int]): List of window sizes for each layer.
+ depths (List[int]): Number of blocks in each stage.
+ num_layers (int): Total number of layers in the network.
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
- drop_rate (float): Dropout rate for drop layers.
- drop_path_rate (float): Drop path rate for stochastic depth.
- use_checkpoint (bool): Use checkpointing for efficient memory usage.
- mbconv_expand_ratio (float): Expansion ratio for MBConv layer.
- local_conv_size (int): Local convolution kernel size.
- layer_lr_decay (float): Layer-wise learning rate decay.
-
- Note:
- This implementation is generalized to accept a list of depths, attention heads,
- embedding dimensions and window sizes, which allows you to create a
- "stack" of TinyViT models of varying configurations.
+ patch_embed (PatchEmbed): Module for patch embedding.
+ patches_resolution (Tuple[int, int]): Resolution of embedded patches.
+ layers (nn.ModuleList): List of network layers.
+ norm_head (nn.LayerNorm): Layer normalization for the classifier head.
+ head (nn.Linear): Linear layer for final classification.
+ neck (nn.Sequential): Neck module for feature refinement.
+
+ Methods:
+ set_layer_lr_decay: Sets layer-wise learning rate decay.
+ _init_weights: Initializes weights for linear and normalization layers.
+ no_weight_decay_keywords: Returns keywords for parameters that should not use weight decay.
+ forward_features: Processes input through the feature extraction layers.
+ forward: Performs a forward pass through the entire network.
+
+ Examples:
+ >>> model = TinyViT(img_size=224, num_classes=1000)
+ >>> x = torch.randn(1, 3, 224, 224)
+ >>> features = model.forward_features(x)
+ >>> print(features.shape)
+ torch.Size([1, 256, 64, 64])
"""
def __init__(
@@ -579,21 +817,33 @@ class TinyViT(nn.Module):
"""
Initializes the TinyViT model.
+ This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of
+ attention and convolution blocks, and a classification head.
+
Args:
- img_size (int, optional): The input image size. Defaults to 224.
- in_chans (int, optional): Number of input channels. Defaults to 3.
- num_classes (int, optional): Number of classification classes. Defaults to 1000.
- embed_dims (List[int], optional): List of embedding dimensions per layer. Defaults to [96, 192, 384, 768].
- depths (List[int], optional): List of depths for each layer. Defaults to [2, 2, 6, 2].
- num_heads (List[int], optional): List of number of attention heads per layer. Defaults to [3, 6, 12, 24].
- window_sizes (List[int], optional): List of window sizes for each layer. Defaults to [7, 7, 14, 7].
- mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension. Defaults to 4.
- drop_rate (float, optional): Dropout rate. Defaults to 0.
- drop_path_rate (float, optional): Drop path rate for stochastic depth. Defaults to 0.1.
- use_checkpoint (bool, optional): Whether to use checkpointing for efficient memory usage. Defaults to False.
- mbconv_expand_ratio (float, optional): Expansion ratio for MBConv layer. Defaults to 4.0.
- local_conv_size (int, optional): Local convolution kernel size. Defaults to 3.
- layer_lr_decay (float, optional): Layer-wise learning rate decay. Defaults to 1.0.
+ img_size (int): Size of the input image. Default is 224.
+ in_chans (int): Number of input channels. Default is 3.
+ num_classes (int): Number of classes for classification. Default is 1000.
+ embed_dims (Tuple[int, int, int, int]): Embedding dimensions for each stage.
+ Default is (96, 192, 384, 768).
+ depths (Tuple[int, int, int, int]): Number of blocks in each stage. Default is (2, 2, 6, 2).
+ num_heads (Tuple[int, int, int, int]): Number of attention heads in each stage.
+ Default is (3, 6, 12, 24).
+ window_sizes (Tuple[int, int, int, int]): Window sizes for each stage. Default is (7, 7, 14, 7).
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. Default is 4.0.
+ drop_rate (float): Dropout rate. Default is 0.0.
+ drop_path_rate (float): Stochastic depth rate. Default is 0.1.
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default is False.
+ mbconv_expand_ratio (float): Expansion ratio for MBConv layer. Default is 4.0.
+ local_conv_size (int): Kernel size for local convolutions. Default is 3.
+ layer_lr_decay (float): Layer-wise learning rate decay factor. Default is 1.0.
+
+ Examples:
+ >>> model = TinyViT(img_size=224, num_classes=1000)
+ >>> x = torch.randn(1, 3, 224, 224)
+ >>> output = model(x)
+ >>> print(output.shape)
+ torch.Size([1, 1000])
"""
super().__init__()
self.img_size = img_size
@@ -671,7 +921,7 @@ class TinyViT(nn.Module):
)
def set_layer_lr_decay(self, layer_lr_decay):
- """Sets the learning rate decay for each layer in the TinyViT model."""
+ """Sets layer-wise learning rate decay for the TinyViT model based on depth."""
decay_rate = layer_lr_decay
# Layers -> blocks (depth)
@@ -706,7 +956,7 @@ class TinyViT(nn.Module):
self.apply(_check_lr_scale)
def _init_weights(self, m):
- """Initializes weights for linear layers and layer normalization in the given module."""
+ """Initializes weights for linear and normalization layers in the TinyViT model."""
if isinstance(m, nn.Linear):
# NOTE: This initialization is needed only for training.
# trunc_normal_(m.weight, std=.02)
@@ -718,11 +968,11 @@ class TinyViT(nn.Module):
@torch.jit.ignore
def no_weight_decay_keywords(self):
- """Returns a dictionary of parameter names where weight decay should not be applied."""
+ """Returns a set of keywords for parameters that should not use weight decay."""
return {"attention_biases"}
def forward_features(self, x):
- """Runs the input through the model layers and returns the transformed output."""
+ """Processes input through feature extraction layers, returning spatial features."""
x = self.patch_embed(x) # x input is (N, C, H, W)
x = self.layers[0](x)
@@ -737,5 +987,5 @@ class TinyViT(nn.Module):
return self.neck(x)
def forward(self, x):
- """Executes a forward pass on the input tensor through the constructed model layers."""
+ """Performs the forward pass through the TinyViT model, extracting features from the input image."""
return self.forward_features(x)
diff --git a/ultralytics/models/sam/modules/transformer.py b/ultralytics/models/sam/modules/transformer.py
index 6375c2adc..b8c408193 100644
--- a/ultralytics/models/sam/modules/transformer.py
+++ b/ultralytics/models/sam/modules/transformer.py
@@ -11,19 +11,31 @@ from ultralytics.nn.modules import MLPBlock
class TwoWayTransformer(nn.Module):
"""
- A Two-Way Transformer module that enables the simultaneous attention to both image and query points. This class
- serves as a specialized transformer decoder that attends to an input image using queries whose positional embedding
- is supplied. This is particularly useful for tasks like object detection, image segmentation, and point cloud
- processing.
+ A Two-Way Transformer module for simultaneous attention to image and query points.
+
+ This class implements a specialized transformer decoder that attends to an input image using queries with
+ supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point
+ cloud processing.
Attributes:
- depth (int): The number of layers in the transformer.
- embedding_dim (int): The channel dimension for the input embeddings.
- num_heads (int): The number of heads for multihead attention.
- mlp_dim (int): The internal channel dimension for the MLP block.
- layers (nn.ModuleList): The list of TwoWayAttentionBlock layers that make up the transformer.
- final_attn_token_to_image (Attention): The final attention layer applied from the queries to the image.
- norm_final_attn (nn.LayerNorm): The layer normalization applied to the final queries.
+ depth (int): Number of layers in the transformer.
+ embedding_dim (int): Channel dimension for input embeddings.
+ num_heads (int): Number of heads for multihead attention.
+ mlp_dim (int): Internal channel dimension for the MLP block.
+ layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer.
+ final_attn_token_to_image (Attention): Final attention layer from queries to image.
+ norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
+
+ Methods:
+ forward: Processes image and point embeddings through the transformer.
+
+ Examples:
+ >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
+ >>> image_embedding = torch.randn(1, 256, 32, 32)
+ >>> image_pe = torch.randn(1, 256, 32, 32)
+ >>> point_embedding = torch.randn(1, 100, 256)
+ >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
+ >>> print(output_queries.shape, output_image.shape)
"""
def __init__(
@@ -36,15 +48,33 @@ class TwoWayTransformer(nn.Module):
attention_downsample_rate: int = 2,
) -> None:
"""
- A transformer decoder that attends to an input image using queries whose positional embedding is supplied.
+ Initialize a Two-Way Transformer for simultaneous attention to image and query points.
Args:
- depth (int): number of layers in the transformer
- embedding_dim (int): the channel dimension for the input embeddings
- num_heads (int): the number of heads for multihead attention. Must
- divide embedding_dim
- mlp_dim (int): the channel dimension internal to the MLP block
- activation (nn.Module): the activation to use in the MLP block
+ depth (int): Number of layers in the transformer.
+ embedding_dim (int): Channel dimension for input embeddings.
+ num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
+ mlp_dim (int): Internal channel dimension for the MLP block.
+ activation (Type[nn.Module]): Activation function to use in the MLP block.
+ attention_downsample_rate (int): Downsampling rate for attention mechanism.
+
+ Attributes:
+ depth (int): Number of layers in the transformer.
+ embedding_dim (int): Channel dimension for input embeddings.
+ embedding_dim (int): Channel dimension for input embeddings.
+ num_heads (int): Number of heads for multihead attention.
+ mlp_dim (int): Internal channel dimension for the MLP block.
+ layers (nn.ModuleList): List of TwoWayAttentionBlock layers.
+ final_attn_token_to_image (Attention): Final attention layer from queries to image.
+ norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
+
+ Examples:
+ >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
+ >>> image_embedding = torch.randn(1, 256, 32, 32)
+ >>> image_pe = torch.randn(1, 256, 32, 32)
+ >>> point_embedding = torch.randn(1, 100, 256)
+ >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
+ >>> print(output_queries.shape, output_image.shape)
"""
super().__init__()
self.depth = depth
@@ -75,15 +105,23 @@ class TwoWayTransformer(nn.Module):
point_embedding: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
+ Processes image and point embeddings through the Two-Way Transformer.
+
Args:
- image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w.
- image_pe (torch.Tensor): the positional encoding to add to the image. Must have same shape as image_embedding.
- point_embedding (torch.Tensor): the embedding to add to the query points.
- Must have shape B x N_points x embedding_dim for any N_points.
+ image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
+ image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding.
+ point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).
Returns:
- (torch.Tensor): the processed point_embedding
- (torch.Tensor): the processed image_embedding
+ (Tuple[torch.Tensor, torch.Tensor]): Processed point_embedding and image_embedding.
+
+ Examples:
+ >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
+ >>> image_embedding = torch.randn(1, 256, 32, 32)
+ >>> image_pe = torch.randn(1, 256, 32, 32)
+ >>> point_embedding = torch.randn(1, 100, 256)
+ >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
+ >>> print(output_queries.shape, output_image.shape)
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
@@ -114,21 +152,34 @@ class TwoWayTransformer(nn.Module):
class TwoWayAttentionBlock(nn.Module):
"""
- An attention block that performs both self-attention and cross-attention in two directions: queries to keys and
- keys to queries. This block consists of four main layers: (1) self-attention on sparse inputs, (2) cross-attention
- of sparse inputs to dense inputs, (3) an MLP block on sparse inputs, and (4) cross-attention of dense inputs to
- sparse inputs.
+ A two-way attention block for simultaneous attention to image and query points.
+
+ This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,
+ cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense
+ inputs to sparse inputs.
Attributes:
- self_attn (Attention): The self-attention layer for the queries.
- norm1 (nn.LayerNorm): Layer normalization following the first attention block.
+ self_attn (Attention): Self-attention layer for queries.
+ norm1 (nn.LayerNorm): Layer normalization after self-attention.
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
- norm2 (nn.LayerNorm): Layer normalization following the second attention block.
- mlp (MLPBlock): MLP block that transforms the query embeddings.
- norm3 (nn.LayerNorm): Layer normalization following the MLP block.
- norm4 (nn.LayerNorm): Layer normalization following the third attention block.
+ norm2 (nn.LayerNorm): Layer normalization after token-to-image attention.
+ mlp (MLPBlock): MLP block for transforming query embeddings.
+ norm3 (nn.LayerNorm): Layer normalization after MLP block.
+ norm4 (nn.LayerNorm): Layer normalization after image-to-token attention.
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
- skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
+ skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
+
+ Methods:
+ forward: Applies self-attention and cross-attention to queries and keys.
+
+ Examples:
+ >>> embedding_dim, num_heads = 256, 8
+ >>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
+ >>> queries = torch.randn(1, 100, embedding_dim)
+ >>> keys = torch.randn(1, 1000, embedding_dim)
+ >>> query_pe = torch.randn(1, 100, embedding_dim)
+ >>> key_pe = torch.randn(1, 1000, embedding_dim)
+ >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
"""
def __init__(
@@ -141,16 +192,28 @@ class TwoWayAttentionBlock(nn.Module):
skip_first_layer_pe: bool = False,
) -> None:
"""
- A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse
- inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse
- inputs.
+ Initializes a TwoWayAttentionBlock for simultaneous attention to image and query points.
+
+ This block implements a specialized transformer layer with four main components: self-attention on sparse
+ inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
+ of dense inputs to sparse inputs.
Args:
- embedding_dim (int): the channel dimension of the embeddings
- num_heads (int): the number of heads in the attention layers
- mlp_dim (int): the hidden dimension of the mlp block
- activation (nn.Module): the activation of the mlp block
- skip_first_layer_pe (bool): skip the PE on the first layer
+ embedding_dim (int): Channel dimension of the embeddings.
+ num_heads (int): Number of attention heads in the attention layers.
+ mlp_dim (int): Hidden dimension of the MLP block.
+ activation (Type[nn.Module]): Activation function for the MLP block.
+ attention_downsample_rate (int): Downsampling rate for the attention mechanism.
+ skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
+
+ Examples:
+ >>> embedding_dim, num_heads = 256, 8
+ >>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
+ >>> queries = torch.randn(1, 100, embedding_dim)
+ >>> keys = torch.randn(1, 1000, embedding_dim)
+ >>> query_pe = torch.randn(1, 100, embedding_dim)
+ >>> key_pe = torch.randn(1, 1000, embedding_dim)
+ >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
"""
super().__init__()
self.self_attn = Attention(embedding_dim, num_heads)
@@ -168,7 +231,7 @@ class TwoWayAttentionBlock(nn.Module):
self.skip_first_layer_pe = skip_first_layer_pe
def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
- """Apply self-attention and cross-attention to queries and keys and return the processed embeddings."""
+ """Applies two-way attention to process query and key embeddings in a transformer block."""
# Self attention block
if self.skip_first_layer_pe:
@@ -202,8 +265,34 @@ class TwoWayAttentionBlock(nn.Module):
class Attention(nn.Module):
- """An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
- values.
+ """
+ An attention layer with downscaling capability for embedding size after projection.
+
+ This class implements a multi-head attention mechanism with the option to downsample the internal
+ dimension of queries, keys, and values.
+
+ Attributes:
+ embedding_dim (int): Dimensionality of input embeddings.
+ kv_in_dim (int): Dimensionality of key and value inputs.
+ internal_dim (int): Internal dimension after downsampling.
+ num_heads (int): Number of attention heads.
+ q_proj (nn.Linear): Linear projection for queries.
+ k_proj (nn.Linear): Linear projection for keys.
+ v_proj (nn.Linear): Linear projection for values.
+ out_proj (nn.Linear): Linear projection for output.
+
+ Methods:
+ _separate_heads: Separates input tensor into attention heads.
+ _recombine_heads: Recombines separated attention heads.
+ forward: Computes attention output for given query, key, and value tensors.
+
+ Examples:
+ >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
+ >>> q = torch.randn(1, 100, 256)
+ >>> k = v = torch.randn(1, 50, 256)
+ >>> output = attn(q, k, v)
+ >>> print(output.shape)
+ torch.Size([1, 100, 256])
"""
def __init__(
@@ -214,15 +303,27 @@ class Attention(nn.Module):
kv_in_dim: int = None,
) -> None:
"""
- Initializes the Attention model with the given dimensions and settings.
+ Initializes the Attention module with specified dimensions and settings.
+
+ This class implements a multi-head attention mechanism with optional downsampling of the internal
+ dimension for queries, keys, and values.
Args:
- embedding_dim (int): The dimensionality of the input embeddings.
- num_heads (int): The number of attention heads.
- downsample_rate (int, optional): The factor by which the internal dimensions are downsampled. Defaults to 1.
+ embedding_dim (int): Dimensionality of input embeddings.
+ num_heads (int): Number of attention heads.
+ downsample_rate (int): Factor by which internal dimensions are downsampled. Defaults to 1.
+ kv_in_dim (int | None): Dimensionality of key and value inputs. If None, uses embedding_dim.
Raises:
- AssertionError: If 'num_heads' does not evenly divide the internal dim (embedding_dim / downsample_rate).
+ AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).
+
+ Examples:
+ >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
+ >>> q = torch.randn(1, 100, 256)
+ >>> k = v = torch.randn(1, 50, 256)
+ >>> output = attn(q, k, v)
+ >>> print(output.shape)
+ torch.Size([1, 100, 256])
"""
super().__init__()
self.embedding_dim = embedding_dim
@@ -238,20 +339,20 @@ class Attention(nn.Module):
@staticmethod
def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
- """Separate the input tensor into the specified number of attention heads."""
+ """Separates the input tensor into the specified number of attention heads."""
b, n, c = x.shape
x = x.reshape(b, n, num_heads, c // num_heads)
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
@staticmethod
def _recombine_heads(x: Tensor) -> Tensor:
- """Recombine the separated attention heads into a single tensor."""
+ """Recombines separated attention heads into a single tensor."""
b, n_heads, n_tokens, c_per_head = x.shape
x = x.transpose(1, 2)
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
- """Compute the attention output given the input query, key, and value tensors."""
+ """Applies multi-head attention to query, key, and value tensors with optional downsampling."""
# Input projections
q = self.q_proj(q)
diff --git a/ultralytics/models/sam2/modules/utils.py b/ultralytics/models/sam/modules/utils.py
similarity index 65%
rename from ultralytics/models/sam2/modules/utils.py
rename to ultralytics/models/sam/modules/utils.py
index b09dd9b2f..debc62e8e 100644
--- a/ultralytics/models/sam2/modules/utils.py
+++ b/ultralytics/models/sam/modules/utils.py
@@ -1,5 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
+from typing import Tuple
+
import torch
import torch.nn.functional as F
@@ -70,7 +72,7 @@ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
def init_t_xy(end_x: int, end_y: int):
- """Initializes 1D and 2D coordinate tensors for a grid of size end_x by end_y."""
+ """Initializes 1D and 2D coordinate tensors for a grid of specified dimensions."""
t = torch.arange(end_x * end_y, dtype=torch.float32)
t_x = (t % end_x).float()
t_y = torch.div(t, end_x, rounding_mode="floor").float()
@@ -78,7 +80,7 @@ def init_t_xy(end_x: int, end_y: int):
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
- """Computes axial complex exponential positional encodings for 2D spatial positions."""
+ """Computes axial complex exponential positional encodings for 2D spatial positions in a grid."""
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
@@ -91,7 +93,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
- """Reshapes frequency tensor for broadcasting, ensuring compatibility with input tensor dimensions."""
+ """Reshapes frequency tensor for broadcasting with input tensor, ensuring dimensional compatibility."""
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
@@ -189,3 +191,103 @@ def window_unpartition(windows, window_size, pad_hw, hw):
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
+
+
+def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
+ """
+ Extracts relative positional embeddings based on query and key sizes.
+
+ Args:
+ q_size (int): Size of the query.
+ k_size (int): Size of the key.
+ rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative
+ distance and C is the embedding dimension.
+
+ Returns:
+ (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,
+ k_size, C).
+
+ Examples:
+ >>> q_size, k_size = 8, 16
+ >>> rel_pos = torch.randn(31, 64) # 31 = 2 * max(8, 16) - 1
+ >>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos)
+ >>> print(extracted_pos.shape)
+ torch.Size([8, 16, 64])
+ """
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
+ # Interpolate rel pos if needed.
+ if rel_pos.shape[0] != max_rel_dist:
+ # Interpolate rel pos.
+ rel_pos_resized = F.interpolate(
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
+ size=max_rel_dist,
+ mode="linear",
+ )
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
+ else:
+ rel_pos_resized = rel_pos
+
+ # Scale the coords with short length if shapes for q and k are different.
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
+
+ return rel_pos_resized[relative_coords.long()]
+
+
+def add_decomposed_rel_pos(
+ attn: torch.Tensor,
+ q: torch.Tensor,
+ rel_pos_h: torch.Tensor,
+ rel_pos_w: torch.Tensor,
+ q_size: Tuple[int, int],
+ k_size: Tuple[int, int],
+) -> torch.Tensor:
+ """
+ Adds decomposed Relative Positional Embeddings to the attention map.
+
+ This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
+ paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
+ positions.
+
+ Args:
+ attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w).
+ q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C).
+ rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C).
+ rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C).
+ q_size (Tuple[int, int]): Spatial sequence size of query q as (q_h, q_w).
+ k_size (Tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
+
+ Returns:
+ (torch.Tensor): Updated attention map with added relative positional embeddings, shape
+ (B, q_h * q_w, k_h * k_w).
+
+ Examples:
+ >>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8
+ >>> attn = torch.rand(B, q_h * q_w, k_h * k_w)
+ >>> q = torch.rand(B, q_h * q_w, C)
+ >>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C)
+ >>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C)
+ >>> q_size, k_size = (q_h, q_w), (k_h, k_w)
+ >>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)
+ >>> print(updated_attn.shape)
+ torch.Size([1, 64, 64])
+
+ References:
+ https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py
+ """
+ q_h, q_w = q_size
+ k_h, k_w = k_size
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
+
+ B, _, dim = q.shape
+ r_q = q.reshape(B, q_h, q_w, dim)
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
+
+ attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
+ B, q_h * q_w, k_h * k_w
+ )
+
+ return attn
diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py
index e03ba6256..034133514 100644
--- a/ultralytics/models/sam/predict.py
+++ b/ultralytics/models/sam/predict.py
@@ -34,35 +34,64 @@ from .build import build_sam
class Predictor(BasePredictor):
"""
- Predictor class for the Segment Anything Model (SAM), extending BasePredictor.
+ Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.
- The class provides an interface for model inference tailored to image segmentation tasks.
- With advanced architecture and promptable segmentation capabilities, it facilitates flexible and real-time
- mask generation. The class is capable of working with various types of prompts such as bounding boxes,
- points, and low-resolution masks.
+ This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image
+ segmentation tasks. It supports various input prompts like points, bounding boxes, and masks for
+ fine-grained control over segmentation results.
Attributes:
- cfg (dict): Configuration dictionary specifying model and task-related parameters.
- overrides (dict): Dictionary containing values that override the default configuration.
- _callbacks (dict): Dictionary of user-defined callback functions to augment behavior.
- args (namespace): Namespace to hold command-line arguments or other operational variables.
- im (torch.Tensor): Preprocessed input image tensor.
- features (torch.Tensor): Extracted image features used for inference.
- prompts (dict): Collection of various prompt types, such as bounding boxes and points.
- segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones.
+ args (SimpleNamespace): Configuration arguments for the predictor.
+ model (torch.nn.Module): The loaded SAM model.
+ device (torch.device): The device (CPU or GPU) on which the model is loaded.
+ im (torch.Tensor): The preprocessed input image.
+ features (torch.Tensor): Extracted image features.
+ prompts (Dict): Dictionary to store various types of prompts (e.g., bboxes, points, masks).
+ segment_all (bool): Flag to indicate if full image segmentation should be performed.
+ mean (torch.Tensor): Mean values for image normalization.
+ std (torch.Tensor): Standard deviation values for image normalization.
+
+ Methods:
+ preprocess: Prepares input images for model inference.
+ pre_transform: Performs initial transformations on the input image.
+ inference: Performs segmentation inference based on input prompts.
+ prompt_inference: Internal function for prompt-based segmentation inference.
+ generate: Generates segmentation masks for an entire image.
+ setup_model: Initializes the SAM model for inference.
+ get_model: Builds and returns a SAM model.
+ postprocess: Post-processes model outputs to generate final results.
+ setup_source: Sets up the data source for inference.
+ set_image: Sets and preprocesses a single image for inference.
+ get_im_features: Extracts image features using the SAM image encoder.
+ set_prompts: Sets prompts for subsequent inference.
+ reset_image: Resets the current image and its features.
+ remove_small_regions: Removes small disconnected regions and holes from masks.
+
+ Examples:
+ >>> predictor = Predictor()
+ >>> predictor.setup_model(model_path='sam_model.pt')
+ >>> predictor.set_image('image.jpg')
+ >>> masks, scores, boxes = predictor.generate()
+ >>> results = predictor.postprocess((masks, scores, boxes), im, orig_img)
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize the Predictor with configuration, overrides, and callbacks.
- The method sets up the Predictor object and applies any configuration overrides or callbacks provided. It
- initializes task-specific settings for SAM, such as retina_masks being set to True for optimal results.
+ Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or
+ callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True
+ for optimal results.
Args:
- cfg (dict): Configuration dictionary.
- overrides (dict, optional): Dictionary of values to override default configuration.
- _callbacks (dict, optional): Dictionary of callback functions to customize behavior.
+ cfg (Dict): Configuration dictionary containing default settings.
+ overrides (Dict | None): Dictionary of values to override default configuration.
+ _callbacks (Dict | None): Dictionary of callback functions to customize behavior.
+
+ Examples:
+ >>> predictor = Predictor(cfg=DEFAULT_CFG)
+ >>> predictor = Predictor(overrides={'imgsz': 640})
+ >>> predictor = Predictor(_callbacks={'on_predict_start': custom_callback})
"""
if overrides is None:
overrides = {}
@@ -78,14 +107,19 @@ class Predictor(BasePredictor):
"""
Preprocess the input image for model inference.
- The method prepares the input image by applying transformations and normalization.
- It supports both torch.Tensor and list of np.ndarray as input formats.
+ This method prepares the input image by applying transformations and normalization. It supports both
+ torch.Tensor and list of np.ndarray as input formats.
Args:
- im (torch.Tensor | List[np.ndarray]): BCHW tensor format or list of HWC numpy arrays.
+ im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.
Returns:
- (torch.Tensor): The preprocessed image tensor.
+ (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
+
+ Examples:
+ >>> predictor = Predictor()
+ >>> image = torch.rand(1, 3, 640, 640)
+ >>> preprocessed_image = predictor.preprocess(image)
"""
if self.im is not None:
return self.im
@@ -106,14 +140,24 @@ class Predictor(BasePredictor):
"""
Perform initial transformations on the input image for preprocessing.
- The method applies transformations such as resizing to prepare the image for further preprocessing.
+ This method applies transformations such as resizing to prepare the image for further preprocessing.
Currently, batched inference is not supported; hence the list length should be 1.
Args:
- im (List[np.ndarray]): List containing images in HWC numpy array format.
+ im (List[np.ndarray]): List containing a single image in HWC numpy array format.
Returns:
- (List[np.ndarray]): List of transformed images.
+ (List[np.ndarray]): List containing the transformed image.
+
+ Raises:
+ AssertionError: If the input list contains more than one image.
+
+ Examples:
+ >>> predictor = Predictor()
+ >>> image = np.random.rand(480, 640, 3) # Single HWC image
+ >>> transformed = predictor.pre_transform([image])
+ >>> print(len(transformed))
+ 1
"""
assert len(im) == 1, "SAM model does not currently support batched inference"
letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
@@ -121,23 +165,32 @@ class Predictor(BasePredictor):
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
"""
- Perform image segmentation inference based on the given input cues, using the currently loaded image. This
- method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
- mask decoder for real-time and promptable segmentation tasks.
+ Perform image segmentation inference based on the given input cues, using the currently loaded image.
+
+ This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt
+ encoder, and mask decoder for real-time and promptable segmentation tasks.
Args:
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
- bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
- points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
- labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
- masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
- multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts.
+ bboxes (np.ndarray | List | None): Bounding boxes with shape (N, 4), in XYXY format.
+ points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
+ labels (np.ndarray | List | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background.
+ masks (np.ndarray | None): Low-resolution masks from previous predictions, shape (N, H, W). For SAM H=W=256.
+ multimask_output (bool): Flag to return multiple masks. Helpful for ambiguous prompts.
+ *args (Any): Additional positional arguments.
+ **kwargs (Any): Additional keyword arguments.
Returns:
- (tuple): Contains the following three elements.
- - np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
+ (tuple): Contains the following three elements:
+ - np.ndarray: The output masks in shape (C, H, W), where C is the number of generated masks.
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
- - np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
+ - np.ndarray: Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
+
+ Examples:
+ >>> predictor = Predictor()
+ >>> predictor.setup_model(model_path='sam_model.pt')
+ >>> predictor.set_image('image.jpg')
+ >>> masks, scores, logits = predictor.inference(im, bboxes=[[0, 0, 100, 100]])
"""
# Override prompts if any stored in self.prompts
bboxes = self.prompts.pop("bboxes", bboxes)
@@ -151,22 +204,30 @@ class Predictor(BasePredictor):
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
"""
- Internal function for image segmentation inference based on cues like bounding boxes, points, and masks.
- Leverages SAM's specialized architecture for prompt-based, real-time segmentation.
+ Performs image segmentation inference based on input cues using SAM's specialized architecture.
+
+ This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation.
+ It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks.
Args:
- im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
- bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
- points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
- labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
- masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
- multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts.
+ im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
+ bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
+ points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
+ labels (np.ndarray | List | None): Point prompt labels with shape (N,). 1 for foreground, 0 for background.
+ masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.
+ multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
Returns:
- (tuple): Contains the following three elements.
- - np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
- - np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
- - np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
+ (tuple): Tuple containing:
+ - np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
+ - np.ndarray: Quality scores predicted by the model for each mask, with length C.
+ - np.ndarray: Low-resolution logits with shape (C, H, W) for subsequent inference, where H=W=256.
+
+ Examples:
+ >>> predictor = Predictor()
+ >>> im = torch.rand(1, 3, 1024, 1024)
+ >>> bboxes = [[100, 100, 200, 200]]
+ >>> masks, scores, logits = predictor.prompt_inference(im, bboxes=bboxes)
"""
features = self.get_im_features(im) if self.features is None else self.features
@@ -224,27 +285,32 @@ class Predictor(BasePredictor):
"""
Perform image segmentation using the Segment Anything Model (SAM).
- This function segments an entire image into constituent parts by leveraging SAM's advanced architecture
+ This method segments an entire image into constituent parts by leveraging SAM's advanced architecture
and real-time performance capabilities. It can optionally work on image crops for finer segmentation.
Args:
- im (torch.Tensor): Input tensor representing the preprocessed image with dimensions (N, C, H, W).
- crop_n_layers (int): Specifies the number of layers for additional mask predictions on image crops.
- Each layer produces 2**i_layer number of image crops.
- crop_overlap_ratio (float): Determines the overlap between crops. Scaled down in subsequent layers.
- crop_downscale_factor (int): Scaling factor for the number of sampled points-per-side in each layer.
- point_grids (list[np.ndarray], optional): Custom grids for point sampling normalized to [0,1].
- Used in the nth crop layer.
- points_stride (int, optional): Number of points to sample along each side of the image.
- Exclusive with 'point_grids'.
+ im (torch.Tensor): Input tensor representing the preprocessed image with shape (N, C, H, W).
+ crop_n_layers (int): Number of layers for additional mask predictions on image crops.
+ crop_overlap_ratio (float): Overlap between crops, scaled down in subsequent layers.
+ crop_downscale_factor (int): Scaling factor for sampled points-per-side in each layer.
+ point_grids (List[np.ndarray] | None): Custom grids for point sampling normalized to [0,1].
+ points_stride (int): Number of points to sample along each side of the image.
points_batch_size (int): Batch size for the number of points processed simultaneously.
- conf_thres (float): Confidence threshold [0,1] for filtering based on the model's mask quality prediction.
- stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on mask stability.
+ conf_thres (float): Confidence threshold [0,1] for filtering based on mask quality prediction.
+ stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on stability.
stability_score_offset (float): Offset value for calculating stability score.
crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops.
Returns:
- (tuple): A tuple containing segmented masks, confidence scores, and bounding boxes.
+ (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
+ - pred_masks (torch.Tensor): Segmented masks with shape (N, H, W).
+ - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,).
+ - pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4).
+
+ Examples:
+ >>> predictor = Predictor()
+ >>> im = torch.rand(1, 3, 1024, 1024) # Example input image
+ >>> masks, scores, boxes = predictor.generate(im)
"""
import torchvision # scope for faster 'import ultralytics'
@@ -326,11 +392,9 @@ class Predictor(BasePredictor):
model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration.
verbose (bool): If True, prints selected device information.
- Attributes:
- model (torch.nn.Module): The SAM model allocated to the chosen device for inference.
- device (torch.device): The device to which the model and tensors are allocated.
- mean (torch.Tensor): The mean values for image normalization.
- std (torch.Tensor): The standard deviation values for image normalization.
+ Examples:
+ >>> predictor = Predictor()
+ >>> predictor.setup_model(model=sam_model, verbose=True)
"""
device = select_device(self.args.device, verbose=verbose)
if model is None:
@@ -349,23 +413,32 @@ class Predictor(BasePredictor):
self.done_warmup = True
def get_model(self):
- """Built Segment Anything Model (SAM) model."""
+ """Retrieves or builds the Segment Anything Model (SAM) for image segmentation tasks."""
return build_sam(self.args.model)
def postprocess(self, preds, img, orig_imgs):
"""
Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
- The method scales masks and boxes to the original image size and applies a threshold to the mask predictions.
- The SAM model uses advanced architecture and promptable segmentation tasks to achieve real-time performance.
+ This method scales masks and boxes to the original image size and applies a threshold to the mask
+ predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks.
Args:
- preds (tuple): The output from SAM model inference, containing masks, scores, and optional bounding boxes.
- img (torch.Tensor): The processed input image tensor.
- orig_imgs (list | torch.Tensor): The original, unprocessed images.
+ preds (Tuple[torch.Tensor]): The output from SAM model inference, containing:
+ - pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W).
+ - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1).
+ - pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True.
+ img (torch.Tensor): The processed input image tensor with shape (C, H, W).
+ orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images.
Returns:
- (list): List of Results objects containing detection masks, bounding boxes, and other metadata.
+ (List[Results]): List of Results objects containing detection masks, bounding boxes, and other
+ metadata for each processed image.
+
+ Examples:
+ >>> predictor = Predictor()
+ >>> preds = predictor.inference(img)
+ >>> results = predictor.postprocess(preds, img, orig_imgs)
"""
# (N, 1, H, W), (N, 1)
pred_masks, pred_scores = preds[:2]
@@ -393,11 +466,23 @@ class Predictor(BasePredictor):
"""
Sets up the data source for inference.
- This method configures the data source from which images will be fetched for inference. The source could be a
- directory, a video file, or other types of image data sources.
+ This method configures the data source from which images will be fetched for inference. It supports
+ various input types such as image files, directories, video files, and other compatible data sources.
Args:
- source (str | Path): The path to the image data source for inference.
+ source (str | Path | None): The path or identifier for the image data source. Can be a file path,
+ directory path, URL, or other supported source types.
+
+ Examples:
+ >>> predictor = Predictor()
+ >>> predictor.setup_source('path/to/images')
+ >>> predictor.setup_source('video.mp4')
+ >>> predictor.setup_source(None) # Uses default source if available
+
+ Notes:
+ - If source is None, the method may use a default source if configured.
+ - The method adapts to different source types and prepares them for subsequent inference steps.
+ - Supported source types may include local files, directories, URLs, and video streams.
"""
if source is not None:
super().setup_source(source)
@@ -406,14 +491,25 @@ class Predictor(BasePredictor):
"""
Preprocesses and sets a single image for inference.
- This function sets up the model if not already initialized, configures the data source to the specified image,
- and preprocesses the image for feature extraction. Only one image can be set at a time.
+ This method prepares the model for inference on a single image by setting up the model if not already
+ initialized, configuring the data source, and preprocessing the image for feature extraction. It
+ ensures that only one image is set at a time and extracts image features for subsequent use.
Args:
- image (str | np.ndarray): Image file path as a string, or a np.ndarray image read by cv2.
+ image (str | np.ndarray): Path to the image file as a string, or a numpy array representing
+ an image read by cv2.
Raises:
- AssertionError: If more than one image is set.
+ AssertionError: If more than one image is attempted to be set.
+
+ Examples:
+ >>> predictor = Predictor()
+ >>> predictor.set_image('path/to/image.jpg')
+ >>> predictor.set_image(cv2.imread('path/to/image.jpg'))
+
+ Notes:
+ - This method should be called before performing inference on a new image.
+ - The extracted features are stored in the `self.features` attribute for later use.
"""
if self.model is None:
self.setup_model(model=None)
@@ -425,35 +521,44 @@ class Predictor(BasePredictor):
break
def get_im_features(self, im):
- """Get image features from the SAM image encoder."""
+ """Extracts image features using the SAM model's image encoder for subsequent mask prediction."""
return self.model.image_encoder(im)
def set_prompts(self, prompts):
- """Set prompts in advance."""
+ """Sets prompts for subsequent inference operations."""
self.prompts = prompts
def reset_image(self):
- """Resets the image and its features to None."""
+ """Resets the current image and its features, clearing them for subsequent inference."""
self.im = None
self.features = None
@staticmethod
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
"""
- Perform post-processing on segmentation masks generated by the Segment Anything Model (SAM). Specifically, this
- function removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
+ Remove small disconnected regions and holes from segmentation masks.
+
+ This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM).
+ It removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
Suppression (NMS) to eliminate any newly created duplicate boxes.
Args:
- masks (torch.Tensor): A tensor containing the masks to be processed. Shape should be (N, H, W), where N is
- the number of masks, H is height, and W is width.
- min_area (int): The minimum area below which disconnected regions and holes will be removed. Defaults to 0.
- nms_thresh (float): The IoU threshold for the NMS algorithm. Defaults to 0.7.
+ masks (torch.Tensor): Segmentation masks to be processed, with shape (N, H, W) where N is the number of
+ masks, H is height, and W is width.
+ min_area (int): Minimum area threshold for removing disconnected regions and holes. Regions smaller than
+ this will be removed.
+ nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes.
Returns:
- (tuple([torch.Tensor, List[int]])):
- - new_masks (torch.Tensor): The processed masks with small regions removed. Shape is (N, H, W).
- - keep (List[int]): The indices of the remaining masks post-NMS, which can be used to filter the boxes.
+ (tuple):
+ - new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).
+ - keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.
+
+ Examples:
+ >>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks
+ >>> new_masks, keep = remove_small_regions(masks, min_area=100, nms_thresh=0.7)
+ >>> print(f"Original masks: {masks.shape}, Processed masks: {new_masks.shape}")
+ >>> print(f"Indices of kept masks: {keep}")
"""
import torchvision # scope for faster 'import ultralytics'
@@ -480,3 +585,188 @@ class Predictor(BasePredictor):
keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh)
return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep
+
+
+class SAM2Predictor(Predictor):
+ """
+ SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.
+
+ This class extends the base Predictor class to implement SAM2-specific functionality for image
+ segmentation tasks. It provides methods for model initialization, feature extraction, and
+ prompt-based inference.
+
+ Attributes:
+ _bb_feat_sizes (List[Tuple[int, int]]): Feature sizes for different backbone levels.
+ model (torch.nn.Module): The loaded SAM2 model.
+ device (torch.device): The device (CPU or GPU) on which the model is loaded.
+ features (Dict[str, torch.Tensor]): Cached image features for efficient inference.
+ segment_all (bool): Flag to indicate if all segments should be predicted.
+ prompts (Dict): Dictionary to store various types of prompts for inference.
+
+ Methods:
+ get_model: Retrieves and initializes the SAM2 model.
+ prompt_inference: Performs image segmentation inference based on various prompts.
+ set_image: Preprocesses and sets a single image for inference.
+ get_im_features: Extracts and processes image features using SAM2's image encoder.
+
+ Examples:
+ >>> predictor = SAM2Predictor(cfg)
+ >>> predictor.set_image("path/to/image.jpg")
+ >>> bboxes = [[100, 100, 200, 200]]
+ >>> masks, scores, _ = predictor.prompt_inference(predictor.im, bboxes=bboxes)
+ >>> print(f"Predicted {len(masks)} masks with average score {scores.mean():.2f}")
+ """
+
+ _bb_feat_sizes = [
+ (256, 256),
+ (128, 128),
+ (64, 64),
+ ]
+
+ def get_model(self):
+ """Retrieves and initializes the Segment Anything Model 2 (SAM2) for image segmentation tasks."""
+ return build_sam(self.args.model)
+
+ def prompt_inference(
+ self,
+ im,
+ bboxes=None,
+ points=None,
+ labels=None,
+ masks=None,
+ multimask_output=False,
+ img_idx=-1,
+ ):
+ """
+ Performs image segmentation inference based on various prompts using SAM2 architecture.
+
+ This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images
+ based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and
+ multi-object prediction scenarios.
+
+ Args:
+ im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
+ bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4).
+ points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels.
+ labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.
+ masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W).
+ multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
+ img_idx (int): Index of the image in the batch to process.
+
+ Returns:
+ (tuple): Tuple containing:
+ - np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
+ - np.ndarray: Quality scores for each mask, with length C.
+ - np.ndarray: Low-resolution logits with shape (C, 256, 256) for subsequent inference.
+
+ Examples:
+ >>> predictor = SAM2Predictor(cfg)
+ >>> image = torch.rand(1, 3, 640, 640)
+ >>> bboxes = [[100, 100, 200, 200]]
+ >>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes)
+ >>> print(f"Generated {masks.shape[0]} masks with average score {scores.mean():.2f}")
+
+ Notes:
+ - The method supports batched inference for multiple objects when points or bboxes are provided.
+ - Input prompts (bboxes, points) are automatically scaled to match the input image dimensions.
+ - When both bboxes and points are provided, they are merged into a single 'points' input for the model.
+
+ References:
+ - SAM2 Paper: [Add link to SAM2 paper when available]
+ """
+ features = self.get_im_features(im) if self.features is None else self.features
+
+ src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
+ r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
+ # Transform input prompts
+ if points is not None:
+ points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
+ points = points[None] if points.ndim == 1 else points
+ # Assuming labels are all positive if users don't pass labels.
+ if labels is None:
+ labels = torch.ones(points.shape[0])
+ labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
+ points *= r
+ # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
+ points, labels = points[:, None], labels[:, None]
+ if bboxes is not None:
+ bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
+ bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
+ bboxes = bboxes.view(-1, 2, 2) * r
+ bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
+ # NOTE: merge "boxes" and "points" into a single "points" input
+ # (where boxes are added at the beginning) to model.sam_prompt_encoder
+ if points is not None:
+ points = torch.cat([bboxes, points], dim=1)
+ labels = torch.cat([bbox_labels, labels], dim=1)
+ else:
+ points, labels = bboxes, bbox_labels
+ if masks is not None:
+ masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
+
+ points = (points, labels) if points is not None else None
+
+ sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
+ points=points,
+ boxes=None,
+ masks=masks,
+ )
+ # Predict masks
+ batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction
+ high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
+ pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(
+ image_embeddings=features["image_embed"][img_idx].unsqueeze(0),
+ image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ repeat_image=batched_mode,
+ high_res_features=high_res_features,
+ )
+ # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
+ # `d` could be 1 or 3 depends on `multimask_output`.
+ return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
+
+ def set_image(self, image):
+ """
+ Preprocesses and sets a single image for inference using the SAM2 model.
+
+ This method initializes the model if not already done, configures the data source to the specified image,
+ and preprocesses the image for feature extraction. It supports setting only one image at a time.
+
+ Args:
+ image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image.
+
+ Raises:
+ AssertionError: If more than one image is attempted to be set.
+
+ Examples:
+ >>> predictor = SAM2Predictor()
+ >>> predictor.set_image("path/to/image.jpg")
+ >>> predictor.set_image(np.array([...])) # Using a numpy array
+
+ Notes:
+ - This method must be called before performing any inference on a new image.
+ - The method caches the extracted features for efficient subsequent inferences on the same image.
+ - Only one image can be set at a time. To process multiple images, call this method for each new image.
+ """
+ if self.model is None:
+ self.setup_model(model=None)
+ self.setup_source(image)
+ assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
+ for batch in self.dataset:
+ im = self.preprocess(batch[1])
+ self.features = self.get_im_features(im)
+ break
+
+ def get_im_features(self, im):
+ """Extracts image features from the SAM image encoder for subsequent processing."""
+ backbone_out = self.model.forward_image(im)
+ _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
+ if self.model.directly_add_no_mem_embed:
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
+ feats = [
+ feat.permute(1, 2, 0).view(1, -1, *feat_size)
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
+ ][::-1]
+ return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
diff --git a/ultralytics/models/sam2/__init__.py b/ultralytics/models/sam2/__init__.py
deleted file mode 100644
index 755a160af..000000000
--- a/ultralytics/models/sam2/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-# Ultralytics YOLO 🚀, AGPL-3.0 license
-
-from .model import SAM2
-from .predict import SAM2Predictor
-
-__all__ = "SAM2", "SAM2Predictor" # tuple or list
diff --git a/ultralytics/models/sam2/build.py b/ultralytics/models/sam2/build.py
deleted file mode 100644
index 2791029a5..000000000
--- a/ultralytics/models/sam2/build.py
+++ /dev/null
@@ -1,156 +0,0 @@
-# Ultralytics YOLO 🚀, AGPL-3.0 license
-
-import torch
-
-from ultralytics.utils.downloads import attempt_download_asset
-
-from .modules.encoders import FpnNeck, Hiera, ImageEncoder, MemoryEncoder
-from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
-from .modules.sam2 import SAM2Model
-
-
-def build_sam2_t(checkpoint=None):
- """Build and return a Segment Anything Model (SAM2) tiny-size model with specified architecture parameters."""
- return _build_sam2(
- encoder_embed_dim=96,
- encoder_stages=[1, 2, 7, 2],
- encoder_num_heads=1,
- encoder_global_att_blocks=[5, 7, 9],
- encoder_window_spec=[8, 4, 14, 7],
- encoder_backbone_channel_list=[768, 384, 192, 96],
- checkpoint=checkpoint,
- )
-
-
-def build_sam2_s(checkpoint=None):
- """Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters."""
- return _build_sam2(
- encoder_embed_dim=96,
- encoder_stages=[1, 2, 11, 2],
- encoder_num_heads=1,
- encoder_global_att_blocks=[7, 10, 13],
- encoder_window_spec=[8, 4, 14, 7],
- encoder_backbone_channel_list=[768, 384, 192, 96],
- checkpoint=checkpoint,
- )
-
-
-def build_sam2_b(checkpoint=None):
- """Builds and returns a Segment Anything Model (SAM2) base-size model with specified architecture parameters."""
- return _build_sam2(
- encoder_embed_dim=112,
- encoder_stages=[2, 3, 16, 3],
- encoder_num_heads=2,
- encoder_global_att_blocks=[12, 16, 20],
- encoder_window_spec=[8, 4, 14, 7],
- encoder_window_spatial_size=[14, 14],
- encoder_backbone_channel_list=[896, 448, 224, 112],
- checkpoint=checkpoint,
- )
-
-
-def build_sam2_l(checkpoint=None):
- """Build and return a Segment Anything Model (SAM2) large-size model with specified architecture parameters."""
- return _build_sam2(
- encoder_embed_dim=144,
- encoder_stages=[2, 6, 36, 4],
- encoder_num_heads=2,
- encoder_global_att_blocks=[23, 33, 43],
- encoder_window_spec=[8, 4, 16, 8],
- encoder_backbone_channel_list=[1152, 576, 288, 144],
- checkpoint=checkpoint,
- )
-
-
-def _build_sam2(
- encoder_embed_dim=1280,
- encoder_stages=[2, 6, 36, 4],
- encoder_num_heads=2,
- encoder_global_att_blocks=[7, 15, 23, 31],
- encoder_backbone_channel_list=[1152, 576, 288, 144],
- encoder_window_spatial_size=[7, 7],
- encoder_window_spec=[8, 4, 16, 8],
- checkpoint=None,
-):
- """Builds a SAM2 model with specified architecture parameters and optional checkpoint loading."""
- image_encoder = ImageEncoder(
- trunk=Hiera(
- embed_dim=encoder_embed_dim,
- num_heads=encoder_num_heads,
- stages=encoder_stages,
- global_att_blocks=encoder_global_att_blocks,
- window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,
- window_spec=encoder_window_spec,
- ),
- neck=FpnNeck(
- d_model=256,
- backbone_channel_list=encoder_backbone_channel_list,
- fpn_top_down_levels=[2, 3],
- fpn_interp_model="nearest",
- ),
- scalp=1,
- )
- memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
- memory_encoder = MemoryEncoder(out_dim=64)
-
- sam2 = SAM2Model(
- image_encoder=image_encoder,
- memory_attention=memory_attention,
- memory_encoder=memory_encoder,
- num_maskmem=7,
- image_size=1024,
- sigmoid_scale_for_mem_enc=20.0,
- sigmoid_bias_for_mem_enc=-10.0,
- use_mask_input_as_output_without_sam=True,
- directly_add_no_mem_embed=True,
- use_high_res_features_in_sam=True,
- multimask_output_in_sam=True,
- iou_prediction_use_sigmoid=True,
- use_obj_ptrs_in_encoder=True,
- add_tpos_enc_to_obj_ptrs=True,
- only_obj_ptrs_in_the_past_for_eval=True,
- pred_obj_scores=True,
- pred_obj_scores_mlp=True,
- fixed_no_obj_ptr=True,
- multimask_output_for_tracking=True,
- use_multimask_token_for_obj_ptr=True,
- multimask_min_pt_num=0,
- multimask_max_pt_num=1,
- use_mlp_for_obj_ptr_proj=True,
- compile_image_encoder=False,
- sam_mask_decoder_extra_args=dict(
- dynamic_multimask_via_stability=True,
- dynamic_multimask_stability_delta=0.05,
- dynamic_multimask_stability_thresh=0.98,
- ),
- )
-
- if checkpoint is not None:
- checkpoint = attempt_download_asset(checkpoint)
- with open(checkpoint, "rb") as f:
- state_dict = torch.load(f)["model"]
- sam2.load_state_dict(state_dict)
- sam2.eval()
- return sam2
-
-
-sam_model_map = {
- "sam2_t.pt": build_sam2_t,
- "sam2_s.pt": build_sam2_s,
- "sam2_b.pt": build_sam2_b,
- "sam2_l.pt": build_sam2_l,
-}
-
-
-def build_sam2(ckpt="sam_b.pt"):
- """Constructs a Segment Anything Model (SAM2) based on the specified checkpoint, with various size options."""
- model_builder = None
- ckpt = str(ckpt) # to allow Path ckpt types
- for k in sam_model_map.keys():
- if ckpt.endswith(k):
- model_builder = sam_model_map.get(k)
-
- if not model_builder:
- raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
-
- return model_builder(ckpt)
diff --git a/ultralytics/models/sam2/model.py b/ultralytics/models/sam2/model.py
deleted file mode 100644
index 4b3265e19..000000000
--- a/ultralytics/models/sam2/model.py
+++ /dev/null
@@ -1,97 +0,0 @@
-# Ultralytics YOLO 🚀, AGPL-3.0 license
-"""
-SAM2 model interface.
-
-This module provides an interface to the Segment Anything Model (SAM2) from Ultralytics, designed for real-time image
-segmentation tasks. The SAM2 model allows for promptable segmentation with unparalleled versatility in image analysis,
-and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new
-image distributions and tasks without prior knowledge.
-
-Key Features:
- - Promptable segmentation
- - Real-time performance
- - Zero-shot transfer capabilities
- - Trained on SA-1B dataset
-"""
-
-from ultralytics.models.sam import SAM
-
-from .build import build_sam2
-from .predict import SAM2Predictor
-
-
-class SAM2(SAM):
- """
- SAM2 class for real-time image segmentation using the Segment Anything Model (SAM2).
-
- This class extends the SAM base class, providing an interface to the SAM2 model for promptable segmentation
- tasks. It supports loading pre-trained weights and offers zero-shot performance capabilities.
-
- Attributes:
- model (torch.nn.Module): The loaded SAM2 model.
- task_map (Dict[str, Type[SAM2Predictor]]): Mapping of 'segment' task to SAM2Predictor.
-
- Methods:
- __init__: Initializes the SAM2 model with pre-trained weights.
- _load: Loads specified weights into the SAM2 model.
-
- Examples:
- >>> sam2 = SAM2("sam2_b.pt")
- >>> sam2._load('path/to/sam2_weights.pt')
- >>> task_map = sam2.task_map
- >>> print(task_map)
- {'segment': SAM2Predictor}
-
- Notes:
- - Supports .pt and .pth file extensions for model weights.
- - Offers zero-shot transfer capabilities for new image distributions and tasks.
- """
-
- def __init__(self, model="sam2_b.pt") -> None:
- """
- Initializes the SAM2 model with a pre-trained model file.
-
- Args:
- model (str): Path to the pre-trained SAM2 model file. File should have a .pt or .pth extension.
-
- Raises:
- NotImplementedError: If the model file extension is not .pt or .pth.
-
- Examples:
- >>> sam2 = SAM2("sam2_b.pt")
- """
- super().__init__(model=model)
-
- def _load(self, weights: str, task=None):
- """
- Loads the specified weights into the SAM2 model.
-
- This method is responsible for loading pre-trained weights into the SAM2 model. It supports loading
- weights from files with .pt or .pth extensions.
-
- Args:
- weights (str): Path to the weights file. Should be a file with .pt or .pth extension.
- task (str | None): Task name. If provided, it may be used to configure model-specific settings.
-
- Examples:
- >>> sam2_model = SAM2()
- >>> sam2_model._load('path/to/sam2_weights.pt')
- """
- self.model = build_sam2(weights)
-
- @property
- def task_map(self):
- """
- Provides a mapping from the 'segment' task to its corresponding 'Predictor'.
-
- Returns:
- (Dict[str, Type[SAM2Predictor]]): A dictionary mapping the 'segment' task to its corresponding
- SAM2Predictor class.
-
- Examples:
- >>> sam2 = SAM2()
- >>> task_map = sam2.task_map
- >>> print(task_map)
- {'segment': SAM2Predictor}
- """
- return {"segment": {"predictor": SAM2Predictor}}
diff --git a/ultralytics/models/sam2/modules/__init__.py b/ultralytics/models/sam2/modules/__init__.py
deleted file mode 100644
index 9e68dc122..000000000
--- a/ultralytics/models/sam2/modules/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Ultralytics YOLO 🚀, AGPL-3.0 license
diff --git a/ultralytics/models/sam2/modules/decoders.py b/ultralytics/models/sam2/modules/decoders.py
deleted file mode 100644
index ac6c60a60..000000000
--- a/ultralytics/models/sam2/modules/decoders.py
+++ /dev/null
@@ -1,305 +0,0 @@
-# Ultralytics YOLO 🚀, AGPL-3.0 license
-
-from typing import List, Optional, Tuple, Type
-
-import torch
-from torch import nn
-
-from ultralytics.nn.modules import MLP, LayerNorm2d
-
-
-class MaskDecoder(nn.Module):
- """Transformer-based decoder predicting instance segmentation masks from image and prompt embeddings."""
-
- def __init__(
- self,
- transformer_dim: int,
- transformer: nn.Module,
- num_multimask_outputs: int = 3,
- activation: Type[nn.Module] = nn.GELU,
- iou_head_depth: int = 3,
- iou_head_hidden_dim: int = 256,
- use_high_res_features: bool = False,
- iou_prediction_use_sigmoid=False,
- dynamic_multimask_via_stability=False,
- dynamic_multimask_stability_delta=0.05,
- dynamic_multimask_stability_thresh=0.98,
- pred_obj_scores: bool = False,
- pred_obj_scores_mlp: bool = False,
- use_multimask_token_for_obj_ptr: bool = False,
- ) -> None:
- """
- Initializes the MaskDecoder module for predicting instance segmentation masks.
-
- Args:
- transformer_dim (int): Channel dimension of the transformer.
- transformer (nn.Module): Transformer used to predict masks.
- num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
- activation (Type[nn.Module]): Type of activation to use when upscaling masks.
- iou_head_depth (int): Depth of the MLP used to predict mask quality.
- iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
- use_high_res_features (bool): Whether to use high-resolution features.
- iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction.
- dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
- dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
- dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
- pred_obj_scores (bool): Whether to predict object scores.
- pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction.
- use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
-
- Attributes:
- transformer_dim (int): Channel dimension of the transformer.
- transformer (nn.Module): Transformer used to predict masks.
- num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
- iou_token (nn.Embedding): Embedding for IOU token.
- num_mask_tokens (int): Total number of mask tokens.
- mask_tokens (nn.Embedding): Embedding for mask tokens.
- pred_obj_scores (bool): Whether to predict object scores.
- obj_score_token (nn.Embedding): Embedding for object score token.
- use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
- output_upscaling (nn.Sequential): Upscaling layers for output.
- use_high_res_features (bool): Whether to use high-resolution features.
- conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0).
- conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1).
- output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks.
- iou_prediction_head (MLP): MLP for IOU prediction.
- pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction.
- dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
- dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
- """
- super().__init__()
- self.transformer_dim = transformer_dim
- self.transformer = transformer
-
- self.num_multimask_outputs = num_multimask_outputs
-
- self.iou_token = nn.Embedding(1, transformer_dim)
- self.num_mask_tokens = num_multimask_outputs + 1
- self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
-
- self.pred_obj_scores = pred_obj_scores
- if self.pred_obj_scores:
- self.obj_score_token = nn.Embedding(1, transformer_dim)
- self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
-
- self.output_upscaling = nn.Sequential(
- nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
- LayerNorm2d(transformer_dim // 4),
- activation(),
- nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
- activation(),
- )
- self.use_high_res_features = use_high_res_features
- if use_high_res_features:
- self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1)
- self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1)
-
- self.output_hypernetworks_mlps = nn.ModuleList(
- [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
- )
-
- self.iou_prediction_head = MLP(
- transformer_dim,
- iou_head_hidden_dim,
- self.num_mask_tokens,
- iou_head_depth,
- sigmoid=iou_prediction_use_sigmoid,
- )
- if self.pred_obj_scores:
- self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
- if pred_obj_scores_mlp:
- self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
-
- # When outputting a single mask, optionally we can dynamically fall back to the best
- # multimask output token if the single mask output token gives low stability scores.
- self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
- self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
- self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
-
- def forward(
- self,
- image_embeddings: torch.Tensor,
- image_pe: torch.Tensor,
- sparse_prompt_embeddings: torch.Tensor,
- dense_prompt_embeddings: torch.Tensor,
- multimask_output: bool,
- repeat_image: bool,
- high_res_features: Optional[List[torch.Tensor]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Predicts masks given image and prompt embeddings.
-
- Args:
- image_embeddings (torch.Tensor): Embeddings from the image encoder.
- image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings.
- sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes.
- dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs.
- multimask_output (bool): Whether to return multiple masks or a single mask.
- repeat_image (bool): Flag to repeat the image embeddings.
- high_res_features (List[torch.Tensor] | None): Optional high-resolution features.
-
- Returns:
- (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
- - masks (torch.Tensor): Batched predicted masks.
- - iou_pred (torch.Tensor): Batched predictions of mask quality.
- - sam_tokens_out (torch.Tensor): Batched SAM token for mask output.
-
- Examples:
- >>> image_embeddings = torch.rand(1, 256, 64, 64)
- >>> image_pe = torch.rand(1, 256, 64, 64)
- >>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
- >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
- >>> decoder = MaskDecoder(256, transformer)
- >>> masks, iou_pred, sam_tokens_out = decoder.forward(image_embeddings, image_pe,
- ... sparse_prompt_embeddings, dense_prompt_embeddings, True, False)
- """
- masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
- image_embeddings=image_embeddings,
- image_pe=image_pe,
- sparse_prompt_embeddings=sparse_prompt_embeddings,
- dense_prompt_embeddings=dense_prompt_embeddings,
- repeat_image=repeat_image,
- high_res_features=high_res_features,
- )
-
- # Select the correct mask or masks for output
- if multimask_output:
- masks = masks[:, 1:, :, :]
- iou_pred = iou_pred[:, 1:]
- elif self.dynamic_multimask_via_stability and not self.training:
- masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
- else:
- masks = masks[:, 0:1, :, :]
- iou_pred = iou_pred[:, 0:1]
-
- if multimask_output and self.use_multimask_token_for_obj_ptr:
- sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
- else:
- # Take the mask output token. Here we *always* use the token for single mask output.
- # At test time, even if we track after 1-click (and using multimask_output=True),
- # we still take the single mask token here. The rationale is that we always track
- # after multiple clicks during training, so the past tokens seen during training
- # are always the single mask token (and we'll let it be the object-memory token).
- sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
-
- # Prepare output
- return masks, iou_pred, sam_tokens_out, object_score_logits
-
- def predict_masks(
- self,
- image_embeddings: torch.Tensor,
- image_pe: torch.Tensor,
- sparse_prompt_embeddings: torch.Tensor,
- dense_prompt_embeddings: torch.Tensor,
- repeat_image: bool,
- high_res_features: Optional[List[torch.Tensor]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Predicts instance segmentation masks from image and prompt embeddings using a transformer architecture."""
- # Concatenate output tokens
- s = 0
- if self.pred_obj_scores:
- output_tokens = torch.cat(
- [
- self.obj_score_token.weight,
- self.iou_token.weight,
- self.mask_tokens.weight,
- ],
- dim=0,
- )
- s = 1
- else:
- output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
- output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
- tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
-
- # Expand per-image data in batch direction to be per-mask
- if repeat_image:
- src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
- else:
- assert image_embeddings.shape[0] == tokens.shape[0]
- src = image_embeddings
- src = src + dense_prompt_embeddings
- assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
- pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
- b, c, h, w = src.shape
-
- # Run the transformer
- hs, src = self.transformer(src, pos_src, tokens)
- iou_token_out = hs[:, s, :]
- mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
-
- # Upscale mask embeddings and predict masks using the mask tokens
- src = src.transpose(1, 2).view(b, c, h, w)
- if not self.use_high_res_features:
- upscaled_embedding = self.output_upscaling(src)
- else:
- dc1, ln1, act1, dc2, act2 = self.output_upscaling
- feat_s0, feat_s1 = high_res_features
- upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
- upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
-
- hyper_in_list: List[torch.Tensor] = []
- for i in range(self.num_mask_tokens):
- hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
- hyper_in = torch.stack(hyper_in_list, dim=1)
- b, c, h, w = upscaled_embedding.shape
- masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
-
- # Generate mask quality predictions
- iou_pred = self.iou_prediction_head(iou_token_out)
- if self.pred_obj_scores:
- assert s == 1
- object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
- else:
- # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
- object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
-
- return masks, iou_pred, mask_tokens_out, object_score_logits
-
- def _get_stability_scores(self, mask_logits):
- """Computes mask stability scores based on IoU between upper and lower thresholds."""
- mask_logits = mask_logits.flatten(-2)
- stability_delta = self.dynamic_multimask_stability_delta
- area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
- area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
- stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
- return stability_scores
-
- def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
- """
- Dynamically selects the most stable mask output based on stability scores and IoU predictions.
-
- When outputting a single mask, if the stability score from the current single-mask output (based on output token
- 0) falls below a threshold, we instead select from multi-mask outputs (based on output token 1~3) the mask with
- the highest predicted IoU score.
-
- This is intended to ensure a valid mask for both clicking and tracking.
- """
- # The best mask from multimask output tokens (1~3)
- multimask_logits = all_mask_logits[:, 1:, :, :]
- multimask_iou_scores = all_iou_scores[:, 1:]
- best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
- batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
- best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
- best_multimask_logits = best_multimask_logits.unsqueeze(1)
- best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
- best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
-
- # The mask from singlemask output token 0 and its stability score
- singlemask_logits = all_mask_logits[:, 0:1, :, :]
- singlemask_iou_scores = all_iou_scores[:, 0:1]
- stability_scores = self._get_stability_scores(singlemask_logits)
- is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
-
- # Dynamically fall back to best multimask output upon low stability scores.
- mask_logits_out = torch.where(
- is_stable[..., None, None].expand_as(singlemask_logits),
- singlemask_logits,
- best_multimask_logits,
- )
- iou_scores_out = torch.where(
- is_stable.expand_as(singlemask_iou_scores),
- singlemask_iou_scores,
- best_multimask_iou_scores,
- )
- return mask_logits_out, iou_scores_out
diff --git a/ultralytics/models/sam2/modules/encoders.py b/ultralytics/models/sam2/modules/encoders.py
deleted file mode 100644
index b4cd0f8f2..000000000
--- a/ultralytics/models/sam2/modules/encoders.py
+++ /dev/null
@@ -1,332 +0,0 @@
-# Ultralytics YOLO 🚀, AGPL-3.0 license
-
-from typing import List, Optional, Tuple
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from ultralytics.models.sam.modules.encoders import PatchEmbed
-
-from .sam2_blocks import CXBlock, Fuser, MaskDownSampler, MultiScaleBlock, PositionEmbeddingSine
-
-
-class MemoryEncoder(nn.Module):
- """Encodes pixel features and masks into a memory representation for efficient image segmentation."""
-
- def __init__(
- self,
- out_dim,
- in_dim=256, # in_dim of pix_feats
- ):
- """Initializes the MemoryEncoder module for encoding pixel features and masks in SAM-like models."""
- super().__init__()
-
- self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
-
- self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
- self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
- self.position_encoding = PositionEmbeddingSine(num_pos_feats=64)
- self.out_proj = nn.Identity()
- if out_dim != in_dim:
- self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
-
- def forward(
- self,
- pix_feat: torch.Tensor,
- masks: torch.Tensor,
- skip_mask_sigmoid: bool = False,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Processes pixel features and masks, fusing them to generate encoded memory representations."""
- if not skip_mask_sigmoid:
- masks = F.sigmoid(masks)
- masks = self.mask_downsampler(masks)
-
- # Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA
- pix_feat = pix_feat.to(masks.device)
-
- x = self.pix_feat_proj(pix_feat)
- x = x + masks
- x = self.fuser(x)
- x = self.out_proj(x)
-
- pos = self.position_encoding(x).to(x.dtype)
-
- return {"vision_features": x, "vision_pos_enc": [pos]}
-
-
-class ImageEncoder(nn.Module):
- """Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings."""
-
- def __init__(
- self,
- trunk: nn.Module,
- neck: nn.Module,
- scalp: int = 0,
- ):
- """Initializes an image encoder with a trunk, neck, and optional scalp for feature extraction."""
- super().__init__()
- self.trunk = trunk
- self.neck = neck
- self.scalp = scalp
- assert (
- self.trunk.channel_list == self.neck.backbone_channel_list
- ), f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match."
-
- def forward(self, sample: torch.Tensor):
- """Processes image input through trunk and neck, returning features, positional encodings, and FPN outputs."""
- features, pos = self.neck(self.trunk(sample))
- if self.scalp > 0:
- # Discard the lowest resolution features
- features, pos = features[: -self.scalp], pos[: -self.scalp]
-
- src = features[-1]
- output = {
- "vision_features": src,
- "vision_pos_enc": pos,
- "backbone_fpn": features,
- }
- return output
-
-
-class FpnNeck(nn.Module):
- """Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models."""
-
- def __init__(
- self,
- d_model: int,
- backbone_channel_list: List[int],
- kernel_size: int = 1,
- stride: int = 1,
- padding: int = 0,
- fpn_interp_model: str = "bilinear",
- fuse_type: str = "sum",
- fpn_top_down_levels: Optional[List[int]] = None,
- ):
- """
- Initializes a modified Feature Pyramid Network (FPN) neck.
-
- This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
- similar to ViT positional embedding interpolation.
-
- Args:
- d_model (int): Dimension of the model.
- backbone_channel_list (List[int]): List of channel dimensions from the backbone.
- kernel_size (int): Kernel size for the convolutional layers.
- stride (int): Stride for the convolutional layers.
- padding (int): Padding for the convolutional layers.
- fpn_interp_model (str): Interpolation mode for FPN feature resizing.
- fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
- fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs.
-
- Attributes:
- position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding.
- convs (nn.ModuleList): List of convolutional layers for each backbone level.
- backbone_channel_list (List[int]): List of channel dimensions from the backbone.
- fpn_interp_model (str): Interpolation mode for FPN feature resizing.
- fuse_type (str): Type of feature fusion.
- fpn_top_down_levels (List[int]): Levels with top-down feature propagation.
-
- Examples:
- >>> backbone_channels = [64, 128, 256, 512]
- >>> fpn_neck = FpnNeck(256, backbone_channels)
- >>> print(fpn_neck)
- """
- super().__init__()
- self.position_encoding = PositionEmbeddingSine(num_pos_feats=256)
- self.convs = nn.ModuleList()
- self.backbone_channel_list = backbone_channel_list
- for dim in backbone_channel_list:
- current = nn.Sequential()
- current.add_module(
- "conv",
- nn.Conv2d(
- in_channels=dim,
- out_channels=d_model,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- ),
- )
-
- self.convs.append(current)
- self.fpn_interp_model = fpn_interp_model
- assert fuse_type in ["sum", "avg"]
- self.fuse_type = fuse_type
-
- # levels to have top-down features in its outputs
- # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
- # have top-down propagation, while outputs of level 0 and level 1 have only
- # lateral features from the same backbone level.
- if fpn_top_down_levels is None:
- # default is to have top-down features on all levels
- fpn_top_down_levels = range(len(self.convs))
- self.fpn_top_down_levels = list(fpn_top_down_levels)
-
- def forward(self, xs: List[torch.Tensor]):
- """
- Performs forward pass through the Feature Pyramid Network (FPN) neck.
-
- Args:
- xs (List[torch.Tensor]): List of input tensors from the backbone, with shape (B, C, H, W) for each tensor.
-
- Returns:
- (Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing two lists:
- - out: List of output feature maps after FPN processing, with shape (B, d_model, H, W) for each tensor.
- - pos: List of positional encodings corresponding to each output feature map.
-
- Examples:
- >>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])
- >>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]]
- >>> outputs, positions = fpn_neck(inputs)
- """
- out = [None] * len(self.convs)
- pos = [None] * len(self.convs)
- assert len(xs) == len(self.convs)
- # fpn forward pass
- # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
- prev_features = None
- # forward in top-down order (from low to high resolution)
- n = len(self.convs) - 1
- for i in range(n, -1, -1):
- x = xs[i]
- lateral_features = self.convs[n - i](x)
- if i in self.fpn_top_down_levels and prev_features is not None:
- top_down_features = F.interpolate(
- prev_features.to(dtype=torch.float32),
- scale_factor=2.0,
- mode=self.fpn_interp_model,
- align_corners=(None if self.fpn_interp_model == "nearest" else False),
- antialias=False,
- )
- prev_features = lateral_features + top_down_features
- if self.fuse_type == "avg":
- prev_features /= 2
- else:
- prev_features = lateral_features
- x_out = prev_features
- out[i] = x_out
- pos[i] = self.position_encoding(x_out).to(x_out.dtype)
-
- return out, pos
-
-
-class Hiera(nn.Module):
- """Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks."""
-
- def __init__(
- self,
- embed_dim: int = 96, # initial embed dim
- num_heads: int = 1, # initial number of heads
- drop_path_rate: float = 0.0, # stochastic depth
- q_pool: int = 3, # number of q_pool stages
- q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
- stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
- dim_mul: float = 2.0, # dim_mul factor at stage shift
- head_mul: float = 2.0, # head_mul factor at stage shift
- window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
- # window size per stage, when not using global att.
- window_spec: Tuple[int, ...] = (
- 8,
- 4,
- 14,
- 7,
- ),
- # global attn in these blocks
- global_att_blocks: Tuple[int, ...] = (
- 12,
- 16,
- 20,
- ),
- return_interm_layers=True, # return feats from every stage
- ):
- """Initializes a Hiera model with configurable architecture for hierarchical vision transformers."""
- super().__init__()
-
- assert len(stages) == len(window_spec)
- self.window_spec = window_spec
-
- depth = sum(stages)
- self.q_stride = q_stride
- self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
- assert 0 <= q_pool <= len(self.stage_ends[:-1])
- self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
- self.return_interm_layers = return_interm_layers
-
- self.patch_embed = PatchEmbed(
- embed_dim=embed_dim,
- kernel_size=(7, 7),
- stride=(4, 4),
- padding=(3, 3),
- )
- # Which blocks have global att?
- self.global_att_blocks = global_att_blocks
-
- # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
- self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
- self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))
- self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
-
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
-
- cur_stage = 1
- self.blocks = nn.ModuleList()
-
- for i in range(depth):
- dim_out = embed_dim
- # lags by a block, so first block of
- # next stage uses an initial window size
- # of previous stage and final window size of current stage
- window_size = self.window_spec[cur_stage - 1]
-
- if self.global_att_blocks is not None:
- window_size = 0 if i in self.global_att_blocks else window_size
-
- if i - 1 in self.stage_ends:
- dim_out = int(embed_dim * dim_mul)
- num_heads = int(num_heads * head_mul)
- cur_stage += 1
-
- block = MultiScaleBlock(
- dim=embed_dim,
- dim_out=dim_out,
- num_heads=num_heads,
- drop_path=dpr[i],
- q_stride=self.q_stride if i in self.q_pool_blocks else None,
- window_size=window_size,
- )
-
- embed_dim = dim_out
- self.blocks.append(block)
-
- self.channel_list = (
- [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
- if return_interm_layers
- else [self.blocks[-1].dim_out]
- )
-
- def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
- """Generate positional embeddings by interpolating and combining window and background embeddings."""
- h, w = hw
- window_embed = self.pos_embed_window
- pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
- pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
- pos_embed = pos_embed.permute(0, 2, 3, 1)
- return pos_embed
-
- def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
- """Performs hierarchical vision transformer forward pass, returning multiscale feature maps."""
- x = self.patch_embed(x)
- # x: (B, H, W, C)
-
- # Add pos embed
- x = x + self._get_pos_embed(x.shape[1:3])
-
- outputs = []
- for i, blk in enumerate(self.blocks):
- x = blk(x)
- if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):
- feats = x.permute(0, 3, 1, 2)
- outputs.append(feats)
-
- return outputs
diff --git a/ultralytics/models/sam2/modules/sam2.py b/ultralytics/models/sam2/modules/sam2.py
deleted file mode 100644
index 363241a19..000000000
--- a/ultralytics/models/sam2/modules/sam2.py
+++ /dev/null
@@ -1,804 +0,0 @@
-# Ultralytics YOLO 🚀, AGPL-3.0 license
-
-import torch
-import torch.distributed
-import torch.nn.functional as F
-from torch.nn.init import trunc_normal_
-
-from ultralytics.models.sam.modules.encoders import PromptEncoder
-from ultralytics.nn.modules import MLP
-
-from .decoders import MaskDecoder
-from .sam2_blocks import TwoWayTransformer
-from .utils import get_1d_sine_pe, select_closest_cond_frames
-
-# a large negative value as a placeholder score for missing objects
-NO_OBJ_SCORE = -1024.0
-
-
-class SAM2Model(torch.nn.Module):
- """SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities."""
-
- mask_threshold: float = 0.0
-
- def __init__(
- self,
- image_encoder,
- memory_attention,
- memory_encoder,
- num_maskmem=7, # default 1 input frame + 6 previous frames
- image_size=512,
- backbone_stride=16, # stride of the image backbone output
- sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob
- sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob
- # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks
- binarize_mask_from_pts_for_mem_enc=False,
- use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder
- # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
- # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
- # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
- max_cond_frames_in_attn=-1,
- # on the first frame, whether to directly add the no-memory embedding to the image feature
- # (instead of using the transformer encoder)
- directly_add_no_mem_embed=False,
- # whether to use high-resolution feature maps in the SAM mask decoder
- use_high_res_features_in_sam=False,
- # whether to output multiple (3) masks for the first click on initial conditioning frames
- multimask_output_in_sam=False,
- # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
- # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
- multimask_min_pt_num=1,
- multimask_max_pt_num=1,
- # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
- multimask_output_for_tracking=False,
- # Whether to use multimask tokens for obj ptr; Only relevant when both
- # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True
- use_multimask_token_for_obj_ptr: bool = False,
- # whether to use sigmoid to restrict ious prediction to [0-1]
- iou_prediction_use_sigmoid=False,
- # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
- # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
- # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
- memory_temporal_stride_for_eval=1,
- # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
- # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
- add_all_frames_to_correct_as_cond=False,
- # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
- non_overlap_masks_for_mem_enc=False,
- # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
- use_obj_ptrs_in_encoder=False,
- # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`)
- max_obj_ptrs_in_encoder=16,
- # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`)
- add_tpos_enc_to_obj_ptrs=True,
- # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
- # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
- proj_tpos_enc_in_obj_ptrs=False,
- # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
- # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
- only_obj_ptrs_in_the_past_for_eval=False,
- # Whether to predict if there is an object in the frame
- pred_obj_scores: bool = False,
- # Whether to use an MLP to predict object scores
- pred_obj_scores_mlp: bool = False,
- # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True;
- # Whether to have a fixed no obj pointer when there is no object present
- # or to use it as an additive embedding with obj_ptr produced by decoder
- fixed_no_obj_ptr: bool = False,
- # Soft no object, i.e. mix in no_obj_ptr softly,
- # hope to make recovery easier if there is a mistake and mitigate accumulation of errors
- soft_no_obj_ptr: bool = False,
- use_mlp_for_obj_ptr_proj: bool = False,
- # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
- sam_mask_decoder_extra_args=None,
- compile_image_encoder: bool = False,
- ):
- """Initializes SAM2Model model with image encoder, memory attention, and memory encoder components."""
- super().__init__()
-
- # Part 1: the image backbone
- self.image_encoder = image_encoder
- # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
- self.use_high_res_features_in_sam = use_high_res_features_in_sam
- self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
- self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
- self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
- if use_obj_ptrs_in_encoder:
- # A conv layer to downsample the mask prompt to stride 4 (the same stride as
- # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
- # so that it can be fed into the SAM mask decoder to generate a pointer.
- self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
- self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
- if proj_tpos_enc_in_obj_ptrs:
- assert add_tpos_enc_to_obj_ptrs # these options need to be used together
- self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
- self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
-
- # Part 2: memory attention to condition current frame's visual features
- # with memories (and obj ptrs) from past frames
- self.memory_attention = memory_attention
- self.hidden_dim = memory_attention.d_model
-
- # Part 3: memory encoder for the previous frame's outputs
- self.memory_encoder = memory_encoder
- self.mem_dim = self.hidden_dim
- if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
- # if there is compression of memories along channel dim
- self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
- self.num_maskmem = num_maskmem # Number of memories accessible
- # Temporal encoding of the memories
- self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
- trunc_normal_(self.maskmem_tpos_enc, std=0.02)
- # a single token to indicate no memory embedding from previous frames
- self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
- self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
- trunc_normal_(self.no_mem_embed, std=0.02)
- trunc_normal_(self.no_mem_pos_enc, std=0.02)
- self.directly_add_no_mem_embed = directly_add_no_mem_embed
- # Apply sigmoid to the output raw mask logits (to turn them from
- # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
- self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
- self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
- self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
- self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
- self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
- # On frames with mask input, whether to directly output the input mask without
- # using a SAM prompt encoder + mask decoder
- self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
- self.multimask_output_in_sam = multimask_output_in_sam
- self.multimask_min_pt_num = multimask_min_pt_num
- self.multimask_max_pt_num = multimask_max_pt_num
- self.multimask_output_for_tracking = multimask_output_for_tracking
- self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
- self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
-
- # Part 4: SAM-style prompt encoder (for both mask and point inputs)
- # and SAM-style mask decoder for the final mask output
- self.image_size = image_size
- self.backbone_stride = backbone_stride
- self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
- self.pred_obj_scores = pred_obj_scores
- self.pred_obj_scores_mlp = pred_obj_scores_mlp
- self.fixed_no_obj_ptr = fixed_no_obj_ptr
- self.soft_no_obj_ptr = soft_no_obj_ptr
- if self.fixed_no_obj_ptr:
- assert self.pred_obj_scores
- assert self.use_obj_ptrs_in_encoder
- if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
- self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
- trunc_normal_(self.no_obj_ptr, std=0.02)
- self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
-
- self._build_sam_heads()
- self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
- self.max_cond_frames_in_attn = max_cond_frames_in_attn
-
- # Model compilation
- if compile_image_encoder:
- # Compile the forward function (not the full module) to allow loading checkpoints.
- print("Image encoder compilation is enabled. First forward pass will be slow.")
- self.image_encoder.forward = torch.compile(
- self.image_encoder.forward,
- mode="max-autotune",
- fullgraph=True,
- dynamic=False,
- )
-
- @property
- def device(self):
- """Returns the device on which the model's parameters are stored."""
- return next(self.parameters()).device
-
- def forward(self, *args, **kwargs):
- """Processes input frames and prompts to generate object masks and scores in video sequences."""
- raise NotImplementedError(
- "Please use the corresponding methods in SAM2VideoPredictor for inference."
- "See notebooks/video_predictor_example.ipynb for an example."
- )
-
- def _build_sam_heads(self):
- """Builds SAM-style prompt encoder and mask decoder for image segmentation tasks."""
- self.sam_prompt_embed_dim = self.hidden_dim
- self.sam_image_embedding_size = self.image_size // self.backbone_stride
-
- # build PromptEncoder and MaskDecoder from SAM
- # (their hyperparameters like `mask_in_chans=16` are from SAM code)
- self.sam_prompt_encoder = PromptEncoder(
- embed_dim=self.sam_prompt_embed_dim,
- image_embedding_size=(
- self.sam_image_embedding_size,
- self.sam_image_embedding_size,
- ),
- input_image_size=(self.image_size, self.image_size),
- mask_in_chans=16,
- )
- self.sam_mask_decoder = MaskDecoder(
- num_multimask_outputs=3,
- transformer=TwoWayTransformer(
- depth=2,
- embedding_dim=self.sam_prompt_embed_dim,
- mlp_dim=2048,
- num_heads=8,
- ),
- transformer_dim=self.sam_prompt_embed_dim,
- iou_head_depth=3,
- iou_head_hidden_dim=256,
- use_high_res_features=self.use_high_res_features_in_sam,
- iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
- pred_obj_scores=self.pred_obj_scores,
- pred_obj_scores_mlp=self.pred_obj_scores_mlp,
- use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
- **(self.sam_mask_decoder_extra_args or {}),
- )
- if self.use_obj_ptrs_in_encoder:
- # a linear projection on SAM output tokens to turn them into object pointers
- self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
- if self.use_mlp_for_obj_ptr_proj:
- self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
- else:
- self.obj_ptr_proj = torch.nn.Identity()
- if self.proj_tpos_enc_in_obj_ptrs:
- # a linear projection on temporal positional encoding in object pointers to
- # avoid potential interference with spatial positional encoding
- self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
- else:
- self.obj_ptr_tpos_proj = torch.nn.Identity()
-
- def _forward_sam_heads(
- self,
- backbone_features,
- point_inputs=None,
- mask_inputs=None,
- high_res_features=None,
- multimask_output=False,
- ):
- """
- Forward SAM prompt encoders and mask heads.
-
- Args:
- backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
- point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts.
- 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute
- pixel-unit coordinates in (x, y) format for P input points.
- 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks,
- 0 means negative clicks, and -1 means padding.
- mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the
- same spatial size as the image.
- high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes
- (B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps
- for SAM decoder.
- multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False,
- output only 1 mask and its IoU estimate.
-
- Returns:
- (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]):
- low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
- high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
- ious: Tensor of shape (B, M) with estimated IoU for each output mask.
- low_res_masks: Tensor of shape (B, 1, H*4, W*4) with best low-resolution mask.
- high_res_masks: Tensor of shape (B, 1, H*16, W*16) with best high-resolution mask.
- obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
- object_score_logits: Tensor of shape (B,) with object score logits.
-
- Where M is 3 if multimask_output=True, and 1 if multimask_output=False.
-
- Examples:
- >>> backbone_features = torch.rand(1, 256, 32, 32)
- >>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])}
- >>> mask_inputs = torch.rand(1, 1, 512, 512)
- >>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)
- >>> low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits = results
- """
- B = backbone_features.size(0)
- device = backbone_features.device
- assert backbone_features.size(1) == self.sam_prompt_embed_dim
- assert backbone_features.size(2) == self.sam_image_embedding_size
- assert backbone_features.size(3) == self.sam_image_embedding_size
-
- # a) Handle point prompts
- if point_inputs is not None:
- sam_point_coords = point_inputs["point_coords"]
- sam_point_labels = point_inputs["point_labels"]
- assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
- else:
- # If no points are provide, pad with an empty point (with label -1)
- sam_point_coords = torch.zeros(B, 1, 2, device=device)
- sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
-
- # b) Handle mask prompts
- if mask_inputs is not None:
- # If mask_inputs is provided, downsize it into low-res mask input if needed
- # and feed it as a dense mask prompt into the SAM mask encoder
- assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
- if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
- sam_mask_prompt = F.interpolate(
- mask_inputs.float(),
- size=self.sam_prompt_encoder.mask_input_size,
- align_corners=False,
- mode="bilinear",
- antialias=True, # use antialias for downsampling
- )
- else:
- sam_mask_prompt = mask_inputs
- else:
- # Otherwise, simply feed None (and SAM's prompt encoder will add
- # a learned `no_mask_embed` to indicate no mask input in this case).
- sam_mask_prompt = None
-
- sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
- points=(sam_point_coords, sam_point_labels),
- boxes=None,
- masks=sam_mask_prompt,
- )
- (
- low_res_multimasks,
- ious,
- sam_output_tokens,
- object_score_logits,
- ) = self.sam_mask_decoder(
- image_embeddings=backbone_features,
- image_pe=self.sam_prompt_encoder.get_dense_pe(),
- sparse_prompt_embeddings=sparse_embeddings,
- dense_prompt_embeddings=dense_embeddings,
- multimask_output=multimask_output,
- repeat_image=False, # the image is already batched
- high_res_features=high_res_features,
- )
- if self.pred_obj_scores:
- is_obj_appearing = object_score_logits > 0
-
- # Mask used for spatial memories is always a *hard* choice between obj and no obj,
- # consistent with the actual mask prediction
- low_res_multimasks = torch.where(
- is_obj_appearing[:, None, None],
- low_res_multimasks,
- NO_OBJ_SCORE,
- )
-
- # convert masks from possibly bfloat16 (or float16) to float32
- # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
- low_res_multimasks = low_res_multimasks.float()
- high_res_multimasks = F.interpolate(
- low_res_multimasks,
- size=(self.image_size, self.image_size),
- mode="bilinear",
- align_corners=False,
- )
-
- sam_output_token = sam_output_tokens[:, 0]
- if multimask_output:
- # take the best mask prediction (with the highest IoU estimation)
- best_iou_inds = torch.argmax(ious, dim=-1)
- batch_inds = torch.arange(B, device=device)
- low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
- high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
- if sam_output_tokens.size(1) > 1:
- sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
- else:
- low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
-
- # Extract object pointer from the SAM output token (with occlusion handling)
- obj_ptr = self.obj_ptr_proj(sam_output_token)
- if self.pred_obj_scores:
- # Allow *soft* no obj ptr, unlike for masks
- if self.soft_no_obj_ptr:
- # Only hard possible with gt
- assert not self.teacher_force_obj_scores_for_mem
- lambda_is_obj_appearing = object_score_logits.sigmoid()
- else:
- lambda_is_obj_appearing = is_obj_appearing.float()
-
- if self.fixed_no_obj_ptr:
- obj_ptr = lambda_is_obj_appearing * obj_ptr
- obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
-
- return (
- low_res_multimasks,
- high_res_multimasks,
- ious,
- low_res_masks,
- high_res_masks,
- obj_ptr,
- object_score_logits,
- )
-
- def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
- """Processes mask inputs to generate output mask logits and object pointers without using SAM."""
- # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
- out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
- mask_inputs_float = mask_inputs.float()
- high_res_masks = mask_inputs_float * out_scale + out_bias
- low_res_masks = F.interpolate(
- high_res_masks,
- size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
- align_corners=False,
- mode="bilinear",
- antialias=True, # use antialias for downsampling
- )
- # a dummy IoU prediction of all 1's under mask input
- ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
- if not self.use_obj_ptrs_in_encoder:
- # all zeros as a dummy object pointer (of shape [B, C])
- obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)
- else:
- # produce an object pointer using the SAM decoder from the mask input
- _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
- backbone_features=backbone_features,
- mask_inputs=self.mask_downsample(mask_inputs_float),
- high_res_features=high_res_features,
- )
- # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
- # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
- # on the object_scores from the SAM decoder.
- is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
- is_obj_appearing = is_obj_appearing[..., None]
- lambda_is_obj_appearing = is_obj_appearing.float()
- object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
- if self.pred_obj_scores:
- if self.fixed_no_obj_ptr:
- obj_ptr = lambda_is_obj_appearing * obj_ptr
- obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
-
- return (
- low_res_masks,
- high_res_masks,
- ious,
- low_res_masks,
- high_res_masks,
- obj_ptr,
- object_score_logits,
- )
-
- def forward_image(self, img_batch: torch.Tensor):
- """Process image batch through encoder to extract multi-level features for SAM model."""
- backbone_out = self.image_encoder(img_batch)
- if self.use_high_res_features_in_sam:
- # precompute projected level 0 and level 1 features in SAM decoder
- # to avoid running it again on every SAM click
- backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
- backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
- return backbone_out
-
- def _prepare_backbone_features(self, backbone_out):
- """Prepare and flatten visual features from the image backbone output."""
- backbone_out = backbone_out.copy()
- assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
- assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
-
- feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
- vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
-
- feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
- # flatten NxCxHxW to HWxNxC
- vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
- vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
-
- return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
-
- def _prepare_memory_conditioned_features(
- self,
- frame_idx,
- is_init_cond_frame,
- current_vision_feats,
- current_vision_pos_embeds,
- feat_sizes,
- output_dict,
- num_frames,
- track_in_reverse=False, # tracking in reverse time order (for demo usage)
- ):
- """Prepares memory-conditioned features by fusing current frame's visual features with previous memories."""
- B = current_vision_feats[-1].size(1) # batch size on this frame
- C = self.hidden_dim
- H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
- device = current_vision_feats[-1].device
- # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
- # In this case, we skip the fusion with any memory.
- if self.num_maskmem == 0: # Disable memory and skip fusion
- pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
- return pix_feat
-
- num_obj_ptr_tokens = 0
- # Step 1: condition the visual features of the current frame on previous memories
- if not is_init_cond_frame:
- # Retrieve the memories encoded with the maskmem backbone
- to_cat_memory, to_cat_memory_pos_embed = [], []
- # Add conditioning frames's output first (all cond frames have t_pos=0 for
- # when getting temporal positional embedding below)
- assert len(output_dict["cond_frame_outputs"]) > 0
- # Select a maximum number of temporally closest cond frames for cross attention
- cond_outputs = output_dict["cond_frame_outputs"]
- selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
- frame_idx, cond_outputs, self.max_cond_frames_in_attn
- )
- t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
- # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
- # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
- # We also allow taking the memory frame non-consecutively (with r>1), in which case
- # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
- r = self.memory_temporal_stride_for_eval
- for t_pos in range(1, self.num_maskmem):
- t_rel = self.num_maskmem - t_pos # how many frames before current frame
- if t_rel == 1:
- # for t_rel == 1, we take the last frame (regardless of r)
- if not track_in_reverse:
- # the frame immediately before this frame (i.e. frame_idx - 1)
- prev_frame_idx = frame_idx - t_rel
- else:
- # the frame immediately after this frame (i.e. frame_idx + 1)
- prev_frame_idx = frame_idx + t_rel
- else:
- # for t_rel >= 2, we take the memory frame from every r-th frames
- if not track_in_reverse:
- # first find the nearest frame among every r-th frames before this frame
- # for r=1, this would be (frame_idx - 2)
- prev_frame_idx = ((frame_idx - 2) // r) * r
- # then seek further among every r-th frames
- prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
- else:
- # first find the nearest frame among every r-th frames after this frame
- # for r=1, this would be (frame_idx + 2)
- prev_frame_idx = -(-(frame_idx + 2) // r) * r
- # then seek further among every r-th frames
- prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
- out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
- if out is None:
- # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
- # frames, we still attend to it as if it's a non-conditioning frame.
- out = unselected_cond_outputs.get(prev_frame_idx, None)
- t_pos_and_prevs.append((t_pos, out))
-
- for t_pos, prev in t_pos_and_prevs:
- if prev is None:
- continue # skip padding frames
- # "maskmem_features" might have been offloaded to CPU in demo use cases,
- # so we load it back to GPU (it's a no-op if it's already on GPU).
- feats = prev["maskmem_features"].cuda(non_blocking=True)
- to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
- # Spatial positional encoding (it might have been offloaded to CPU in eval)
- maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
- maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
- # Temporal positional encoding
- maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
- to_cat_memory_pos_embed.append(maskmem_enc)
-
- # Construct the list of past object pointers
- if self.use_obj_ptrs_in_encoder:
- max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
- # First add those object pointers from selected conditioning frames
- # (optionally, only include object pointers in the past during evaluation)
- if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
- ptr_cond_outputs = {
- t: out
- for t, out in selected_cond_outputs.items()
- if (t >= frame_idx if track_in_reverse else t <= frame_idx)
- }
- else:
- ptr_cond_outputs = selected_cond_outputs
- pos_and_ptrs = [
- # Temporal pos encoding contains how far away each pointer is from current frame
- (abs(frame_idx - t), out["obj_ptr"])
- for t, out in ptr_cond_outputs.items()
- ]
- # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
- for t_diff in range(1, max_obj_ptrs_in_encoder):
- t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
- if t < 0 or (num_frames is not None and t >= num_frames):
- break
- out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None))
- if out is not None:
- pos_and_ptrs.append((t_diff, out["obj_ptr"]))
- # If we have at least one object pointer, add them to the across attention
- if len(pos_and_ptrs) > 0:
- pos_list, ptrs_list = zip(*pos_and_ptrs)
- # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
- obj_ptrs = torch.stack(ptrs_list, dim=0)
- # a temporal positional embedding based on how far each object pointer is from
- # the current frame (sine embedding normalized by the max pointer num).
- if self.add_tpos_enc_to_obj_ptrs:
- t_diff_max = max_obj_ptrs_in_encoder - 1
- tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
- obj_pos = torch.tensor(pos_list, device=device)
- obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
- obj_pos = self.obj_ptr_tpos_proj(obj_pos)
- obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
- else:
- obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
- if self.mem_dim < C:
- # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
- obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
- obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
- obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
- to_cat_memory.append(obj_ptrs)
- to_cat_memory_pos_embed.append(obj_pos)
- num_obj_ptr_tokens = obj_ptrs.shape[0]
- else:
- num_obj_ptr_tokens = 0
- else:
- # for initial conditioning frames, encode them without using any previous memory
- if self.directly_add_no_mem_embed:
- # directly add no-mem embedding (instead of using the transformer encoder)
- pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
- pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
- return pix_feat_with_mem
-
- # Use a dummy token on the first frame (to avoid empty memory input to transformer encoder)
- to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
- to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
-
- # Step 2: Concatenate the memories and forward through the transformer encoder
- memory = torch.cat(to_cat_memory, dim=0)
- memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
-
- pix_feat_with_mem = self.memory_attention(
- curr=current_vision_feats,
- curr_pos=current_vision_pos_embeds,
- memory=memory,
- memory_pos=memory_pos_embed,
- num_obj_ptr_tokens=num_obj_ptr_tokens,
- )
- # reshape the output (HW)BC => BCHW
- pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
- return pix_feat_with_mem
-
- def _encode_new_memory(
- self,
- current_vision_feats,
- feat_sizes,
- pred_masks_high_res,
- is_mask_from_pts,
- ):
- """Encodes the current frame's features and predicted masks into a new memory representation."""
- B = current_vision_feats[-1].size(1) # batch size on this frame
- C = self.hidden_dim
- H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
- # top-level feature, (HW)BC => BCHW
- pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
- if self.non_overlap_masks_for_mem_enc and not self.training:
- # optionally, apply non-overlapping constraints to the masks (it's applied
- # in the batch dimension and should only be used during eval, where all
- # the objects come from the same video under batch size 1).
- pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
- # scale the raw mask logits with a temperature before applying sigmoid
- binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
- if binarize and not self.training:
- mask_for_mem = (pred_masks_high_res > 0).float()
- else:
- # apply sigmoid on the raw mask logits to turn them into range (0, 1)
- mask_for_mem = torch.sigmoid(pred_masks_high_res)
- # apply scale and bias terms to the sigmoid probabilities
- if self.sigmoid_scale_for_mem_enc != 1.0:
- mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
- if self.sigmoid_bias_for_mem_enc != 0.0:
- mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
- maskmem_out = self.memory_encoder(
- pix_feat,
- mask_for_mem,
- skip_mask_sigmoid=True, # sigmoid already applied
- )
- maskmem_features = maskmem_out["vision_features"]
- maskmem_pos_enc = maskmem_out["vision_pos_enc"]
-
- return maskmem_features, maskmem_pos_enc
-
- def track_step(
- self,
- frame_idx,
- is_init_cond_frame,
- current_vision_feats,
- current_vision_pos_embeds,
- feat_sizes,
- point_inputs,
- mask_inputs,
- output_dict,
- num_frames,
- track_in_reverse=False, # tracking in reverse time order (for demo usage)
- # Whether to run the memory encoder on the predicted masks. Sometimes we might want
- # to skip the memory encoder with `run_mem_encoder=False`. For example,
- # in demo we might call `track_step` multiple times for each user click,
- # and only encode the memory when the user finalizes their clicks. And in ablation
- # settings like SAM training on static images, we don't need the memory encoder.
- run_mem_encoder=True,
- # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
- prev_sam_mask_logits=None,
- ):
- """Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
- current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
- # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
- if len(current_vision_feats) > 1:
- high_res_features = [
- x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
- for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
- ]
- else:
- high_res_features = None
- if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
- # When use_mask_input_as_output_without_sam=True, we directly output the mask input
- # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
- pix_feat = current_vision_feats[-1].permute(1, 2, 0)
- pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
- sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
- else:
- # fused the visual feature with previous memory features in the memory bank
- pix_feat_with_mem = self._prepare_memory_conditioned_features(
- frame_idx=frame_idx,
- is_init_cond_frame=is_init_cond_frame,
- current_vision_feats=current_vision_feats[-1:],
- current_vision_pos_embeds=current_vision_pos_embeds[-1:],
- feat_sizes=feat_sizes[-1:],
- output_dict=output_dict,
- num_frames=num_frames,
- track_in_reverse=track_in_reverse,
- )
- # apply SAM-style segmentation head
- # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
- # e.g. in demo where such logits come from earlier interaction instead of correction sampling
- # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
- if prev_sam_mask_logits is not None:
- assert point_inputs is not None and mask_inputs is None
- mask_inputs = prev_sam_mask_logits
- multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
- sam_outputs = self._forward_sam_heads(
- backbone_features=pix_feat_with_mem,
- point_inputs=point_inputs,
- mask_inputs=mask_inputs,
- high_res_features=high_res_features,
- multimask_output=multimask_output,
- )
- (
- _,
- _,
- _,
- low_res_masks,
- high_res_masks,
- obj_ptr,
- _,
- ) = sam_outputs
-
- current_out["pred_masks"] = low_res_masks
- current_out["pred_masks_high_res"] = high_res_masks
- current_out["obj_ptr"] = obj_ptr
-
- # Finally run the memory encoder on the predicted mask to encode
- # it into a new memory feature (that can be used in future frames)
- if run_mem_encoder and self.num_maskmem > 0:
- high_res_masks_for_mem_enc = high_res_masks
- maskmem_features, maskmem_pos_enc = self._encode_new_memory(
- current_vision_feats=current_vision_feats,
- feat_sizes=feat_sizes,
- pred_masks_high_res=high_res_masks_for_mem_enc,
- is_mask_from_pts=(point_inputs is not None),
- )
- current_out["maskmem_features"] = maskmem_features
- current_out["maskmem_pos_enc"] = maskmem_pos_enc
- else:
- current_out["maskmem_features"] = None
- current_out["maskmem_pos_enc"] = None
-
- return current_out
-
- def _use_multimask(self, is_init_cond_frame, point_inputs):
- """Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
- num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
- multimask_output = (
- self.multimask_output_in_sam
- and (is_init_cond_frame or self.multimask_output_for_tracking)
- and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
- )
- return multimask_output
-
- def _apply_non_overlapping_constraints(self, pred_masks):
- """Applies non-overlapping constraints to object masks, keeping highest scoring object at each location."""
- batch_size = pred_masks.size(0)
- if batch_size == 1:
- return pred_masks
-
- device = pred_masks.device
- # "max_obj_inds": object index of the object with the highest score at each location
- max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
- # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
- batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
- keep = max_obj_inds == batch_obj_inds
- # suppress overlapping regions' scores below -10.0 so that the foreground regions
- # don't overlap (here sigmoid(-10.0)=4.5398e-05)
- pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
- return pred_masks
diff --git a/ultralytics/models/sam2/modules/sam2_blocks.py b/ultralytics/models/sam2/modules/sam2_blocks.py
deleted file mode 100644
index 67dc587e9..000000000
--- a/ultralytics/models/sam2/modules/sam2_blocks.py
+++ /dev/null
@@ -1,715 +0,0 @@
-# Ultralytics YOLO 🚀, AGPL-3.0 license
-
-import copy
-import math
-from functools import partial
-from typing import Optional, Tuple, Type, Union
-
-import torch
-import torch.nn.functional as F
-from torch import Tensor, nn
-
-from ultralytics.models.sam.modules.transformer import (
- Attention,
-)
-from ultralytics.models.sam.modules.transformer import (
- TwoWayAttentionBlock as SAMTwoWayAttentionBlock,
-)
-from ultralytics.models.sam.modules.transformer import (
- TwoWayTransformer as SAMTwoWayTransformer,
-)
-from ultralytics.nn.modules import MLP, LayerNorm2d
-
-from .utils import apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition
-
-
-class DropPath(nn.Module):
- """Implements stochastic depth regularization for neural networks during training."""
-
- def __init__(self, drop_prob=0.0, scale_by_keep=True):
- """Initialize DropPath module with specified drop probability and scaling option."""
- super(DropPath, self).__init__()
- self.drop_prob = drop_prob
- self.scale_by_keep = scale_by_keep
-
- def forward(self, x):
- """Applies stochastic depth to input tensor during training, with optional scaling."""
- if self.drop_prob == 0.0 or not self.training:
- return x
- keep_prob = 1 - self.drop_prob
- shape = (x.shape[0],) + (1,) * (x.ndim - 1)
- random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
- if keep_prob > 0.0 and self.scale_by_keep:
- random_tensor.div_(keep_prob)
- return x * random_tensor
-
-
-class MaskDownSampler(nn.Module):
- """Downsamples and embeds masks using convolutional layers and layer normalization for efficient processing."""
-
- def __init__(
- self,
- embed_dim=256,
- kernel_size=4,
- stride=4,
- padding=0,
- total_stride=16,
- activation=nn.GELU,
- ):
- """Initializes a mask downsampler module for progressive downsampling and channel expansion."""
- super().__init__()
- num_layers = int(math.log2(total_stride) // math.log2(stride))
- assert stride**num_layers == total_stride
- self.encoder = nn.Sequential()
- mask_in_chans, mask_out_chans = 1, 1
- for _ in range(num_layers):
- mask_out_chans = mask_in_chans * (stride**2)
- self.encoder.append(
- nn.Conv2d(
- mask_in_chans,
- mask_out_chans,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- )
- )
- self.encoder.append(LayerNorm2d(mask_out_chans))
- self.encoder.append(activation())
- mask_in_chans = mask_out_chans
-
- self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
-
- def forward(self, x):
- """Downsamples and encodes input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
- return self.encoder(x)
-
-
-# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
-class CXBlock(nn.Module):
- """
- ConvNeXt Block for efficient feature extraction in convolutional neural networks.
-
- This block implements a modified version of the ConvNeXt architecture, offering two equivalent
- implementations for improved performance and flexibility.
-
- Attributes:
- dwconv (nn.Conv2d): Depthwise convolution layer.
- norm (LayerNorm2d): Layer normalization applied to channels.
- pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.
- act (nn.GELU): GELU activation function.
- pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.
- gamma (nn.Parameter | None): Learnable scale parameter for layer scaling.
- drop_path (nn.Module): DropPath layer for stochastic depth regularization.
-
- Methods:
- forward: Processes the input tensor through the ConvNeXt block.
-
- Examples:
- >>> import torch
- >>> x = torch.randn(1, 64, 56, 56)
- >>> block = CXBlock(dim=64, kernel_size=7, padding=3)
- >>> output = block(x)
- >>> print(output.shape)
- torch.Size([1, 64, 56, 56])
- """
-
- def __init__(
- self,
- dim,
- kernel_size=7,
- padding=3,
- drop_path=0.0,
- layer_scale_init_value=1e-6,
- use_dwconv=True,
- ):
- """
- Initialize a ConvNeXt Block.
-
- This block implements a ConvNeXt architecture with optional depthwise convolution, layer normalization,
- pointwise convolutions, and GELU activation.
-
- Args:
- dim (int): Number of input channels.
- kernel_size (int): Size of the convolutional kernel. Default is 7.
- padding (int): Padding size for the convolution. Default is 3.
- drop_path (float): Stochastic depth rate. Default is 0.0.
- layer_scale_init_value (float): Initial value for Layer Scale. Default is 1e-6.
- use_dwconv (bool): Whether to use depthwise convolution. Default is True.
-
- Attributes:
- dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer.
- norm (LayerNorm2d): Layer normalization applied to the output of dwconv.
- pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.
- act (nn.GELU): GELU activation function.
- pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.
- gamma (nn.Parameter | None): Learnable scale parameter for the residual path.
-
- Examples:
- >>> block = CXBlock(dim=64, kernel_size=7, padding=3)
- >>> x = torch.randn(1, 64, 32, 32)
- >>> output = block(x)
- >>> print(output.shape)
- torch.Size([1, 64, 32, 32])
- """
- super().__init__()
- self.dwconv = nn.Conv2d(
- dim,
- dim,
- kernel_size=kernel_size,
- padding=padding,
- groups=dim if use_dwconv else 1,
- ) # depthwise conv
- self.norm = LayerNorm2d(dim, eps=1e-6)
- self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
- self.act = nn.GELU()
- self.pwconv2 = nn.Linear(4 * dim, dim)
- self.gamma = (
- nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
- if layer_scale_init_value > 0
- else None
- )
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
-
- def forward(self, x):
- """Applies ConvNeXt block operations to input tensor, including convolutions and residual connection."""
- input = x
- x = self.dwconv(x)
- x = self.norm(x)
- x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
- x = self.pwconv1(x)
- x = self.act(x)
- x = self.pwconv2(x)
- if self.gamma is not None:
- x = self.gamma * x
- x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
-
- x = input + self.drop_path(x)
- return x
-
-
-class Fuser(nn.Module):
- """
- A module for fusing features through multiple layers of a neural network.
-
- This class applies a series of identical layers to an input tensor, optionally projecting the input first.
-
- Attributes:
- proj (nn.Module): An optional input projection layer. Identity if no projection is needed.
- layers (nn.ModuleList): A list of identical layers to be applied sequentially.
-
- Methods:
- forward: Applies the fuser to an input tensor.
-
- Examples:
- >>> layer = CXBlock(dim=256)
- >>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True)
- >>> x = torch.randn(1, 256, 32, 32)
- >>> output = fuser(x)
- >>> print(output.shape)
- torch.Size([1, 256, 32, 32])
- """
-
- def __init__(self, layer, num_layers, dim=None, input_projection=False):
- """
- Initializes the Fuser module.
-
- This module creates a sequence of identical layers and optionally applies an input projection.
-
- Args:
- layer (nn.Module): The layer to be replicated in the fuser.
- num_layers (int): The number of times to replicate the layer.
- dim (int | None): The dimension for input projection, if used.
- input_projection (bool): Whether to use input projection.
-
- Attributes:
- proj (nn.Module): The input projection layer, or nn.Identity if not used.
- layers (nn.ModuleList): A list of replicated layers.
-
- Examples:
- >>> layer = nn.Linear(64, 64)
- >>> fuser = Fuser(layer, num_layers=3, dim=64, input_projection=True)
- >>> input_tensor = torch.randn(1, 64)
- >>> output = fuser(input_tensor)
- """
- super().__init__()
- self.proj = nn.Identity()
- self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
-
- if input_projection:
- assert dim is not None
- self.proj = nn.Conv2d(dim, dim, kernel_size=1)
-
- def forward(self, x):
- """Applies a series of layers to the input tensor, optionally projecting it first."""
- x = self.proj(x)
- for layer in self.layers:
- x = layer(x)
- return x
-
-
-class TwoWayAttentionBlock(SAMTwoWayAttentionBlock):
- """
- A two-way attention block for performing self-attention and cross-attention in both directions.
-
- This block extends the SAMTwoWayAttentionBlock and consists of four main components: self-attention on
- sparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and
- cross-attention from dense to sparse inputs.
-
- Attributes:
- self_attn (Attention): Self-attention layer for queries.
- norm1 (nn.LayerNorm): Layer normalization after the first attention block.
- cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
- norm2 (nn.LayerNorm): Layer normalization after the second attention block.
- mlp (MLP): MLP block for transforming query embeddings.
- norm3 (nn.LayerNorm): Layer normalization after the MLP block.
- norm4 (nn.LayerNorm): Layer normalization after the third attention block.
- cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
- skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer.
-
- Methods:
- forward: Processes input through the attention blocks and MLP.
-
- Examples:
- >>> block = TwoWayAttentionBlock(embedding_dim=256, num_heads=8)
- >>> sparse_input = torch.randn(1, 100, 256)
- >>> dense_input = torch.randn(1, 256, 16, 16)
- >>> sparse_output, dense_output = block(sparse_input, dense_input)
- """
-
- def __init__(
- self,
- embedding_dim: int,
- num_heads: int,
- mlp_dim: int = 2048,
- activation: Type[nn.Module] = nn.ReLU,
- attention_downsample_rate: int = 2,
- skip_first_layer_pe: bool = False,
- ) -> None:
- """
- Initializes a TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
-
- This block consists of four main layers: self-attention on sparse inputs, cross-attention of sparse inputs
- to dense inputs, an MLP block on sparse inputs, and cross-attention of dense inputs to sparse inputs.
-
- Args:
- embedding_dim (int): The channel dimension of the embeddings.
- num_heads (int): The number of heads in the attention layers.
- mlp_dim (int): The hidden dimension of the MLP block.
- activation (Type[nn.Module]): The activation function of the MLP block.
- attention_downsample_rate (int): The downsample rate for attention computations.
- skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
-
- Attributes:
- self_attn (Attention): The self-attention layer for the queries.
- norm1 (nn.LayerNorm): Layer normalization following the first attention block.
- cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
- norm2 (nn.LayerNorm): Layer normalization following the second attention block.
- mlp (MLP): MLP block that transforms the query embeddings.
- norm3 (nn.LayerNorm): Layer normalization following the MLP block.
- norm4 (nn.LayerNorm): Layer normalization following the third attention block.
- cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
- skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
-
- Examples:
- >>> block = TwoWayAttentionBlock(embedding_dim=256, num_heads=8, mlp_dim=2048)
- >>> sparse_inputs = torch.randn(1, 100, 256)
- >>> dense_inputs = torch.randn(1, 256, 32, 32)
- >>> sparse_outputs, dense_outputs = block(sparse_inputs, dense_inputs)
- """
- super().__init__(embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate, skip_first_layer_pe)
- self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, act=activation)
-
-
-class TwoWayTransformer(SAMTwoWayTransformer):
- """
- A Two-Way Transformer module for simultaneous attention to image and query points.
-
- This class implements a specialized transformer decoder that attends to an input image using queries with
- supplied positional embeddings. It is particularly useful for tasks like object detection, image
- segmentation, and point cloud processing.
-
- Attributes:
- depth (int): Number of layers in the transformer.
- embedding_dim (int): Channel dimension for input embeddings.
- num_heads (int): Number of heads for multihead attention.
- mlp_dim (int): Internal channel dimension for the MLP block.
- layers (nn.ModuleList): List of TwoWayAttentionBlock layers comprising the transformer.
- final_attn_token_to_image (Attention): Final attention layer from queries to image.
- norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
-
- Methods:
- forward: Processes input image embeddings and query embeddings through the transformer.
-
- Examples:
- >>> transformer = TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
- >>> image_embedding = torch.randn(1, 256, 64, 64)
- >>> query_embedding = torch.randn(1, 100, 256)
- >>> output = transformer(image_embedding, query_embedding)
- """
-
- def __init__(
- self,
- depth: int,
- embedding_dim: int,
- num_heads: int,
- mlp_dim: int,
- activation: Type[nn.Module] = nn.ReLU,
- attention_downsample_rate: int = 2,
- ) -> None:
- """
- Initializes a TwoWayTransformer instance.
-
- This transformer decoder attends to an input image using queries with supplied positional embeddings.
- It is designed for tasks like object detection, image segmentation, and point cloud processing.
-
- Args:
- depth (int): Number of layers in the transformer.
- embedding_dim (int): Channel dimension for the input embeddings.
- num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
- mlp_dim (int): Channel dimension internal to the MLP block.
- activation (Type[nn.Module]): Activation function to use in the MLP block.
- attention_downsample_rate (int): Downsampling rate for attention computations.
-
- Attributes:
- depth (int): Number of layers in the transformer.
- embedding_dim (int): Channel dimension for the input embeddings.
- num_heads (int): Number of heads for multihead attention.
- mlp_dim (int): Internal channel dimension for the MLP block.
- layers (nn.ModuleList): List of TwoWayAttentionBlock layers comprising the transformer.
- final_attn_token_to_image (Attention): Final attention layer from queries to image.
- norm_final_attn (nn.LayerNorm): Layer normalization applied to the final queries.
-
- Examples:
- >>> transformer = TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
- >>> transformer
- TwoWayTransformer(
- (layers): ModuleList(
- (0-4): 5 x TwoWayAttentionBlock(...)
- )
- (final_attn_token_to_image): Attention(...)
- (norm_final_attn): LayerNorm(...)
- )
- """
- super().__init__(depth, embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate)
- self.layers = nn.ModuleList()
- for i in range(depth):
- self.layers.append(
- TwoWayAttentionBlock(
- embedding_dim=embedding_dim,
- num_heads=num_heads,
- mlp_dim=mlp_dim,
- activation=activation,
- attention_downsample_rate=attention_downsample_rate,
- skip_first_layer_pe=(i == 0),
- )
- )
-
-
-class RoPEAttention(Attention):
- """Implements rotary position encoding for attention mechanisms in transformer architectures."""
-
- def __init__(
- self,
- *args,
- rope_theta=10000.0,
- # whether to repeat q rope to match k length
- # this is needed for cross-attention to memories
- rope_k_repeat=False,
- feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
- **kwargs,
- ):
- """Initializes RoPEAttention with rotary position encoding for attention mechanisms."""
- super().__init__(*args, **kwargs)
-
- self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta)
- freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
- self.freqs_cis = freqs_cis
- self.rope_k_repeat = rope_k_repeat
-
- def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor:
- """Applies rotary position encoding and computes attention between query, key, and value tensors."""
- q = self.q_proj(q)
- k = self.k_proj(k)
- v = self.v_proj(v)
-
- # Separate into heads
- q = self._separate_heads(q, self.num_heads)
- k = self._separate_heads(k, self.num_heads)
- v = self._separate_heads(v, self.num_heads)
-
- # Apply rotary position encoding
- w = h = math.sqrt(q.shape[-2])
- self.freqs_cis = self.freqs_cis.to(q.device)
- if self.freqs_cis.shape[0] != q.shape[-2]:
- self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
- if q.shape[-2] != k.shape[-2]:
- assert self.rope_k_repeat
-
- num_k_rope = k.size(-2) - num_k_exclude_rope
- q, k[:, :, :num_k_rope] = apply_rotary_enc(
- q,
- k[:, :, :num_k_rope],
- freqs_cis=self.freqs_cis,
- repeat_freqs_k=self.rope_k_repeat,
- )
-
- # Attention
- _, _, _, c_per_head = q.shape
- attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
- attn = attn / math.sqrt(c_per_head)
- attn = torch.softmax(attn, dim=-1)
-
- # Get output
- out = attn @ v
-
- out = self._recombine_heads(out)
- out = self.out_proj(out)
-
- return out
-
-
-def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
- """Applies pooling and optional normalization to a tensor, handling permutations for spatial operations."""
- if pool is None:
- return x
- # (B, H, W, C) -> (B, C, H, W)
- x = x.permute(0, 3, 1, 2)
- x = pool(x)
- # (B, C, H', W') -> (B, H', W', C)
- x = x.permute(0, 2, 3, 1)
- if norm:
- x = norm(x)
-
- return x
-
-
-class MultiScaleAttention(nn.Module):
- """Implements multi-scale self-attention with optional query pooling for efficient feature extraction."""
-
- def __init__(
- self,
- dim: int,
- dim_out: int,
- num_heads: int,
- q_pool: nn.Module = None,
- ):
- """Initializes a multi-scale attention module with configurable query pooling and linear projections."""
- super().__init__()
-
- self.dim = dim
- self.dim_out = dim_out
-
- self.num_heads = num_heads
- head_dim = dim_out // num_heads
- self.scale = head_dim**-0.5
-
- self.q_pool = q_pool
- self.qkv = nn.Linear(dim, dim_out * 3)
- self.proj = nn.Linear(dim_out, dim_out)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Applies multi-scale attention to input tensor, optionally downsampling query features."""
- B, H, W, _ = x.shape
- # qkv with shape (B, H * W, 3, nHead, C)
- qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
- # q, k, v with shape (B, H * W, nheads, C)
- q, k, v = torch.unbind(qkv, 2)
-
- # Q pooling (for downsample at stage changes)
- if self.q_pool:
- q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
- H, W = q.shape[1:3] # downsampled shape
- q = q.reshape(B, H * W, self.num_heads, -1)
-
- # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
- x = F.scaled_dot_product_attention(
- q.transpose(1, 2),
- k.transpose(1, 2),
- v.transpose(1, 2),
- )
- # Transpose back
- x = x.transpose(1, 2)
- x = x.reshape(B, H, W, -1)
-
- x = self.proj(x)
-
- return x
-
-
-class MultiScaleBlock(nn.Module):
- """Multiscale attention block with window partitioning and query pooling for efficient vision transformers."""
-
- def __init__(
- self,
- dim: int,
- dim_out: int,
- num_heads: int,
- mlp_ratio: float = 4.0,
- drop_path: float = 0.0,
- norm_layer: Union[nn.Module, str] = "LayerNorm",
- q_stride: Tuple[int, int] = None,
- act_layer: nn.Module = nn.GELU,
- window_size: int = 0,
- ):
- """Initializes a multi-scale attention block with optional window partitioning and downsampling."""
- super().__init__()
-
- if isinstance(norm_layer, str):
- norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
-
- self.dim = dim
- self.dim_out = dim_out
- self.norm1 = norm_layer(dim)
-
- self.window_size = window_size
-
- self.pool, self.q_stride = None, q_stride
- if self.q_stride:
- self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False)
-
- self.attn = MultiScaleAttention(
- dim,
- dim_out,
- num_heads=num_heads,
- q_pool=self.pool,
- )
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
-
- self.norm2 = norm_layer(dim_out)
- self.mlp = MLP(
- dim_out,
- int(dim_out * mlp_ratio),
- dim_out,
- num_layers=2,
- act=act_layer,
- )
-
- if dim != dim_out:
- self.proj = nn.Linear(dim, dim_out)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Applies multi-scale attention and MLP processing to input tensor, with optional windowing."""
- shortcut = x # B, H, W, C
- x = self.norm1(x)
-
- # Skip connection
- if self.dim != self.dim_out:
- shortcut = do_pool(self.proj(x), self.pool)
-
- # Window partition
- window_size = self.window_size
- if window_size > 0:
- H, W = x.shape[1], x.shape[2]
- x, pad_hw = window_partition(x, window_size)
-
- # Window Attention + Q Pooling (if stage change)
- x = self.attn(x)
- if self.q_stride:
- # Shapes have changed due to Q pooling
- window_size = self.window_size // self.q_stride[0]
- H, W = shortcut.shape[1:3]
-
- pad_h = (window_size - H % window_size) % window_size
- pad_w = (window_size - W % window_size) % window_size
- pad_hw = (H + pad_h, W + pad_w)
-
- # Reverse window partition
- if self.window_size > 0:
- x = window_unpartition(x, window_size, pad_hw, (H, W))
-
- x = shortcut + self.drop_path(x)
- # MLP
- x = x + self.drop_path(self.mlp(self.norm2(x)))
- return x
-
-
-class PositionEmbeddingSine(nn.Module):
- """Generates sinusoidal positional embeddings for 2D inputs like images."""
-
- def __init__(
- self,
- num_pos_feats,
- temperature: int = 10000,
- normalize: bool = True,
- scale: Optional[float] = None,
- ):
- """Initializes sinusoidal position embeddings for 2D image inputs."""
- super().__init__()
- assert num_pos_feats % 2 == 0, "Expecting even model width"
- self.num_pos_feats = num_pos_feats // 2
- self.temperature = temperature
- self.normalize = normalize
- if scale is not None and normalize is False:
- raise ValueError("normalize should be True if scale is passed")
- if scale is None:
- scale = 2 * math.pi
- self.scale = scale
-
- self.cache = {}
-
- def _encode_xy(self, x, y):
- """Encodes 2D positions using sine and cosine functions for positional embeddings."""
- assert len(x) == len(y) and x.ndim == y.ndim == 1
- x_embed = x * self.scale
- y_embed = y * self.scale
-
- dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
- dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
-
- pos_x = x_embed[:, None] / dim_t
- pos_y = y_embed[:, None] / dim_t
- pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
- pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
- return pos_x, pos_y
-
- @torch.no_grad()
- def encode_boxes(self, x, y, w, h):
- """Encodes box coordinates and dimensions into positional embeddings for object detection tasks."""
- pos_x, pos_y = self._encode_xy(x, y)
- pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
- return pos
-
- encode = encode_boxes # Backwards compatibility
-
- @torch.no_grad()
- def encode_points(self, x, y, labels):
- """Encodes 2D point coordinates with sinusoidal positional embeddings and appends labels."""
- (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
- assert bx == by and nx == ny and bx == bl and nx == nl
- pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
- pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
- pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
- return pos
-
- @torch.no_grad()
- def forward(self, x: torch.Tensor):
- """Generate sinusoidal position embeddings for 2D inputs."""
- cache_key = (x.shape[-2], x.shape[-1])
- if cache_key in self.cache:
- return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
- y_embed = (
- torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
- .view(1, -1, 1)
- .repeat(x.shape[0], 1, x.shape[-1])
- )
- x_embed = (
- torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
- .view(1, 1, -1)
- .repeat(x.shape[0], x.shape[-2], 1)
- )
-
- if self.normalize:
- eps = 1e-6
- y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
- x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
-
- dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
- dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
-
- pos_x = x_embed[:, :, :, None] / dim_t
- pos_y = y_embed[:, :, :, None] / dim_t
- pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
- pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
- pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
- self.cache[cache_key] = pos[0]
- return pos
diff --git a/ultralytics/models/sam2/predict.py b/ultralytics/models/sam2/predict.py
deleted file mode 100644
index f40c01582..000000000
--- a/ultralytics/models/sam2/predict.py
+++ /dev/null
@@ -1,177 +0,0 @@
-# Ultralytics YOLO 🚀, AGPL-3.0 license
-
-import torch
-
-from ..sam.predict import Predictor
-from .build import build_sam2
-
-
-class SAM2Predictor(Predictor):
- """
- A predictor class for the Segment Anything Model 2 (SAM2), extending the base Predictor class.
-
- This class provides an interface for model inference tailored to image segmentation tasks, leveraging SAM2's
- advanced architecture and promptable segmentation capabilities. It facilitates flexible and real-time mask
- generation, working with various types of prompts such as bounding boxes, points, and low-resolution masks.
-
- Attributes:
- cfg (Dict): Configuration dictionary specifying model and task-related parameters.
- overrides (Dict): Dictionary containing values that override the default configuration.
- _callbacks (Dict): Dictionary of user-defined callback functions to augment behavior.
- args (namespace): Namespace to hold command-line arguments or other operational variables.
- im (torch.Tensor): Preprocessed input image tensor.
- features (torch.Tensor): Extracted image features used for inference.
- prompts (Dict): Collection of various prompt types, such as bounding boxes and points.
- segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones.
- model (torch.nn.Module): The loaded SAM2 model.
- device (torch.device): The device (CPU or GPU) on which the model is loaded.
- _bb_feat_sizes (List[Tuple[int, int]]): List of feature sizes for different backbone levels.
-
- Methods:
- get_model: Builds and returns the SAM2 model.
- prompt_inference: Performs image segmentation inference based on various prompts.
- set_image: Preprocesses and sets a single image for inference.
- get_im_features: Extracts image features from the SAM2 image encoder.
-
- Examples:
- >>> predictor = SAM2Predictor(model='sam2_l.pt')
- >>> predictor.set_image('path/to/image.jpg')
- >>> masks, scores = predictor.prompt_inference(im=predictor.im, points=[[500, 375]], labels=[1])
- >>> print(f"Generated {len(masks)} mask(s) with scores: {scores}")
- """
-
- _bb_feat_sizes = [
- (256, 256),
- (128, 128),
- (64, 64),
- ]
-
- def get_model(self):
- """Retrieves and initializes the Segment Anything Model (SAM) for image segmentation tasks."""
- return build_sam2(self.args.model)
-
- def prompt_inference(
- self,
- im,
- bboxes=None,
- points=None,
- labels=None,
- masks=None,
- multimask_output=False,
- img_idx=-1,
- ):
- """
- Performs image segmentation inference based on various prompts using SAM2 architecture.
-
- Args:
- im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
- bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
- points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
- labels (np.ndarray | List | None): Labels for point prompts with shape (N,). 1 = foreground, 0 = background.
- masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W).
- multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
- img_idx (int): Index of the image in the batch to process.
-
- Returns:
- (tuple): Tuple containing:
- - np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
- - np.ndarray: Quality scores for each mask, with length C.
- - np.ndarray: Low-resolution logits with shape (C, 256, 256) for subsequent inference.
-
- Examples:
- >>> predictor = SAM2Predictor(cfg)
- >>> image = torch.rand(1, 3, 640, 640)
- >>> bboxes = [[100, 100, 200, 200]]
- >>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes)
- """
- features = self.get_im_features(im) if self.features is None else self.features
-
- src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
- r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
- # Transform input prompts
- if points is not None:
- points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
- points = points[None] if points.ndim == 1 else points
- # Assuming labels are all positive if users don't pass labels.
- if labels is None:
- labels = torch.ones(points.shape[0])
- labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
- points *= r
- # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
- points, labels = points[:, None], labels[:, None]
- if bboxes is not None:
- bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
- bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
- bboxes = bboxes.view(-1, 2, 2) * r
- bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
- # NOTE: merge "boxes" and "points" into a single "points" input
- # (where boxes are added at the beginning) to model.sam_prompt_encoder
- if points is not None:
- points = torch.cat([bboxes, points], dim=1)
- labels = torch.cat([bbox_labels, labels], dim=1)
- else:
- points, labels = bboxes, bbox_labels
- if masks is not None:
- masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
-
- points = (points, labels) if points is not None else None
-
- sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
- points=points,
- boxes=None,
- masks=masks,
- )
- # Predict masks
- batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction
- high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
- pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(
- image_embeddings=features["image_embed"][img_idx].unsqueeze(0),
- image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
- sparse_prompt_embeddings=sparse_embeddings,
- dense_prompt_embeddings=dense_embeddings,
- multimask_output=multimask_output,
- repeat_image=batched_mode,
- high_res_features=high_res_features,
- )
- # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
- # `d` could be 1 or 3 depends on `multimask_output`.
- return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
-
- def set_image(self, image):
- """
- Preprocesses and sets a single image for inference.
-
- This function sets up the model if not already initialized, configures the data source to the specified image,
- and preprocesses the image for feature extraction. Only one image can be set at a time.
-
- Args:
- image (str | np.ndarray): Image file path as a string, or a numpy array image read by cv2.
-
- Raises:
- AssertionError: If more than one image is set.
-
- Examples:
- >>> predictor = SAM2Predictor()
- >>> predictor.set_image("path/to/image.jpg")
- >>> predictor.set_image(np.array([...])) # Using a numpy array
- """
- if self.model is None:
- self.setup_model(model=None)
- self.setup_source(image)
- assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
- for batch in self.dataset:
- im = self.preprocess(batch[1])
- self.features = self.get_im_features(im)
- break
-
- def get_im_features(self, im):
- """Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks."""
- backbone_out = self.model.forward_image(im)
- _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
- if self.model.directly_add_no_mem_embed:
- vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
- feats = [
- feat.permute(1, 2, 0).view(1, -1, *feat_size)
- for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
- ][::-1]
- return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}