diff --git a/docs/en/reference/models/sam/build.md b/docs/en/reference/models/sam/build.md index c5a66c656..5b608c7ec 100644 --- a/docs/en/reference/models/sam/build.md +++ b/docs/en/reference/models/sam/build.md @@ -1,6 +1,6 @@ --- -description: Discover detailed instructions for building various Segment Anything Model (SAM) architectures with Ultralytics, including SAM ViT and Mobile-SAM. -keywords: Ultralytics, SAM model, Segment Anything Model, SAM ViT, Mobile-SAM, model building, deep learning, AI +description: Discover detailed instructions for building various Segment Anything Model (SAM) and Segment Anything Model 2 (SAM 2) architectures with Ultralytics, including SAM ViT and Mobile-SAM. +keywords: Ultralytics, SAM model, Segment Anything Model, SAM 2 model, Segment Anything Model 2, SAM ViT, Mobile-SAM, model building, deep learning, AI --- # Reference for `ultralytics/models/sam/build.py` @@ -27,10 +27,30 @@ keywords: Ultralytics, SAM model, Segment Anything Model, SAM ViT, Mobile-SAM, m



+## ::: ultralytics.models.sam.build.build_sam2_t + +



+ +## ::: ultralytics.models.sam.build.build_sam2_s + +



+ +## ::: ultralytics.models.sam.build.build_sam2_b + +



+ +## ::: ultralytics.models.sam.build.build_sam2_l + +



+ ## ::: ultralytics.models.sam.build._build_sam



+## ::: ultralytics.models.sam.build._build_sam2 + +



+ ## ::: ultralytics.models.sam.build.build_sam

diff --git a/docs/en/reference/models/sam/model.md b/docs/en/reference/models/sam/model.md index 2255ad882..9c0f9301c 100644 --- a/docs/en/reference/models/sam/model.md +++ b/docs/en/reference/models/sam/model.md @@ -1,6 +1,6 @@ --- -description: Explore the SAM (Segment Anything Model) interface for real-time image segmentation. Learn about promptable segmentation and zero-shot capabilities. -keywords: Ultralytics, SAM, Segment Anything Model, image segmentation, real-time segmentation, zero-shot performance, promptable segmentation, SA-1B dataset +description: Explore the SAM (Segment Anything Model) and SAM 2 (Segment Anything Model 2) interface for real-time image segmentation. Learn about promptable segmentation and zero-shot capabilities. +keywords: Ultralytics, SAM, Segment Anything Model, SAM 2, Segment Anything Model 2, image segmentation, real-time segmentation, zero-shot performance, promptable segmentation, SA-1B dataset --- # Reference for `ultralytics/models/sam/model.py` diff --git a/docs/en/reference/models/sam/modules/blocks.md b/docs/en/reference/models/sam/modules/blocks.md new file mode 100644 index 000000000..5303a852e --- /dev/null +++ b/docs/en/reference/models/sam/modules/blocks.md @@ -0,0 +1,72 @@ +--- +description: Explore detailed documentation of various SAM and SAM 2 modules such as MaskDownSampler, CXBlock, and more, available in Ultralytics' repository. +keywords: Ultralytics, SAM encoder, SAM 2 encoder, DropPath, MaskDownSampler, CXBlock, Fuser, TwoWayTransformer, TwoWayAttentionBlock, RoPEAttention, MultiScaleAttention, MultiScaleBlock. PositionEmbeddingSine, do_pool +--- + +# Reference for `ultralytics/models/sam/modules/blocks.py` + +!!! Note + + This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/blocks.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/blocks.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam/modules/blocks.py) 🛠️. Thank you 🙏! + +
+ +## ::: ultralytics.models.sam.modules.blocks.DropPath + +



+ +## ::: ultralytics.models.sam.modules.blocks.MaskDownSampler + +



+ +## ::: ultralytics.models.sam.modules.blocks.CXBlock + +



+ +## ::: ultralytics.models.sam.modules.blocks.Fuser + +



+ +## ::: ultralytics.models.sam.modules.blocks.SAM2TwoWayAttentionBlock + +



+ +## ::: ultralytics.models.sam.modules.blocks.SAM2TwoWayTransformer + +



+ +## ::: ultralytics.models.sam.modules.blocks.RoPEAttention + +



+ +## ::: ultralytics.models.sam.modules.blocks.MultiScaleAttention + +



+ +## ::: ultralytics.models.sam.modules.blocks.MultiScaleBlock + +



+ +## ::: ultralytics.models.sam.modules.blocks.PositionEmbeddingSine + +



+ +## ::: ultralytics.models.sam.modules.blocks.PositionEmbeddingRandom + +



+ +## ::: ultralytics.models.sam.modules.blocks.Block + +



+ +## ::: ultralytics.models.sam.modules.blocks.REAttention + +



+ +## ::: ultralytics.models.sam.modules.blocks.PatchEmbed + +



+ +## ::: ultralytics.models.sam.modules.blocks.do_pool + +

diff --git a/docs/en/reference/models/sam/modules/decoders.md b/docs/en/reference/models/sam/modules/decoders.md index a224738b7..df2fc528f 100644 --- a/docs/en/reference/models/sam/modules/decoders.md +++ b/docs/en/reference/models/sam/modules/decoders.md @@ -13,4 +13,8 @@ keywords: Ultralytics, MaskDecoder, MLP, machine learning, transformer architect ## ::: ultralytics.models.sam.modules.decoders.MaskDecoder +



+ +## ::: ultralytics.models.sam.modules.decoders.SAM2MaskDecoder +

diff --git a/docs/en/reference/models/sam/modules/encoders.md b/docs/en/reference/models/sam/modules/encoders.md index 4e839e479..a2ff493f6 100644 --- a/docs/en/reference/models/sam/modules/encoders.md +++ b/docs/en/reference/models/sam/modules/encoders.md @@ -19,34 +19,18 @@ keywords: Ultralytics, SAM encoder, ImageEncoderViT, PromptEncoder, PositionEmbe



-## ::: ultralytics.models.sam.modules.encoders.PositionEmbeddingRandom +## ::: ultralytics.models.sam.modules.encoders.MemoryEncoder



-## ::: ultralytics.models.sam.modules.encoders.Block +## ::: ultralytics.models.sam.modules.encoders.ImageEncoder



-## ::: ultralytics.models.sam.modules.encoders.Attention +## ::: ultralytics.models.sam.modules.encoders.FpnNeck



-## ::: ultralytics.models.sam.modules.encoders.PatchEmbed - -



- -## ::: ultralytics.models.sam.modules.encoders.window_partition - -



- -## ::: ultralytics.models.sam.modules.encoders.window_unpartition - -



- -## ::: ultralytics.models.sam.modules.encoders.get_rel_pos - -



- -## ::: ultralytics.models.sam.modules.encoders.add_decomposed_rel_pos +## ::: ultralytics.models.sam.modules.encoders.Hiera

diff --git a/docs/en/reference/models/sam/modules/memory_attention.md b/docs/en/reference/models/sam/modules/memory_attention.md new file mode 100644 index 000000000..a341d6fed --- /dev/null +++ b/docs/en/reference/models/sam/modules/memory_attention.md @@ -0,0 +1,20 @@ +--- +description: Explore detailed documentation of various SAM 2 encoder modules such as MemoryAttentionLayer, MemoryAttention, available in Ultralytics' repository. +keywords: Ultralytics, SAM 2 encoder, MemoryAttentionLayer, MemoryAttention +--- + +# Reference for `ultralytics/models/sam/modules/memory_attention.py` + +!!! Note + + This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/memory_attention.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/memory_attention.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam/modules/memory_attention.py) 🛠️. Thank you 🙏! + +
+ +## ::: ultralytics.models.sam.modules.memory_attention.MemoryAttentionLayer + +



+ +## ::: ultralytics.models.sam.modules.memory_attention.MemoryAttention + +

diff --git a/docs/en/reference/models/sam/modules/sam.md b/docs/en/reference/models/sam/modules/sam.md index a31cece15..d66508bb5 100644 --- a/docs/en/reference/models/sam/modules/sam.md +++ b/docs/en/reference/models/sam/modules/sam.md @@ -1,6 +1,6 @@ --- -description: Discover the Ultralytics SAM module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide. -keywords: Ultralytics, SAM Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning +description: Discover the Ultralytics SAM and SAM 2 module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide. +keywords: Ultralytics, SAM Module, SAM 2 Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning --- # Reference for `ultralytics/models/sam/modules/sam.py` @@ -13,4 +13,8 @@ keywords: Ultralytics, SAM Module, object segmentation, image encoder, mask deco ## ::: ultralytics.models.sam.modules.sam.SAMModel +



+ +## ::: ultralytics.models.sam.modules.sam.SAM2Model +

diff --git a/docs/en/reference/models/sam/modules/tiny_encoder.md b/docs/en/reference/models/sam/modules/tiny_encoder.md index a5cac0529..fa5e0988d 100644 --- a/docs/en/reference/models/sam/modules/tiny_encoder.md +++ b/docs/en/reference/models/sam/modules/tiny_encoder.md @@ -47,10 +47,6 @@ keywords: Ultralytics, TinyViT, Conv2d_BN, PatchEmbed, MBConv, Attention, PyTorc



-## ::: ultralytics.models.sam.modules.tiny_encoder.LayerNorm2d - -



- ## ::: ultralytics.models.sam.modules.tiny_encoder.TinyViT

diff --git a/docs/en/reference/models/sam/modules/utils.md b/docs/en/reference/models/sam/modules/utils.md new file mode 100644 index 000000000..5611d156e --- /dev/null +++ b/docs/en/reference/models/sam/modules/utils.md @@ -0,0 +1,52 @@ +--- +description: Explore the detailed API reference for Ultralytics SAM and SAM 2 models. +keywords: Ultralytics, SAM, SAM 2, API Reference, models, window partition, data processing, YOLO +--- + +# Reference for `ultralytics/models/sam/modules/utils.py` + +!!! Note + + This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/utils.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam/modules/utils.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam/modules/utils.py) 🛠️. Thank you 🙏! + +
+ +## ::: ultralytics.models.sam.modules.utils.select_closest_cond_frames + +



+ +## ::: ultralytics.models.sam.modules.utils.get_1d_sine_pe + +



+ +## ::: ultralytics.models.sam.modules.utils.init_t_xy + +



+ +## ::: ultralytics.models.sam.modules.utils.compute_axial_cis + +



+ +## ::: ultralytics.models.sam.modules.utils.reshape_for_broadcast + +



+ +## ::: ultralytics.models.sam.modules.utils.apply_rotary_enc + +



+ +## ::: ultralytics.models.sam.modules.utils.window_partition + +



+ +## ::: ultralytics.models.sam.modules.utils.window_unpartition + +



+ +## ::: ultralytics.models.sam.modules.utils.get_rel_pos + +



+ +## ::: ultralytics.models.sam.modules.utils.add_decomposed_rel_pos + +

diff --git a/docs/en/reference/models/sam/predict.md b/docs/en/reference/models/sam/predict.md index 6c35fdc00..76bdce18b 100644 --- a/docs/en/reference/models/sam/predict.md +++ b/docs/en/reference/models/sam/predict.md @@ -1,6 +1,6 @@ --- -description: Explore Ultralytics SAM Predictor for advanced, real-time image segmentation using the Segment Anything Model (SAM). Complete implementation details and auxiliary utilities. -keywords: Ultralytics, SAM, Segment Anything Model, image segmentation, real-time, prediction, AI, machine learning, Python, torch, inference +description: Explore Ultralytics SAM and SAM 2 Predictor for advanced, real-time image segmentation using the Segment Anything Model (SAM and SAM 2). Complete implementation details and auxiliary utilities. +keywords: Ultralytics, SAM, Segment Anything Model, SAM 2, Segment Anything Model 2, image segmentation, real-time, prediction, AI, machine learning, Python, torch, inference --- # Reference for `ultralytics/models/sam/predict.py` @@ -13,4 +13,8 @@ keywords: Ultralytics, SAM, Segment Anything Model, image segmentation, real-tim ## ::: ultralytics.models.sam.predict.Predictor +



+ +## ::: ultralytics.models.sam.predict.SAM2Predictor +

diff --git a/docs/en/reference/models/sam2/build.md b/docs/en/reference/models/sam2/build.md deleted file mode 100644 index 94e9f5bef..000000000 --- a/docs/en/reference/models/sam2/build.md +++ /dev/null @@ -1,36 +0,0 @@ ---- -description: Discover detailed instructions for building various Segment Anything Model 2 (SAM 2) architectures with Ultralytics. -keywords: Ultralytics, SAM 2 model, Segment Anything Model 2, SAM, model building, deep learning, AI ---- - -# Reference for `ultralytics/models/sam2/build.py` - -!!! Note - - This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/build.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/build.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/build.py) 🛠️. Thank you 🙏! - -
- -## ::: ultralytics.models.sam2.build.build_sam2_t - -



- -## ::: ultralytics.models.sam2.build.build_sam2_s - -



- -## ::: ultralytics.models.sam2.build.build_sam2_b - -



- -## ::: ultralytics.models.sam2.build.build_sam2_l - -



- -## ::: ultralytics.models.sam2.build._build_sam2 - -



- -## ::: ultralytics.models.sam2.build.build_sam2 - -

diff --git a/docs/en/reference/models/sam2/model.md b/docs/en/reference/models/sam2/model.md deleted file mode 100644 index fc0d2e250..000000000 --- a/docs/en/reference/models/sam2/model.md +++ /dev/null @@ -1,16 +0,0 @@ ---- -description: Explore the SAM 2 (Segment Anything Model 2) interface for real-time image segmentation. Learn about promptable segmentation and zero-shot capabilities. -keywords: Ultralytics, SAM 2, Segment Anything Model 2, image segmentation, real-time segmentation, zero-shot performance, promptable segmentation, SA-1B dataset ---- - -# Reference for `ultralytics/models/sam2/model.py` - -!!! Note - - This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/model.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/model.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/model.py) 🛠️. Thank you 🙏! - -
- -## ::: ultralytics.models.sam2.model.SAM2 - -

diff --git a/docs/en/reference/models/sam2/modules/decoders.md b/docs/en/reference/models/sam2/modules/decoders.md deleted file mode 100644 index 989169a77..000000000 --- a/docs/en/reference/models/sam2/modules/decoders.md +++ /dev/null @@ -1,16 +0,0 @@ ---- -description: Explore the MaskDecoder and MLP modules in Ultralytics for efficient mask prediction using transformer architecture. Detailed attributes, functionalities, and implementation. -keywords: Ultralytics, MaskDecoder, MLP, machine learning, transformer architecture, mask prediction, neural networks, PyTorch modules ---- - -# Reference for `ultralytics/models/sam2/modules/decoders.py` - -!!! Note - - This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/decoders.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/decoders.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/decoders.py) 🛠️. Thank you 🙏! - -
- -## ::: ultralytics.models.sam2.modules.decoders.MaskDecoder - -

diff --git a/docs/en/reference/models/sam2/modules/encoders.md b/docs/en/reference/models/sam2/modules/encoders.md deleted file mode 100644 index a3da286e1..000000000 --- a/docs/en/reference/models/sam2/modules/encoders.md +++ /dev/null @@ -1,28 +0,0 @@ ---- -description: Discover the Ultralytics SAM 2 module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide. -keywords: Ultralytics, SAM 2 Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning ---- - -# Reference for `ultralytics/models/sam2/modules/encoders.py` - -!!! Note - - This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/encoders.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/encoders.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/encoders.py) 🛠️. Thank you 🙏! - -
- -## ::: ultralytics.models.sam2.modules.encoders.MemoryEncoder - -



- -## ::: ultralytics.models.sam2.modules.encoders.ImageEncoder - -



- -## ::: ultralytics.models.sam2.modules.encoders.FpnNeck - -



- -## ::: ultralytics.models.sam2.modules.encoders.Hiera - -

diff --git a/docs/en/reference/models/sam2/modules/memory_attention.md b/docs/en/reference/models/sam2/modules/memory_attention.md deleted file mode 100644 index fef22174e..000000000 --- a/docs/en/reference/models/sam2/modules/memory_attention.md +++ /dev/null @@ -1,20 +0,0 @@ ---- -description: Explore detailed documentation of various SAM 2 encoder modules such as MemoryAttentionLayer, MemoryAttention, available in Ultralytics' repository. -keywords: Ultralytics, SAM 2 encoder, MemoryAttentionLayer, MemoryAttention ---- - -# Reference for `ultralytics/models/sam2/modules/memory_attention.py` - -!!! Note - - This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/memory_attention.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/memory_attention.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/memory_attention.py) 🛠️. Thank you 🙏! - -
- -## ::: ultralytics.models.sam2.modules.memory_attention.MemoryAttentionLayer - -



- -## ::: ultralytics.models.sam2.modules.memory_attention.MemoryAttention - -

diff --git a/docs/en/reference/models/sam2/modules/sam2.md b/docs/en/reference/models/sam2/modules/sam2.md deleted file mode 100644 index fcd47bbdc..000000000 --- a/docs/en/reference/models/sam2/modules/sam2.md +++ /dev/null @@ -1,16 +0,0 @@ ---- -description: Discover the Ultralytics SAM 2 module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide. -keywords: Ultralytics, SAM 2 Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning ---- - -# Reference for `ultralytics/models/sam2/modules/sam2.py` - -!!! Note - - This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/sam2.py) 🛠️. Thank you 🙏! - -
- -## ::: ultralytics.models.sam2.modules.sam2.SAM2Model - -

diff --git a/docs/en/reference/models/sam2/modules/sam2_blocks.md b/docs/en/reference/models/sam2/modules/sam2_blocks.md deleted file mode 100644 index 796669a03..000000000 --- a/docs/en/reference/models/sam2/modules/sam2_blocks.md +++ /dev/null @@ -1,56 +0,0 @@ ---- -description: Explore detailed documentation of various SAM 2 modules such as MaskDownSampler, CXBlock, and more, available in Ultralytics' repository. -keywords: Ultralytics, SAM 2 encoder, DropPath, MaskDownSampler, CXBlock, Fuser, TwoWayTransformer, TwoWayAttentionBlock, RoPEAttention, MultiScaleAttention, MultiScaleBlock. PositionEmbeddingSine, do_pool ---- - -# Reference for `ultralytics/models/sam2/modules/sam2_blocks.py` - -!!! Note - - This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2_blocks.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2_blocks.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/sam2_blocks.py) 🛠️. Thank you 🙏! - -
- -## ::: ultralytics.models.sam2.modules.sam2_blocks.DropPath - -



- -## ::: ultralytics.models.sam2.modules.sam2_blocks.MaskDownSampler - -



- -## ::: ultralytics.models.sam2.modules.sam2_blocks.CXBlock - -



- -## ::: ultralytics.models.sam2.modules.sam2_blocks.Fuser - -



- -## ::: ultralytics.models.sam2.modules.sam2_blocks.TwoWayAttentionBlock - -



- -## ::: ultralytics.models.sam2.modules.sam2_blocks.TwoWayTransformer - -



- -## ::: ultralytics.models.sam2.modules.sam2_blocks.RoPEAttention - -



- -## ::: ultralytics.models.sam2.modules.sam2_blocks.MultiScaleAttention - -



- -## ::: ultralytics.models.sam2.modules.sam2_blocks.MultiScaleBlock - -



- -## ::: ultralytics.models.sam2.modules.sam2_blocks.PositionEmbeddingSine - -



- -## ::: ultralytics.models.sam2.modules.sam2_blocks.do_pool - -

diff --git a/docs/en/reference/models/sam2/modules/utils.md b/docs/en/reference/models/sam2/modules/utils.md deleted file mode 100644 index 357cea623..000000000 --- a/docs/en/reference/models/sam2/modules/utils.md +++ /dev/null @@ -1,44 +0,0 @@ ---- -description: Explore the detailed API reference for Ultralytics SAM 2 models. -keywords: Ultralytics, SAM 2, API Reference, models, window partition, data processing, YOLO ---- - -# Reference for `ultralytics/models/sam2/modules/utils.py` - -!!! Note - - This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/utils.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/utils.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/utils.py) 🛠️. Thank you 🙏! - -
- -## ::: ultralytics.models.sam2.modules.utils.select_closest_cond_frames - -



- -## ::: ultralytics.models.sam2.modules.utils.get_1d_sine_pe - -



- -## ::: ultralytics.models.sam2.modules.utils.init_t_xy - -



- -## ::: ultralytics.models.sam2.modules.utils.compute_axial_cis - -



- -## ::: ultralytics.models.sam2.modules.utils.reshape_for_broadcast - -



- -## ::: ultralytics.models.sam2.modules.utils.apply_rotary_enc - -



- -## ::: ultralytics.models.sam2.modules.utils.window_partition - -



- -## ::: ultralytics.models.sam2.modules.utils.window_unpartition - -

diff --git a/docs/en/reference/models/sam2/predict.md b/docs/en/reference/models/sam2/predict.md deleted file mode 100644 index 23b30140b..000000000 --- a/docs/en/reference/models/sam2/predict.md +++ /dev/null @@ -1,16 +0,0 @@ ---- -description: Explore Ultralytics SAM 2 Predictor for advanced, real-time image segmentation using the Segment Anything Model 2 (SAM 2). Complete implementation details and auxiliary utilities. -keywords: Ultralytics, SAM 2, Segment Anything Model 2, image segmentation, real-time, prediction, AI, machine learning, Python, torch, inference ---- - -# Reference for `ultralytics/models/sam2/predict.py` - -!!! Note - - This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/predict.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/predict.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/predict.py) 🛠️. Thank you 🙏! - -
- -## ::: ultralytics.models.sam2.predict.SAM2Predictor - -

diff --git a/mkdocs.yml b/mkdocs.yml index fcf5516b5..eb150fd7a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -503,23 +503,15 @@ nav: - build: reference/models/sam/build.md - model: reference/models/sam/model.md - modules: + - blocks: reference/models/sam/modules/blocks.md - decoders: reference/models/sam/modules/decoders.md - encoders: reference/models/sam/modules/encoders.md + - memory_attention: reference/models/sam/modules/memory_attention.md - sam: reference/models/sam/modules/sam.md - tiny_encoder: reference/models/sam/modules/tiny_encoder.md - transformer: reference/models/sam/modules/transformer.md + - utils: reference/models/sam/modules/utils.md - predict: reference/models/sam/predict.md - - sam2: - - build: reference/models/sam2/build.md - - model: reference/models/sam2/model.md - - modules: - - decoders: reference/models/sam2/modules/decoders.md - - encoders: reference/models/sam2/modules/encoders.md - - memory_attention: reference/models/sam2/modules/memory_attention.md - - sam2: reference/models/sam2/modules/sam2.md - - sam2_blocks: reference/models/sam2/modules/sam2_blocks.md - - utils: reference/models/sam2/modules/utils.md - - predict: reference/models/sam2/predict.md - utils: - loss: reference/models/utils/loss.md - ops: reference/models/utils/ops.md diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 16bacbbe5..313c06fd0 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.2.72" +__version__ = "8.2.73" import os @@ -8,7 +8,7 @@ import os os.environ["OMP_NUM_THREADS"] = "1" # reduce CPU utilization during training from ultralytics.data.explorer.explorer import Explorer -from ultralytics.models import NAS, RTDETR, SAM, SAM2, YOLO, FastSAM, YOLOWorld +from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld from ultralytics.utils import ASSETS, SETTINGS from ultralytics.utils.checks import check_yolo as checks from ultralytics.utils.downloads import download @@ -21,7 +21,6 @@ __all__ = ( "YOLOWorld", "NAS", "SAM", - "SAM2", "FastSAM", "RTDETR", "checks", diff --git a/ultralytics/models/__init__.py b/ultralytics/models/__init__.py index bbade0472..aff620a9a 100644 --- a/ultralytics/models/__init__.py +++ b/ultralytics/models/__init__.py @@ -4,7 +4,6 @@ from .fastsam import FastSAM from .nas import NAS from .rtdetr import RTDETR from .sam import SAM -from .sam2 import SAM2 from .yolo import YOLO, YOLOWorld -__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld", "SAM2" # allow simpler import +__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld" # allow simpler import diff --git a/ultralytics/models/sam/__init__.py b/ultralytics/models/sam/__init__.py index 8701fccea..a29f5cb3f 100644 --- a/ultralytics/models/sam/__init__.py +++ b/ultralytics/models/sam/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license from .model import SAM -from .predict import Predictor +from .predict import Predictor, SAM2Predictor -__all__ = "SAM", "Predictor" # tuple or list +__all__ = "SAM", "Predictor", "SAM2Predictor" # tuple or list diff --git a/ultralytics/models/sam/amg.py b/ultralytics/models/sam/amg.py index b61c6a710..55db3e011 100644 --- a/ultralytics/models/sam/amg.py +++ b/ultralytics/models/sam/amg.py @@ -11,7 +11,7 @@ import torch def is_box_near_crop_edge( boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 ) -> torch.Tensor: - """Return a boolean tensor indicating if boxes are near the crop edge.""" + """Determines if bounding boxes are near the edge of a cropped image region using a specified tolerance.""" crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) boxes = uncrop_boxes_xyxy(boxes, crop_box).float() @@ -22,7 +22,7 @@ def is_box_near_crop_edge( def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: - """Yield batches of data from the input arguments.""" + """Yields batches of data from input arguments with specified batch size for efficient processing.""" assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs." n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) for b in range(n_batches): @@ -33,12 +33,26 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh """ Computes the stability score for a batch of masks. - The stability score is the IoU between the binary masks obtained by thresholding the predicted mask logits at high - and low values. + The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at + high and low values. + + Args: + masks (torch.Tensor): Batch of predicted mask logits. + mask_threshold (float): Threshold value for creating binary masks. + threshold_offset (float): Offset applied to the threshold for creating high and low binary masks. + + Returns: + (torch.Tensor): Stability scores for each mask in the batch. Notes: - One mask is always contained inside the other. - - Save memory by preventing unnecessary cast to torch.int64 + - Memory is saved by preventing unnecessary cast to torch.int64. + + Examples: + >>> masks = torch.rand(10, 256, 256) # Batch of 10 masks + >>> mask_threshold = 0.5 + >>> threshold_offset = 0.1 + >>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset) """ intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) @@ -46,7 +60,7 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh def build_point_grid(n_per_side: int) -> np.ndarray: - """Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1].""" + """Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks.""" offset = 1 / (2 * n_per_side) points_one_side = np.linspace(offset, 1 - offset, n_per_side) points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) @@ -55,18 +69,14 @@ def build_point_grid(n_per_side: int) -> np.ndarray: def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]: - """Generate point grids for all crop layers.""" + """Generates point grids for multiple crop layers with varying scales and densities.""" return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)] def generate_crop_boxes( im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float ) -> Tuple[List[List[int]], List[int]]: - """ - Generates a list of crop boxes of different sizes. - - Each layer has (2**i)**2 boxes for the ith layer. - """ + """Generates crop boxes of varying sizes for multi-scale image processing, with layered overlapping regions.""" crop_boxes, layer_idxs = [], [] im_h, im_w = im_size short_side = min(im_h, im_w) @@ -99,7 +109,7 @@ def generate_crop_boxes( def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: - """Uncrop bounding boxes by adding the crop box offset.""" + """Uncrop bounding boxes by adding the crop box offset to their coordinates.""" x0, y0, _, _ = crop_box offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) # Check if boxes has a channel dimension @@ -109,7 +119,7 @@ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: - """Uncrop points by adding the crop box offset.""" + """Uncrop points by adding the crop box offset to their coordinates.""" x0, y0, _, _ = crop_box offset = torch.tensor([[x0, y0]], device=points.device) # Check if points has a channel dimension @@ -119,7 +129,7 @@ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor: - """Uncrop masks by padding them to the original image size.""" + """Uncrop masks by padding them to the original image size, handling coordinate transformations.""" x0, y0, x1, y1 = crop_box if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: return masks @@ -130,7 +140,7 @@ def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]: - """Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator.""" + """Removes small disconnected regions or holes in a mask based on area threshold and mode.""" import cv2 # type: ignore assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid" @@ -150,11 +160,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: - """ - Calculates boxes in XYXY format around masks. - - Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. - """ + """Calculates bounding boxes in XYXY format around binary masks, handling empty masks and various input shapes.""" # torch.max below raises an error on empty inputs, just skip in this case if torch.numel(masks) == 0: return torch.zeros(*masks.shape[:-2], 4, device=masks.device) diff --git a/ultralytics/models/sam/build.py b/ultralytics/models/sam/build.py index 253c0159d..0e7ddedcf 100644 --- a/ultralytics/models/sam/build.py +++ b/ultralytics/models/sam/build.py @@ -13,14 +13,15 @@ import torch from ultralytics.utils.downloads import attempt_download_asset from .modules.decoders import MaskDecoder -from .modules.encoders import ImageEncoderViT, PromptEncoder -from .modules.sam import SAMModel +from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder +from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer +from .modules.sam import SAM2Model, SAMModel from .modules.tiny_encoder import TinyViT from .modules.transformer import TwoWayTransformer def build_sam_vit_h(checkpoint=None): - """Build and return a Segment Anything Model (SAM) h-size model.""" + """Builds and returns a Segment Anything Model (SAM) h-size model with specified encoder parameters.""" return _build_sam( encoder_embed_dim=1280, encoder_depth=32, @@ -31,7 +32,7 @@ def build_sam_vit_h(checkpoint=None): def build_sam_vit_l(checkpoint=None): - """Build and return a Segment Anything Model (SAM) l-size model.""" + """Builds and returns a Segment Anything Model (SAM) l-size model with specified encoder parameters.""" return _build_sam( encoder_embed_dim=1024, encoder_depth=24, @@ -42,7 +43,7 @@ def build_sam_vit_l(checkpoint=None): def build_sam_vit_b(checkpoint=None): - """Build and return a Segment Anything Model (SAM) b-size model.""" + """Constructs and returns a Segment Anything Model (SAM) with b-size architecture and optional checkpoint.""" return _build_sam( encoder_embed_dim=768, encoder_depth=12, @@ -53,7 +54,7 @@ def build_sam_vit_b(checkpoint=None): def build_mobile_sam(checkpoint=None): - """Build and return Mobile Segment Anything Model (Mobile-SAM).""" + """Builds and returns a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation.""" return _build_sam( encoder_embed_dim=[64, 128, 160, 320], encoder_depth=[2, 2, 6, 2], @@ -64,10 +65,85 @@ def build_mobile_sam(checkpoint=None): ) +def build_sam2_t(checkpoint=None): + """Builds and returns a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters.""" + return _build_sam2( + encoder_embed_dim=96, + encoder_stages=[1, 2, 7, 2], + encoder_num_heads=1, + encoder_global_att_blocks=[5, 7, 9], + encoder_window_spec=[8, 4, 14, 7], + encoder_backbone_channel_list=[768, 384, 192, 96], + checkpoint=checkpoint, + ) + + +def build_sam2_s(checkpoint=None): + """Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters.""" + return _build_sam2( + encoder_embed_dim=96, + encoder_stages=[1, 2, 11, 2], + encoder_num_heads=1, + encoder_global_att_blocks=[7, 10, 13], + encoder_window_spec=[8, 4, 14, 7], + encoder_backbone_channel_list=[768, 384, 192, 96], + checkpoint=checkpoint, + ) + + +def build_sam2_b(checkpoint=None): + """Builds and returns a SAM2 base-size model with specified architecture parameters.""" + return _build_sam2( + encoder_embed_dim=112, + encoder_stages=[2, 3, 16, 3], + encoder_num_heads=2, + encoder_global_att_blocks=[12, 16, 20], + encoder_window_spec=[8, 4, 14, 7], + encoder_window_spatial_size=[14, 14], + encoder_backbone_channel_list=[896, 448, 224, 112], + checkpoint=checkpoint, + ) + + +def build_sam2_l(checkpoint=None): + """Builds and returns a large-size Segment Anything Model (SAM2) with specified architecture parameters.""" + return _build_sam2( + encoder_embed_dim=144, + encoder_stages=[2, 6, 36, 4], + encoder_num_heads=2, + encoder_global_att_blocks=[23, 33, 43], + encoder_window_spec=[8, 4, 16, 8], + encoder_backbone_channel_list=[1152, 576, 288, 144], + checkpoint=checkpoint, + ) + + def _build_sam( - encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, mobile_sam=False + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, + mobile_sam=False, ): - """Builds the selected SAM model architecture.""" + """ + Builds a Segment Anything Model (SAM) with specified encoder parameters. + + Args: + encoder_embed_dim (int | List[int]): Embedding dimension for the encoder. + encoder_depth (int | List[int]): Depth of the encoder. + encoder_num_heads (int | List[int]): Number of attention heads in the encoder. + encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder. + checkpoint (str | None): Path to the model checkpoint file. + mobile_sam (bool): Whether to build a Mobile-SAM model. + + Returns: + (SAMModel): A Segment Anything Model instance with the specified architecture. + + Examples: + >>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11]) + >>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True) + """ prompt_embed_dim = 256 image_size = 1024 vit_patch_size = 16 @@ -139,16 +215,131 @@ def _build_sam( return sam +def _build_sam2( + encoder_embed_dim=1280, + encoder_stages=[2, 6, 36, 4], + encoder_num_heads=2, + encoder_global_att_blocks=[7, 15, 23, 31], + encoder_backbone_channel_list=[1152, 576, 288, 144], + encoder_window_spatial_size=[7, 7], + encoder_window_spec=[8, 4, 16, 8], + checkpoint=None, +): + """ + Builds and returns a Segment Anything Model 2 (SAM2) with specified architecture parameters. + + Args: + encoder_embed_dim (int): Embedding dimension for the encoder. + encoder_stages (List[int]): Number of blocks in each stage of the encoder. + encoder_num_heads (int): Number of attention heads in the encoder. + encoder_global_att_blocks (List[int]): Indices of global attention blocks in the encoder. + encoder_backbone_channel_list (List[int]): Channel dimensions for each level of the encoder backbone. + encoder_window_spatial_size (List[int]): Spatial size of the window for position embeddings. + encoder_window_spec (List[int]): Window specifications for each stage of the encoder. + checkpoint (str | None): Path to the checkpoint file for loading pre-trained weights. + + Returns: + (SAM2Model): A configured and initialized SAM2 model. + + Examples: + >>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2]) + >>> sam2_model.eval() + """ + image_encoder = ImageEncoder( + trunk=Hiera( + embed_dim=encoder_embed_dim, + num_heads=encoder_num_heads, + stages=encoder_stages, + global_att_blocks=encoder_global_att_blocks, + window_pos_embed_bkg_spatial_size=encoder_window_spatial_size, + window_spec=encoder_window_spec, + ), + neck=FpnNeck( + d_model=256, + backbone_channel_list=encoder_backbone_channel_list, + fpn_top_down_levels=[2, 3], + fpn_interp_model="nearest", + ), + scalp=1, + ) + memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer()) + memory_encoder = MemoryEncoder(out_dim=64) + + sam2 = SAM2Model( + image_encoder=image_encoder, + memory_attention=memory_attention, + memory_encoder=memory_encoder, + num_maskmem=7, + image_size=1024, + sigmoid_scale_for_mem_enc=20.0, + sigmoid_bias_for_mem_enc=-10.0, + use_mask_input_as_output_without_sam=True, + directly_add_no_mem_embed=True, + use_high_res_features_in_sam=True, + multimask_output_in_sam=True, + iou_prediction_use_sigmoid=True, + use_obj_ptrs_in_encoder=True, + add_tpos_enc_to_obj_ptrs=True, + only_obj_ptrs_in_the_past_for_eval=True, + pred_obj_scores=True, + pred_obj_scores_mlp=True, + fixed_no_obj_ptr=True, + multimask_output_for_tracking=True, + use_multimask_token_for_obj_ptr=True, + multimask_min_pt_num=0, + multimask_max_pt_num=1, + use_mlp_for_obj_ptr_proj=True, + compile_image_encoder=False, + sam_mask_decoder_extra_args=dict( + dynamic_multimask_via_stability=True, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + ), + ) + + if checkpoint is not None: + checkpoint = attempt_download_asset(checkpoint) + with open(checkpoint, "rb") as f: + state_dict = torch.load(f)["model"] + sam2.load_state_dict(state_dict) + sam2.eval() + return sam2 + + sam_model_map = { "sam_h.pt": build_sam_vit_h, "sam_l.pt": build_sam_vit_l, "sam_b.pt": build_sam_vit_b, "mobile_sam.pt": build_mobile_sam, + "sam2_t.pt": build_sam2_t, + "sam2_s.pt": build_sam2_s, + "sam2_b.pt": build_sam2_b, + "sam2_l.pt": build_sam2_l, } def build_sam(ckpt="sam_b.pt"): - """Build a SAM model specified by ckpt.""" + """ + Builds and returns a Segment Anything Model (SAM) based on the provided checkpoint. + + Args: + ckpt (str | Path): Path to the checkpoint file or name of a pre-defined SAM model. + + Returns: + (SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance. + + Raises: + FileNotFoundError: If the provided checkpoint is not a supported SAM model. + + Examples: + >>> sam_model = build_sam("sam_b.pt") + >>> sam_model = build_sam("path/to/custom_checkpoint.pt") + + Notes: + Supported pre-defined models include: + - SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt' + - SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt' + """ model_builder = None ckpt = str(ckpt) # to allow Path ckpt types for k in sam_model_map.keys(): diff --git a/ultralytics/models/sam/model.py b/ultralytics/models/sam/model.py index a819b6eaf..30d8644bb 100644 --- a/ultralytics/models/sam/model.py +++ b/ultralytics/models/sam/model.py @@ -20,27 +20,46 @@ from ultralytics.engine.model import Model from ultralytics.utils.torch_utils import model_info from .build import build_sam -from .predict import Predictor +from .predict import Predictor, SAM2Predictor class SAM(Model): """ - SAM (Segment Anything Model) interface class. - - SAM is designed for promptable real-time image segmentation. It can be used with a variety of prompts such as - bounding boxes, points, or labels. The model has capabilities for zero-shot performance and is trained on the SA-1B - dataset. + SAM (Segment Anything Model) interface class for real-time image segmentation tasks. + + This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for + promptable segmentation with versatility in image analysis. It supports various prompts such as bounding + boxes, points, or labels, and features zero-shot performance capabilities. + + Attributes: + model (torch.nn.Module): The loaded SAM model. + is_sam2 (bool): Indicates whether the model is SAM2 variant. + task (str): The task type, set to "segment" for SAM models. + + Methods: + predict: Performs segmentation prediction on the given image or video source. + info: Logs information about the SAM model. + + Examples: + >>> sam = SAM('sam_b.pt') + >>> results = sam.predict('image.jpg', points=[[500, 375]]) + >>> for r in results: + >>> print(f"Detected {len(r.masks)} masks") """ def __init__(self, model="sam_b.pt") -> None: """ - Initializes the SAM model with a pre-trained model file. + Initializes the SAM (Segment Anything Model) instance. Args: model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension. Raises: NotImplementedError: If the model file extension is not .pt or .pth. + + Examples: + >>> sam = SAM('sam_b.pt') + >>> print(sam.is_sam2) """ if model and Path(model).suffix not in {".pt", ".pth"}: raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.") @@ -51,30 +70,40 @@ class SAM(Model): """ Loads the specified weights into the SAM model. + This method initializes the SAM model with the provided weights file, setting up the model architecture + and loading the pre-trained parameters. + Args: - weights (str): Path to the weights file. - task (str, optional): Task name. Defaults to None. - """ - if self.is_sam2: - from ..sam2.build import build_sam2 + weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters. + task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for. - self.model = build_sam2(weights) - else: - self.model = build_sam(weights) + Examples: + >>> sam = SAM('sam_b.pt') + >>> sam._load('path/to/custom_weights.pt') + """ + self.model = build_sam(weights) def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs): """ Performs segmentation prediction on the given image or video source. Args: - source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object. - stream (bool, optional): If True, enables real-time streaming. Defaults to False. - bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None. - points (list, optional): List of points for prompted segmentation. Defaults to None. - labels (list, optional): List of labels for prompted segmentation. Defaults to None. + source (str | PIL.Image | numpy.ndarray): Path to the image or video file, or a PIL.Image object, or + a numpy.ndarray object. + stream (bool): If True, enables real-time streaming. + bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation. + points (List[List[float]] | None): List of points for prompted segmentation. + labels (List[int] | None): List of labels for prompted segmentation. + **kwargs (Any): Additional keyword arguments for prediction. Returns: - (list): The model predictions. + (List): The model predictions. + + Examples: + >>> sam = SAM('sam_b.pt') + >>> results = sam.predict('image.jpg', points=[[500, 375]]) + >>> for r in results: + ... print(f"Detected {len(r.masks)} masks") """ overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024) kwargs.update(overrides) @@ -83,17 +112,27 @@ class SAM(Model): def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs): """ - Alias for the 'predict' method. + Performs segmentation prediction on the given image or video source. + + This method is an alias for the 'predict' method, providing a convenient way to call the SAM model + for segmentation tasks. Args: - source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object. - stream (bool, optional): If True, enables real-time streaming. Defaults to False. - bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None. - points (list, optional): List of points for prompted segmentation. Defaults to None. - labels (list, optional): List of labels for prompted segmentation. Defaults to None. + source (str | PIL.Image | numpy.ndarray | None): Path to the image or video file, or a PIL.Image + object, or a numpy.ndarray object. + stream (bool): If True, enables real-time streaming. + bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation. + points (List[List[float]] | None): List of points for prompted segmentation. + labels (List[int] | None): List of labels for prompted segmentation. + **kwargs (Any): Additional keyword arguments to be passed to the predict method. Returns: - (list): The model predictions. + (List): The model predictions, typically containing segmentation masks and other relevant information. + + Examples: + >>> sam = SAM('sam_b.pt') + >>> results = sam('image.jpg', points=[[500, 375]]) + >>> print(f"Detected {len(results[0].masks)} masks") """ return self.predict(source, stream, bboxes, points, labels, **kwargs) @@ -101,12 +140,20 @@ class SAM(Model): """ Logs information about the SAM model. + This method provides details about the Segment Anything Model (SAM), including its architecture, + parameters, and computational requirements. + Args: - detailed (bool, optional): If True, displays detailed information about the model. Defaults to False. - verbose (bool, optional): If True, displays information on the console. Defaults to True. + detailed (bool): If True, displays detailed information about the model layers and operations. + verbose (bool): If True, prints the information to the console. Returns: - (tuple): A tuple containing the model's information. + (Tuple): A tuple containing the model's information (string representations of the model). + + Examples: + >>> sam = SAM('sam_b.pt') + >>> info = sam.info() + >>> print(info[0]) # Print summary information """ return model_info(self.model, detailed=detailed, verbose=verbose) @@ -116,8 +163,13 @@ class SAM(Model): Provides a mapping from the 'segment' task to its corresponding 'Predictor'. Returns: - (dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'. + (Dict[str, Type[Predictor]]): A dictionary mapping the 'segment' task to its corresponding Predictor + class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor. + + Examples: + >>> sam = SAM('sam_b.pt') + >>> task_map = sam.task_map + >>> print(task_map) + {'segment': } """ - from ..sam2.predict import SAM2Predictor - return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}} diff --git a/ultralytics/models/sam/modules/blocks.py b/ultralytics/models/sam/modules/blocks.py new file mode 100644 index 000000000..008185a29 --- /dev/null +++ b/ultralytics/models/sam/modules/blocks.py @@ -0,0 +1,1131 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import copy +import math +from functools import partial +from typing import Any, Optional, Tuple, Type, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from ultralytics.nn.modules import MLP, LayerNorm2d, MLPBlock + +from .transformer import Attention, TwoWayAttentionBlock, TwoWayTransformer +from .utils import add_decomposed_rel_pos, apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition + + +class DropPath(nn.Module): + """ + Implements stochastic depth regularization for neural networks during training. + + Attributes: + drop_prob (float): Probability of dropping a path during training. + scale_by_keep (bool): Whether to scale the output by the keep probability. + + Methods: + forward: Applies stochastic depth to input tensor during training, with optional scaling. + + Examples: + >>> drop_path = DropPath(drop_prob=0.2, scale_by_keep=True) + >>> x = torch.randn(32, 64, 224, 224) + >>> output = drop_path(x) + """ + + def __init__(self, drop_prob=0.0, scale_by_keep=True): + """Initialize DropPath module for stochastic depth regularization during training.""" + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + """Applies stochastic depth to input tensor during training, with optional scaling.""" + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class MaskDownSampler(nn.Module): + """ + A mask downsampling and embedding module for efficient processing of input masks. + + This class implements a mask downsampler that progressively reduces the spatial dimensions of input masks + while expanding their channel dimensions using convolutional layers, layer normalization, and activation + functions. + + Attributes: + encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and + activation functions for downsampling and embedding masks. + + Methods: + forward: Downsamples and encodes input mask to embed_dim channels. + + Examples: + >>> mask_downsampler = MaskDownSampler(embed_dim=256, kernel_size=4, stride=4, padding=0, total_stride=16) + >>> input_mask = torch.randn(1, 1, 256, 256) + >>> output = mask_downsampler(input_mask) + >>> print(output.shape) + torch.Size([1, 256, 16, 16]) + """ + + def __init__( + self, + embed_dim=256, + kernel_size=4, + stride=4, + padding=0, + total_stride=16, + activation=nn.GELU, + ): + """Initializes a mask downsampler module for progressive downsampling and channel expansion.""" + super().__init__() + num_layers = int(math.log2(total_stride) // math.log2(stride)) + assert stride**num_layers == total_stride + self.encoder = nn.Sequential() + mask_in_chans, mask_out_chans = 1, 1 + for _ in range(num_layers): + mask_out_chans = mask_in_chans * (stride**2) + self.encoder.append( + nn.Conv2d( + mask_in_chans, + mask_out_chans, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + ) + self.encoder.append(LayerNorm2d(mask_out_chans)) + self.encoder.append(activation()) + mask_in_chans = mask_out_chans + + self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) + + def forward(self, x): + """Downsamples and encodes input mask to embed_dim channels using convolutional layers and LayerNorm2d.""" + return self.encoder(x) + + +class CXBlock(nn.Module): + """ + ConvNeXt Block for efficient feature extraction in convolutional neural networks. + + This block implements a modified version of the ConvNeXt architecture, offering improved performance and + flexibility in feature extraction. + + Attributes: + dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer. + norm (LayerNorm2d): Layer normalization applied to channels. + pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer. + act (nn.GELU): GELU activation function. + pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer. + gamma (nn.Parameter | None): Learnable scale parameter for layer scaling. + drop_path (nn.Module): DropPath layer for stochastic depth regularization. + + Methods: + forward: Processes the input tensor through the ConvNeXt block. + + Examples: + >>> import torch + >>> x = torch.randn(1, 64, 56, 56) + >>> block = CXBlock(dim=64, kernel_size=7, padding=3) + >>> output = block(x) + >>> print(output.shape) + torch.Size([1, 64, 56, 56]) + """ + + def __init__( + self, + dim, + kernel_size=7, + padding=3, + drop_path=0.0, + layer_scale_init_value=1e-6, + use_dwconv=True, + ): + """ + Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks. + + This block implements a modified version of the ConvNeXt architecture, offering improved performance and + flexibility in feature extraction. + + Args: + dim (int): Number of input channels. + kernel_size (int): Size of the convolutional kernel. + padding (int): Padding size for the convolution. + drop_path (float): Stochastic depth rate. + layer_scale_init_value (float): Initial value for Layer Scale. + use_dwconv (bool): Whether to use depthwise convolution. + + Examples: + >>> block = CXBlock(dim=64, kernel_size=7, padding=3) + >>> x = torch.randn(1, 64, 32, 32) + >>> output = block(x) + >>> print(output.shape) + torch.Size([1, 64, 32, 32]) + """ + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=padding, + groups=dim if use_dwconv else 1, + ) # depthwise conv + self.norm = LayerNorm2d(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + """Applies ConvNeXt block operations to input tensor, including convolutions and residual connection.""" + input = x + x = self.dwconv(x) + x = self.norm(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class Fuser(nn.Module): + """ + A module for fusing features through multiple layers of a neural network. + + This class applies a series of identical layers to an input tensor, optionally projecting the input first. + + Attributes: + proj (nn.Module): An optional input projection layer. Identity if no projection is needed. + layers (nn.ModuleList): A list of identical layers to be applied sequentially. + + Methods: + forward: Applies the fuser to an input tensor. + + Examples: + >>> layer = CXBlock(dim=256) + >>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True) + >>> x = torch.randn(1, 256, 32, 32) + >>> output = fuser(x) + >>> print(output.shape) + torch.Size([1, 256, 32, 32]) + """ + + def __init__(self, layer, num_layers, dim=None, input_projection=False): + """ + Initializes the Fuser module for feature fusion through multiple layers. + + This module creates a sequence of identical layers and optionally applies an input projection. + + Args: + layer (nn.Module): The layer to be replicated in the fuser. + num_layers (int): The number of times to replicate the layer. + dim (int | None): The dimension for input projection, if used. + input_projection (bool): Whether to use input projection. + + Examples: + >>> layer = nn.Linear(64, 64) + >>> fuser = Fuser(layer, num_layers=3, dim=64, input_projection=True) + >>> input_tensor = torch.randn(1, 64) + >>> output = fuser(input_tensor) + """ + super().__init__() + self.proj = nn.Identity() + self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)]) + + if input_projection: + assert dim is not None + self.proj = nn.Conv2d(dim, dim, kernel_size=1) + + def forward(self, x): + """Applies a series of layers to the input tensor, optionally projecting it first.""" + x = self.proj(x) + for layer in self.layers: + x = layer(x) + return x + + +class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock): + """ + A two-way attention block for performing self-attention and cross-attention in both directions. + + This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on + sparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and + cross-attention from dense to sparse inputs. + + Attributes: + self_attn (Attention): Self-attention layer for queries. + norm1 (nn.LayerNorm): Layer normalization after the first attention block. + cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys. + norm2 (nn.LayerNorm): Layer normalization after the second attention block. + mlp (MLP): MLP block for transforming query embeddings. + norm3 (nn.LayerNorm): Layer normalization after the MLP block. + norm4 (nn.LayerNorm): Layer normalization after the third attention block. + cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries. + skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer. + + Methods: + forward: Processes input through the attention blocks and MLP. + + Examples: + >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8) + >>> sparse_input = torch.randn(1, 100, 256) + >>> dense_input = torch.randn(1, 256, 16, 16) + >>> sparse_output, dense_output = block(sparse_input, dense_input) + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + Initializes a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions. + + This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse + inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention + from dense to sparse inputs. + + Args: + embedding_dim (int): The channel dimension of the embeddings. + num_heads (int): The number of heads in the attention layers. + mlp_dim (int): The hidden dimension of the MLP block. + activation (Type[nn.Module]): The activation function of the MLP block. + attention_downsample_rate (int): The downsample rate for attention computations. + skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer. + + Examples: + >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8, mlp_dim=2048) + >>> sparse_inputs = torch.randn(1, 100, 256) + >>> dense_inputs = torch.randn(1, 256, 32, 32) + >>> sparse_outputs, dense_outputs = block(sparse_inputs, dense_inputs) + """ + super().__init__(embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate, skip_first_layer_pe) + self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, act=activation) + + +class SAM2TwoWayTransformer(TwoWayTransformer): + """ + A Two-Way Transformer module for simultaneous attention to image and query points. + + This class extends the TwoWayTransformer, implementing a specialized transformer decoder that attends to an + input image using queries with supplied positional embeddings. It is particularly useful for tasks like + object detection, image segmentation, and point cloud processing. + + Attributes: + depth (int): Number of layers in the transformer. + embedding_dim (int): Channel dimension for input embeddings. + num_heads (int): Number of heads for multihead attention. + mlp_dim (int): Internal channel dimension for the MLP block. + layers (nn.ModuleList): List of SAM2TwoWayAttentionBlock layers comprising the transformer. + final_attn_token_to_image (Attention): Final attention layer from queries to image. + norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries. + + Methods: + forward: Processes input image embeddings and query embeddings through the transformer. + + Examples: + >>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048) + >>> image_embedding = torch.randn(1, 256, 64, 64) + >>> query_embedding = torch.randn(1, 100, 256) + >>> output = transformer(image_embedding, query_embedding) + >>> print(output[0].shape, output[1].shape) + torch.Size([1, 100, 256]) torch.Size([1, 256, 64, 64]) + """ + + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + Initializes a SAM2TwoWayTransformer instance. + + This transformer decoder attends to an input image using queries with supplied positional embeddings. + It is designed for tasks like object detection, image segmentation, and point cloud processing. + + Args: + depth (int): Number of layers in the transformer. + embedding_dim (int): Channel dimension for the input embeddings. + num_heads (int): Number of heads for multihead attention. Must divide embedding_dim. + mlp_dim (int): Channel dimension internal to the MLP block. + activation (Type[nn.Module]): Activation function to use in the MLP block. + attention_downsample_rate (int): Downsampling rate for attention computations. + + Examples: + >>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048) + >>> transformer + SAM2TwoWayTransformer( + (layers): ModuleList( + (0-4): 5 x SAM2TwoWayAttentionBlock(...) + ) + (final_attn_token_to_image): Attention(...) + (norm_final_attn): LayerNorm(...) + ) + """ + super().__init__(depth, embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate) + self.layers = nn.ModuleList() + for i in range(depth): + self.layers.append( + SAM2TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + +class RoPEAttention(Attention): + """ + Implements rotary position encoding for attention mechanisms in transformer architectures. + + This class extends the base Attention class by incorporating Rotary Position Encoding (RoPE) to enhance + the positional awareness of the attention mechanism. + + Attributes: + compute_cis (Callable): Function to compute axial complex numbers for rotary encoding. + freqs_cis (Tensor): Precomputed frequency tensor for rotary encoding. + rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories. + + Methods: + forward: Applies rotary position encoding and computes attention between query, key, and value tensors. + + Examples: + >>> rope_attn = RoPEAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32)) + >>> q = torch.randn(1, 1024, 256) + >>> k = torch.randn(1, 1024, 256) + >>> v = torch.randn(1, 1024, 256) + >>> output = rope_attn(q, k, v) + >>> print(output.shape) + torch.Size([1, 1024, 256]) + """ + + def __init__( + self, + *args, + rope_theta=10000.0, + rope_k_repeat=False, + feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + **kwargs, + ): + """Initializes RoPEAttention with rotary position encoding for enhanced positional awareness.""" + super().__init__(*args, **kwargs) + + self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta) + freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) + self.freqs_cis = freqs_cis + self.rope_k_repeat = rope_k_repeat # repeat q rope to match k length, needed for cross-attention to memories + + def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor: + """Applies rotary position encoding and computes attention between query, key, and value tensors.""" + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Apply rotary position encoding + w = h = math.sqrt(q.shape[-2]) + self.freqs_cis = self.freqs_cis.to(q.device) + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + if q.shape[-2] != k.shape[-2]: + assert self.rope_k_repeat + + num_k_rope = k.size(-2) - num_k_exclude_rope + q, k[:, :, :num_k_rope] = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: + """Applies pooling and optional normalization to a tensor, handling spatial dimension permutations.""" + if pool is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = pool(x) + # (B, C, H', W') -> (B, H', W', C) + x = x.permute(0, 2, 3, 1) + if norm: + x = norm(x) + + return x + + +class MultiScaleAttention(nn.Module): + """ + Implements multi-scale self-attention with optional query pooling for efficient feature extraction. + + This class provides a flexible implementation of multi-scale attention, allowing for optional + downsampling of query features through pooling. It's designed to enhance the model's ability to + capture multi-scale information in visual tasks. + + Attributes: + dim (int): Input dimension of the feature map. + dim_out (int): Output dimension of the attention module. + num_heads (int): Number of attention heads. + scale (float): Scaling factor for dot-product attention. + q_pool (nn.Module | None): Optional pooling module for query features. + qkv (nn.Linear): Linear projection for query, key, and value. + proj (nn.Linear): Output projection. + + Methods: + forward: Applies multi-scale attention to the input tensor. + + Examples: + >>> import torch + >>> from torch import nn + >>> x = torch.randn(1, 64, 64, 256) + >>> msa = MultiScaleAttention(dim=256, dim_out=256, num_heads=8) + >>> output = msa(x) + >>> print(output.shape) + torch.Size([1, 64, 64, 256]) + """ + + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + q_pool: nn.Module = None, + ): + """Initializes multi-scale attention with optional query pooling for efficient feature extraction.""" + super().__init__() + + self.dim = dim + self.dim_out = dim_out + + self.num_heads = num_heads + head_dim = dim_out // num_heads + self.scale = head_dim**-0.5 + + self.q_pool = q_pool + self.qkv = nn.Linear(dim, dim_out * 3) + self.proj = nn.Linear(dim_out, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies multi-scale attention with optional query pooling to extract multi-scale features.""" + B, H, W, _ = x.shape + # qkv with shape (B, H * W, 3, nHead, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) + # q, k, v with shape (B, H * W, nheads, C) + q, k, v = torch.unbind(qkv, 2) + + # Q pooling (for downsample at stage changes) + if self.q_pool: + q = do_pool(q.reshape(B, H, W, -1), self.q_pool) + H, W = q.shape[1:3] # downsampled shape + q = q.reshape(B, H * W, self.num_heads, -1) + + # Torch's SDPA expects [B, nheads, H*W, C] so we transpose + x = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + ) + # Transpose back + x = x.transpose(1, 2) + x = x.reshape(B, H, W, -1) + + x = self.proj(x) + + return x + + +class MultiScaleBlock(nn.Module): + """ + A multi-scale attention block with window partitioning and query pooling for efficient vision transformers. + + This class implements a multi-scale attention mechanism with optional window partitioning and downsampling, + designed for use in vision transformer architectures. + + Attributes: + dim (int): Input dimension of the block. + dim_out (int): Output dimension of the block. + norm1 (nn.Module): First normalization layer. + window_size (int): Size of the window for partitioning. + pool (nn.Module | None): Pooling layer for query downsampling. + q_stride (Tuple[int, int] | None): Stride for query pooling. + attn (MultiScaleAttention): Multi-scale attention module. + drop_path (nn.Module): Drop path layer for regularization. + norm2 (nn.Module): Second normalization layer. + mlp (MLP): Multi-layer perceptron module. + proj (nn.Linear | None): Projection layer for dimension mismatch. + + Methods: + forward: Processes input tensor through the multi-scale block. + + Examples: + >>> block = MultiScaleBlock(dim=256, dim_out=512, num_heads=8, window_size=7) + >>> x = torch.randn(1, 56, 56, 256) + >>> output = block(x) + >>> print(output.shape) + torch.Size([1, 28, 28, 512]) + """ + + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + norm_layer: Union[nn.Module, str] = "LayerNorm", + q_stride: Tuple[int, int] = None, + act_layer: nn.Module = nn.GELU, + window_size: int = 0, + ): + """Initializes a multi-scale attention block with window partitioning and optional query pooling.""" + super().__init__() + + if isinstance(norm_layer, str): + norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) + + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + + self.window_size = window_size + + self.pool, self.q_stride = None, q_stride + if self.q_stride: + self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False) + + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + q_pool=self.pool, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = MLP( + dim_out, + int(dim_out * mlp_ratio), + dim_out, + num_layers=2, + act=act_layer, + ) + + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Processes input through multi-scale attention and MLP, with optional windowing and downsampling.""" + shortcut = x # B, H, W, C + x = self.norm1(x) + + # Skip connection + if self.dim != self.dim_out: + shortcut = do_pool(self.proj(x), self.pool) + + # Window partition + window_size = self.window_size + if window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, window_size) + + # Window Attention + Q Pooling (if stage change) + x = self.attn(x) + if self.q_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.q_stride[0] + H, W = shortcut.shape[1:3] + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + pad_hw = (H + pad_h, W + pad_w) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + # MLP + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PositionEmbeddingSine(nn.Module): + """ + A module for generating sinusoidal positional embeddings for 2D inputs like images. + + This class implements sinusoidal position encoding for 2D spatial positions, which can be used in + transformer-based models for computer vision tasks. + + Attributes: + num_pos_feats (int): Number of positional features (half of the embedding dimension). + temperature (int): Temperature parameter for the sinusoidal functions. + normalize (bool): Whether to normalize the positional embeddings. + scale (float): Scaling factor for the embeddings when normalize is True. + cache (Dict): Cache for storing precomputed embeddings. + + Methods: + _encode_xy: Encodes 2D positions using sine and cosine functions. + encode_boxes: Encodes box coordinates and dimensions into positional embeddings. + encode_points: Encodes 2D point coordinates with sinusoidal positional embeddings. + forward: Generates sinusoidal position embeddings for 2D inputs. + + Examples: + >>> pos_emb = PositionEmbeddingSine(num_pos_feats=128) + >>> x = torch.randn(1, 3, 224, 224) + >>> embeddings = pos_emb(x) + >>> print(embeddings.shape) + torch.Size([1, 256, 224, 224]) + """ + + def __init__( + self, + num_pos_feats, + temperature: int = 10000, + normalize: bool = True, + scale: Optional[float] = None, + ): + """Initializes sinusoidal position embeddings for 2D image inputs.""" + super().__init__() + assert num_pos_feats % 2 == 0, "Expecting even model width" + self.num_pos_feats = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + self.cache = {} + + def _encode_xy(self, x, y): + """Encodes 2D positions using sine/cosine functions for transformer positional embeddings.""" + assert len(x) == len(y) and x.ndim == y.ndim == 1 + x_embed = x * self.scale + y_embed = y * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, None] / dim_t + pos_y = y_embed[:, None] / dim_t + pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1) + pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1) + return pos_x, pos_y + + @torch.no_grad() + def encode_boxes(self, x, y, w, h): + """Encodes box coordinates and dimensions into positional embeddings for detection.""" + pos_x, pos_y = self._encode_xy(x, y) + pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + return pos + + encode = encode_boxes # Backwards compatibility + + @torch.no_grad() + def encode_points(self, x, y, labels): + """Encodes 2D points with sinusoidal embeddings and appends labels.""" + (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape + assert bx == by and nx == ny and bx == bl and nx == nl + pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) + pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) + pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) + return pos + + @torch.no_grad() + def forward(self, x: torch.Tensor): + """Generates sinusoidal position embeddings for 2D inputs like images.""" + cache_key = (x.shape[-2], x.shape[-1]) + if cache_key in self.cache: + return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + y_embed = ( + torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) + .view(1, -1, 1) + .repeat(x.shape[0], 1, x.shape[-1]) + ) + x_embed = ( + torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) + .view(1, 1, -1) + .repeat(x.shape[0], x.shape[-2], 1) + ) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + self.cache[cache_key] = pos[0] + return pos + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + + This class generates positional embeddings for input coordinates using random spatial frequencies. It is + particularly useful for transformer-based models that require position information. + + Attributes: + positional_encoding_gaussian_matrix (torch.Tensor): A buffer containing random values for encoding. + + Methods: + _pe_encoding: Positionally encodes points that are normalized to [0,1]. + forward: Generates positional encoding for a grid of the specified size. + forward_with_coords: Positionally encodes points that are not normalized to [0,1]. + + Examples: + >>> pe = PositionEmbeddingRandom(num_pos_feats=64) + >>> size = (32, 32) + >>> encoding = pe(size) + >>> print(encoding.shape) + torch.Size([128, 32, 32]) + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + """Initializes random spatial frequency position embedding for transformers.""" + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats))) + + # Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation' + torch.use_deterministic_algorithms(False) + torch.backends.cudnn.deterministic = False + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Encodes normalized [0,1] coordinates using random spatial frequencies.""" + # Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # Outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generates positional encoding for a grid using random spatial frequencies.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor: + """Positionally encodes input coordinates, normalizing them to [0,1] based on the given image size.""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +class Block(nn.Module): + """ + Transformer block with support for window attention and residual propagation. + + This class implements a transformer block that can use either global or windowed self-attention, + followed by a feed-forward network. It supports relative positional embeddings and is designed + for use in vision transformer architectures. + + Attributes: + norm1 (nn.Module): First normalization layer. + attn (REAttention): Self-attention layer with optional relative positional encoding. + norm2 (nn.Module): Second normalization layer. + mlp (MLPBlock): Multi-layer perceptron block. + window_size (int): Size of attention window. If 0, global attention is used. + + Methods: + forward: Processes input through the transformer block. + + Examples: + >>> import torch + >>> block = Block(dim=256, num_heads=8, window_size=7) + >>> x = torch.randn(1, 56, 56, 256) + >>> output = block(x) + >>> print(output.shape) + torch.Size([1, 56, 56, 256]) + """ + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Initializes a transformer block with optional window attention and relative positional embeddings. + + This constructor sets up a transformer block that can use either global or windowed self-attention, + followed by a feed-forward network. It supports relative positional embeddings and is designed + for use in vision transformer architectures. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in the self-attention layer. + mlp_ratio (float): Ratio of mlp hidden dimension to embedding dimension. + qkv_bias (bool): If True, adds a learnable bias to query, key, value projections. + norm_layer (Type[nn.Module]): Type of normalization layer to use. + act_layer (Type[nn.Module]): Type of activation function to use in the MLP block. + use_rel_pos (bool): If True, uses relative positional embeddings in attention. + rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero. + window_size (int): Size of attention window. If 0, uses global attention. + input_size (Optional[Tuple[int, int]]): Input resolution for calculating relative positional parameter size. + + Examples: + >>> block = Block(dim=256, num_heads=8, window_size=7) + >>> x = torch.randn(1, 56, 56, 256) + >>> output = block(x) + >>> print(output.shape) + torch.Size([1, 56, 56, 256]) + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = REAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Processes input through transformer block with optional windowed self-attention and residual connection.""" + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + return x + self.mlp(self.norm2(x)) + + +class REAttention(nn.Module): + """ + Rotary Embedding Attention module for efficient self-attention in transformer architectures. + + This class implements a multi-head attention mechanism with rotary positional embeddings, designed + for use in vision transformer models. It supports optional query pooling and window partitioning + for efficient processing of large inputs. + + Attributes: + compute_cis (Callable): Function to compute axial complex numbers for rotary encoding. + freqs_cis (Tensor): Precomputed frequency tensor for rotary encoding. + rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories. + q_proj (nn.Linear): Linear projection for query. + k_proj (nn.Linear): Linear projection for key. + v_proj (nn.Linear): Linear projection for value. + out_proj (nn.Linear): Output projection. + num_heads (int): Number of attention heads. + internal_dim (int): Internal dimension for attention computation. + + Methods: + forward: Applies rotary position encoding and computes attention between query, key, and value tensors. + + Examples: + >>> rope_attn = REAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32)) + >>> q = torch.randn(1, 1024, 256) + >>> k = torch.randn(1, 1024, 256) + >>> v = torch.randn(1, 1024, 256) + >>> output = rope_attn(q, k, v) + >>> print(output.shape) + torch.Size([1, 1024, 256]) + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Initializes a Relative Position Attention module for transformer-based architectures. + + This module implements multi-head attention with optional relative positional encodings, designed + specifically for vision tasks in transformer models. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. Default is 8. + qkv_bias (bool): If True, adds a learnable bias to query, key, value projections. Default is True. + use_rel_pos (bool): If True, uses relative positional encodings. Default is False. + rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero. Default is True. + input_size (Tuple[int, int] | None): Input resolution for calculating relative positional parameter size. + Required if use_rel_pos is True. Default is None. + + Examples: + >>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32)) + >>> x = torch.randn(1, 32, 32, 256) + >>> output = attention(x) + >>> print(output.shape) + torch.Size([1, 32, 32, 256]) + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert input_size is not None, "Input size must be provided if using relative positional encoding." + # Initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies multi-head attention with optional relative positional encoding to input tensor.""" + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + return self.proj(x) + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding module for vision transformer architectures. + + This module converts an input image into a sequence of patch embeddings using a convolutional layer. + It is commonly used as the first layer in vision transformer architectures to transform image data + into a suitable format for subsequent transformer blocks. + + Attributes: + proj (nn.Conv2d): Convolutional layer for projecting image patches to embeddings. + + Methods: + forward: Applies patch embedding to the input tensor. + + Examples: + >>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768) + >>> x = torch.randn(1, 3, 224, 224) + >>> output = patch_embed(x) + >>> print(output.shape) + torch.Size([1, 768, 14, 14]) + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Initializes the PatchEmbed module for converting image patches to embeddings. + + This module is typically used as the first layer in vision transformer architectures to transform + image data into a suitable format for subsequent transformer blocks. + + Args: + kernel_size (Tuple[int, int]): Size of the convolutional kernel for patch extraction. + stride (Tuple[int, int]): Stride of the convolutional operation. + padding (Tuple[int, int]): Padding applied to the input before convolution. + in_chans (int): Number of input image channels. + embed_dim (int): Dimensionality of the output patch embeddings. + + Examples: + >>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768) + >>> x = torch.randn(1, 3, 224, 224) + >>> output = patch_embed(x) + >>> print(output.shape) + torch.Size([1, 768, 14, 14]) + """ + super().__init__() + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Computes patch embedding by applying convolution and transposing resulting tensor.""" + return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C diff --git a/ultralytics/models/sam/modules/decoders.py b/ultralytics/models/sam/modules/decoders.py index eeaab6b45..5b172bdd2 100644 --- a/ultralytics/models/sam/modules/decoders.py +++ b/ultralytics/models/sam/modules/decoders.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from typing import List, Tuple, Type +from typing import List, Optional, Tuple, Type import torch from torch import nn @@ -10,12 +10,14 @@ from ultralytics.nn.modules import MLP, LayerNorm2d class MaskDecoder(nn.Module): """ - Decoder module for generating masks and their associated quality scores, using a transformer architecture to predict - masks given image and prompt embeddings. + Decoder module for generating masks and their associated quality scores using a transformer architecture. + + This class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and + generate mask predictions along with their quality scores. Attributes: transformer_dim (int): Channel dimension for the transformer module. - transformer (nn.Module): The transformer module used for mask prediction. + transformer (nn.Module): Transformer module used for mask prediction. num_multimask_outputs (int): Number of masks to predict for disambiguating masks. iou_token (nn.Embedding): Embedding for the IoU token. num_mask_tokens (int): Number of mask tokens. @@ -23,6 +25,16 @@ class MaskDecoder(nn.Module): output_upscaling (nn.Sequential): Neural network sequence for upscaling the output. output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks. iou_prediction_head (nn.Module): MLP for predicting mask quality. + + Methods: + forward: Predicts masks given image and prompt embeddings. + predict_masks: Internal method for mask prediction. + + Examples: + >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module) + >>> masks, iou_pred = decoder(image_embeddings, image_pe, sparse_prompt_embeddings, + ... dense_prompt_embeddings, multimask_output=True) + >>> print(f"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}") """ def __init__( @@ -35,15 +47,20 @@ class MaskDecoder(nn.Module): iou_head_hidden_dim: int = 256, ) -> None: """ - Predicts masks given an image and prompt embeddings, using a transformer architecture. + Initializes the MaskDecoder module for generating masks and their quality scores. Args: - transformer_dim (int): the channel dimension of the transformer module - transformer (nn.Module): the transformer used to predict masks - num_multimask_outputs (int): the number of masks to predict when disambiguating masks - activation (nn.Module): the type of activation to use when upscaling masks - iou_head_depth (int): the depth of the MLP used to predict mask quality - iou_head_hidden_dim (int): the hidden dimension of the MLP used to predict mask quality + transformer_dim (int): Channel dimension for the transformer module. + transformer (nn.Module): Transformer module used for mask prediction. + num_multimask_outputs (int): Number of masks to predict for disambiguating masks. + activation (Type[nn.Module]): Type of activation to use when upscaling masks. + iou_head_depth (int): Depth of the MLP used to predict mask quality. + iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality. + + Examples: + >>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6) + >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer) + >>> print(decoder) """ super().__init__() self.transformer_dim = transformer_dim @@ -77,18 +94,28 @@ class MaskDecoder(nn.Module): multimask_output: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Predict masks given image and prompt embeddings. + Predicts masks given image and prompt embeddings. Args: - image_embeddings (torch.Tensor): the embeddings from the image encoder - image_pe (torch.Tensor): positional encoding with the shape of image_embeddings - sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes - dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + image_embeddings (torch.Tensor): Embeddings from the image encoder. + image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings. + sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes. + dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs. multimask_output (bool): Whether to return multiple masks or a single mask. Returns: - torch.Tensor: batched predicted masks - torch.Tensor: batched predictions of mask quality + (Tuple[torch.Tensor, torch.Tensor]): A tuple containing: + - masks (torch.Tensor): Batched predicted masks. + - iou_pred (torch.Tensor): Batched predictions of mask quality. + + Examples: + >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module) + >>> image_emb = torch.rand(1, 256, 64, 64) + >>> image_pe = torch.rand(1, 256, 64, 64) + >>> sparse_emb = torch.rand(1, 2, 256) + >>> dense_emb = torch.rand(1, 256, 64, 64) + >>> masks, iou_pred = decoder(image_emb, image_pe, sparse_emb, dense_emb, multimask_output=True) + >>> print(f"Masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}") """ masks, iou_pred = self.predict_masks( image_embeddings=image_embeddings, @@ -112,11 +139,7 @@ class MaskDecoder(nn.Module): sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Predicts masks. - - See 'forward' for more details. - """ + """Predicts masks and quality scores using image and prompt embeddings via transformer architecture.""" # Concatenate output tokens output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1) @@ -147,3 +170,347 @@ class MaskDecoder(nn.Module): iou_pred = self.iou_prediction_head(iou_token_out) return masks, iou_pred + + +class SAM2MaskDecoder(nn.Module): + """ + Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings. + + This class extends the functionality of the MaskDecoder, incorporating additional features such as + high-resolution feature processing, dynamic multimask output, and object score prediction. + + Attributes: + transformer_dim (int): Channel dimension of the transformer. + transformer (nn.Module): Transformer used to predict masks. + num_multimask_outputs (int): Number of masks to predict when disambiguating masks. + iou_token (nn.Embedding): Embedding for IOU token. + num_mask_tokens (int): Total number of mask tokens. + mask_tokens (nn.Embedding): Embedding for mask tokens. + pred_obj_scores (bool): Whether to predict object scores. + obj_score_token (nn.Embedding): Embedding for object score token. + use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer. + output_upscaling (nn.Sequential): Upscaling layers for output. + use_high_res_features (bool): Whether to use high-resolution features. + conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0). + conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1). + output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks. + iou_prediction_head (MLP): MLP for IOU prediction. + pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction. + dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability. + dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability. + dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability. + + Methods: + forward: Predicts masks given image and prompt embeddings. + predict_masks: Predicts instance segmentation masks from image and prompt embeddings. + _get_stability_scores: Computes mask stability scores based on IoU between thresholds. + _dynamic_multimask_via_stability: Dynamically selects the most stable mask output. + + Examples: + >>> image_embeddings = torch.rand(1, 256, 64, 64) + >>> image_pe = torch.rand(1, 256, 64, 64) + >>> sparse_prompt_embeddings = torch.rand(1, 2, 256) + >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64) + >>> decoder = SAM2MaskDecoder(256, transformer) + >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward( + ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False) + """ + + def __init__( + self, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + use_high_res_features: bool = False, + iou_prediction_use_sigmoid=False, + dynamic_multimask_via_stability=False, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + use_multimask_token_for_obj_ptr: bool = False, + ) -> None: + """ + Initializes the SAM2MaskDecoder module for predicting instance segmentation masks. + + This decoder extends the functionality of MaskDecoder, incorporating additional features such as + high-resolution feature processing, dynamic multimask output, and object score prediction. + + Args: + transformer_dim (int): Channel dimension of the transformer. + transformer (nn.Module): Transformer used to predict masks. + num_multimask_outputs (int): Number of masks to predict when disambiguating masks. + activation (Type[nn.Module]): Type of activation to use when upscaling masks. + iou_head_depth (int): Depth of the MLP used to predict mask quality. + iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality. + use_high_res_features (bool): Whether to use high-resolution features. + iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction. + dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability. + dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability. + dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability. + pred_obj_scores (bool): Whether to predict object scores. + pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction. + use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer. + + Examples: + >>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6) + >>> decoder = SAM2MaskDecoder(transformer_dim=256, transformer=transformer) + >>> print(decoder) + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.pred_obj_scores = pred_obj_scores + if self.pred_obj_scores: + self.obj_score_token = nn.Embedding(1, transformer_dim) + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.use_high_res_features = use_high_res_features + if use_high_res_features: + self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1) + self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1) + + self.output_hypernetworks_mlps = nn.ModuleList( + [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)] + ) + + self.iou_prediction_head = MLP( + transformer_dim, + iou_head_hidden_dim, + self.num_mask_tokens, + iou_head_depth, + sigmoid=iou_prediction_use_sigmoid, + ) + if self.pred_obj_scores: + self.pred_obj_score_head = nn.Linear(transformer_dim, 1) + if pred_obj_scores_mlp: + self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) + + # When outputting a single mask, optionally we can dynamically fall back to the best + # multimask output token if the single mask output token gives low stability scores. + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predicts masks given image and prompt embeddings. + + Args: + image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W). + image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings (B, C, H, W). + sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes with shape (B, N, C). + dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs with shape (B, C, H, W). + multimask_output (bool): Whether to return multiple masks or a single mask. + repeat_image (bool): Flag to repeat the image embeddings. + high_res_features (List[torch.Tensor] | None): Optional high-resolution features. + + Returns: + (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing: + - masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W). + - iou_pred (torch.Tensor): Batched predictions of mask quality with shape (B, N). + - sam_tokens_out (torch.Tensor): Batched SAM token for mask output with shape (B, N, C). + - object_score_logits (torch.Tensor): Batched object score logits with shape (B, 1). + + Examples: + >>> image_embeddings = torch.rand(1, 256, 64, 64) + >>> image_pe = torch.rand(1, 256, 64, 64) + >>> sparse_prompt_embeddings = torch.rand(1, 2, 256) + >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64) + >>> decoder = SAM2MaskDecoder(256, transformer) + >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward( + ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False) + """ + masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + repeat_image=repeat_image, + high_res_features=high_res_features, + ) + + # Select the correct mask or masks for output + if multimask_output: + masks = masks[:, 1:, :, :] + iou_pred = iou_pred[:, 1:] + elif self.dynamic_multimask_via_stability and not self.training: + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + masks = masks[:, 0:1, :, :] + iou_pred = iou_pred[:, 0:1] + + if multimask_output and self.use_multimask_token_for_obj_ptr: + sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape + else: + # Take the mask output token. Here we *always* use the token for single mask output. + # At test time, even if we track after 1-click (and using multimask_output=True), + # we still take the single mask token here. The rationale is that we always track + # after multiple clicks during training, so the past tokens seen during training + # are always the single mask token (and we'll let it be the object-memory token). + sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape + + # Prepare output + return masks, iou_pred, sam_tokens_out, object_score_logits + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts instance segmentation masks from image and prompt embeddings using a transformer.""" + # Concatenate output tokens + s = 0 + if self.pred_obj_scores: + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + s = 1 + else: + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + if repeat_image: + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + else: + assert image_embeddings.shape[0] == tokens.shape[0] + src = image_embeddings + src = src + dense_prompt_embeddings + assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, s, :] + mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + if not self.use_high_res_features: + upscaled_embedding = self.output_upscaling(src) + else: + dc1, ln1, act1, dc2, act2 = self.output_upscaling + feat_s0, feat_s1 = high_res_features + upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) + upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + if self.pred_obj_scores: + assert s == 1 + object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) + else: + # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 + object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) + + return masks, iou_pred, mask_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """Computes mask stability scores based on IoU between upper and lower thresholds.""" + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + Dynamically selects the most stable mask output based on stability scores and IoU predictions. + + This method is used when outputting a single mask. If the stability score from the current single-mask + output (based on output token 0) falls below a threshold, it instead selects from multi-mask outputs + (based on output tokens 1-3) the mask with the highest predicted IoU score. This ensures a valid mask + for both clicking and tracking scenarios. + + Args: + all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is + batch size, N is number of masks (typically 4), and H, W are mask dimensions. + all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N). + + Returns: + (Tuple[torch.Tensor, torch.Tensor]): + - mask_logits_out (torch.Tensor): Selected mask logits, shape (B, 1, H, W). + - iou_scores_out (torch.Tensor): Selected IoU scores, shape (B, 1). + + Examples: + >>> decoder = SAM2MaskDecoder(...) + >>> all_mask_logits = torch.rand(2, 4, 256, 256) # 2 images, 4 masks each + >>> all_iou_scores = torch.rand(2, 4) + >>> mask_logits, iou_scores = decoder._dynamic_multimask_via_stability(all_mask_logits, all_iou_scores) + >>> print(mask_logits.shape, iou_scores.shape) + torch.Size([2, 1, 256, 256]) torch.Size([2, 1]) + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) + batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device) + best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] + best_multimask_logits = best_multimask_logits.unsqueeze(1) + best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] + best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out diff --git a/ultralytics/models/sam/modules/encoders.py b/ultralytics/models/sam/modules/encoders.py index 2bf3945dc..9432fec4b 100644 --- a/ultralytics/models/sam/modules/encoders.py +++ b/ultralytics/models/sam/modules/encoders.py @@ -1,30 +1,48 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -from typing import Any, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from ultralytics.nn.modules import LayerNorm2d, MLPBlock +from ultralytics.nn.modules import LayerNorm2d + +from .blocks import ( + Block, + CXBlock, + Fuser, + MaskDownSampler, + MultiScaleBlock, + PatchEmbed, + PositionEmbeddingRandom, + PositionEmbeddingSine, +) class ImageEncoderViT(nn.Module): """ - An image encoder using Vision Transformer (ViT) architecture for encoding an image into a compact latent space. The - encoder takes an image, splits it into patches, and processes these patches through a series of transformer blocks. - The encoded patches are then processed through a neck to generate the final encoded representation. + An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space. - This class and its supporting functions below lightly adapted from the ViTDet backbone available at - https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py. + This class processes images by splitting them into patches, applying transformer blocks, and generating a final + encoded representation through a neck module. Attributes: img_size (int): Dimension of input images, assumed to be square. patch_embed (PatchEmbed): Module for patch embedding. - pos_embed (nn.Parameter, optional): Absolute positional embedding for patches. + pos_embed (nn.Parameter | None): Absolute positional embedding for patches. blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings. neck (nn.Sequential): Neck module to further process the output. + + Methods: + forward: Processes input through patch embedding, positional embedding, blocks, and neck. + + Examples: + >>> import torch + >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12) + >>> input_image = torch.randn(1, 3, 224, 224) + >>> output = encoder(input_image) + >>> print(output.shape) """ def __init__( @@ -47,22 +65,38 @@ class ImageEncoderViT(nn.Module): global_attn_indexes: Tuple[int, ...] = (), ) -> None: """ + Initializes an ImageEncoderViT instance for encoding images using Vision Transformer architecture. + Args: - img_size (int): Input image size. - patch_size (int): Patch size. + img_size (int): Input image size, assumed to be square. + patch_size (int): Size of image patches. in_chans (int): Number of input image channels. - embed_dim (int): Patch embedding dimension. - depth (int): Depth of ViT. - num_heads (int): Number of attention heads in each ViT block. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool): If True, add a learnable bias to query, key, value. - norm_layer (nn.Module): Normalization layer. - act_layer (nn.Module): Activation layer. - use_abs_pos (bool): If True, use absolute positional embeddings. - use_rel_pos (bool): If True, add relative positional embeddings to the attention map. - rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. - window_size (int): Window size for window attention blocks. - global_attn_indexes (list): Indexes for blocks using global attention. + embed_dim (int): Dimension of patch embeddings. + depth (int): Number of transformer blocks. + num_heads (int): Number of attention heads in each block. + mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension. + out_chans (int): Number of output channels from the neck module. + qkv_bias (bool): If True, adds learnable bias to query, key, value projections. + norm_layer (Type[nn.Module]): Type of normalization layer to use. + act_layer (Type[nn.Module]): Type of activation layer to use. + use_abs_pos (bool): If True, uses absolute positional embeddings. + use_rel_pos (bool): If True, adds relative positional embeddings to attention maps. + rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero. + window_size (int): Size of attention window for windowed attention blocks. + global_attn_indexes (Tuple[int, ...]): Indices of blocks that use global attention. + + Attributes: + img_size (int): Dimension of input images. + patch_embed (PatchEmbed): Module for patch embedding. + pos_embed (nn.Parameter | None): Absolute positional embedding for patches. + blocks (nn.ModuleList): List of transformer blocks. + neck (nn.Sequential): Neck module for final processing. + + Examples: + >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12) + >>> input_image = torch.randn(1, 3, 224, 224) + >>> output = encoder(input_image) + >>> print(output.shape) """ super().__init__() self.img_size = img_size @@ -114,9 +148,7 @@ class ImageEncoderViT(nn.Module): ) def forward(self, x: torch.Tensor) -> torch.Tensor: - """Processes input through patch embedding, applies positional embedding if present, and passes through blocks - and neck. - """ + """Processes input through patch embedding, positional embedding, transformer blocks, and neck module.""" x = self.patch_embed(x) if self.pos_embed is not None: x = x + self.pos_embed @@ -127,8 +159,7 @@ class ImageEncoderViT(nn.Module): class PromptEncoder(nn.Module): """ - Encodes different types of prompts, including points, boxes, and masks, for input to SAM's mask decoder. The encoder - produces both sparse and dense embeddings for the input prompts. + Encodes different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings. Attributes: embed_dim (int): Dimension of the embeddings. @@ -137,10 +168,23 @@ class PromptEncoder(nn.Module): pe_layer (PositionEmbeddingRandom): Module for random position embedding. num_point_embeddings (int): Number of point embeddings for different types of points. point_embeddings (nn.ModuleList): List of point embeddings. - not_a_point_embed (nn.Embedding): Embedding for points that are not a part of any label. + not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label. mask_input_size (Tuple[int, int]): Size of the input mask. mask_downscaling (nn.Sequential): Neural network for downscaling the mask. no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided. + + Methods: + get_dense_pe: Returns the positional encoding used to encode point prompts. + forward: Embeds different types of prompts, returning both sparse and dense embeddings. + + Examples: + >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16) + >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5))) + >>> boxes = torch.rand(1, 2, 2) + >>> masks = torch.rand(1, 1, 256, 256) + >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks) + >>> print(sparse_embeddings.shape, dense_embeddings.shape) + torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64]) """ def __init__( @@ -152,18 +196,37 @@ class PromptEncoder(nn.Module): activation: Type[nn.Module] = nn.GELU, ) -> None: """ - Encodes prompts for input to SAM's mask decoder. + Initializes the PromptEncoder module for encoding various types of prompts. + + This module encodes different types of prompts (points, boxes, masks) for input to SAM's mask decoder, + producing both sparse and dense embeddings. Args: - embed_dim (int): The prompts' embedding dimension - image_embedding_size (tuple(int, int)): The spatial size of the - image embedding, as (H, W). - input_image_size (int): The padded size of the image as input - to the image encoder, as (H, W). - mask_in_chans (int): The number of hidden channels used for - encoding input masks. - activation (nn.Module): The activation to use when encoding - input masks. + embed_dim (int): The dimension of the embeddings. + image_embedding_size (Tuple[int, int]): The spatial size of the image embedding as (H, W). + input_image_size (Tuple[int, int]): The padded size of the input image as (H, W). + mask_in_chans (int): The number of hidden channels used for encoding input masks. + activation (Type[nn.Module]): The activation function to use when encoding input masks. + + Attributes: + embed_dim (int): Dimension of the embeddings. + input_image_size (Tuple[int, int]): Size of the input image as (H, W). + image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W). + pe_layer (PositionEmbeddingRandom): Module for random position embedding. + num_point_embeddings (int): Number of point embeddings for different types of points. + point_embeddings (nn.ModuleList): List of point embeddings. + not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label. + mask_input_size (Tuple[int, int]): Size of the input mask. + mask_downscaling (nn.Sequential): Neural network for downscaling the mask. + + Examples: + >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16) + >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5))) + >>> boxes = torch.rand(1, 2, 2) + >>> masks = torch.rand(1, 1, 256, 256) + >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks) + >>> print(sparse_embeddings.shape, dense_embeddings.shape) + torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64]) """ super().__init__() self.embed_dim = embed_dim @@ -190,16 +253,25 @@ class PromptEncoder(nn.Module): def get_dense_pe(self) -> torch.Tensor: """ - Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the - image encoding. + Returns the dense positional encoding used for encoding point prompts. + + This method generates a positional encoding for a dense set of points matching the shape of the image + encoding. The encoding is used to provide spatial information to the model when processing point prompts. Returns: - torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w) + (torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the + height and width of the image embedding size, respectively. + + Examples: + >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16) + >>> dense_pe = prompt_encoder.get_dense_pe() + >>> print(dense_pe.shape) + torch.Size([1, 256, 64, 64]) """ return self.pe_layer(self.image_embedding_size).unsqueeze(0) def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor: - """Embeds point prompts.""" + """Embeds point prompts by applying positional encoding and label-specific embeddings.""" points = points + 0.5 # Shift to center of pixel if pad: padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) @@ -216,7 +288,7 @@ class PromptEncoder(nn.Module): return point_embedding def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: - """Embeds box prompts.""" + """Embeds box prompts by applying positional encoding and adding corner embeddings.""" boxes = boxes + 0.5 # Shift to center of pixel coords = boxes.reshape(-1, 2, 2) corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) @@ -225,7 +297,7 @@ class PromptEncoder(nn.Module): return corner_embedding def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: - """Embeds mask inputs.""" + """Embeds mask inputs by downscaling and processing through convolutional layers.""" return self.mask_downscaling(masks) @staticmethod @@ -258,14 +330,25 @@ class PromptEncoder(nn.Module): Embeds different types of prompts, returning both sparse and dense embeddings. Args: - points (tuple(torch.Tensor, torch.Tensor), None): point coordinates and labels to embed. - boxes (torch.Tensor, None): boxes to embed - masks (torch.Tensor, None): masks to embed + points (Tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first + tensor contains coordinates with shape (B, N, 2), and the second tensor contains labels with + shape (B, N). + boxes (torch.Tensor | None): Boxes to embed with shape (B, M, 2, 2), where M is the number of boxes. + masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W). Returns: - torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined - by the number of input points and boxes. - torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W) + (Tuple[torch.Tensor, torch.Tensor]): A tuple containing: + - sparse_embeddings (torch.Tensor): Sparse embeddings for points and boxes with shape (B, N, embed_dim). + - dense_embeddings (torch.Tensor): Dense embeddings for masks of shape (B, embed_dim, embed_H, embed_W). + + Examples: + >>> encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16) + >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5))) + >>> boxes = torch.rand(1, 2, 2, 2) + >>> masks = torch.rand(1, 1, 256, 256) + >>> sparse_emb, dense_emb = encoder(points, boxes, masks) + >>> print(sparse_emb.shape, dense_emb.shape) + torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64]) """ bs = self._get_batch_size(points, boxes, masks) sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) @@ -287,319 +370,421 @@ class PromptEncoder(nn.Module): return sparse_embeddings, dense_embeddings -class PositionEmbeddingRandom(nn.Module): - """Positional encoding using random spatial frequencies.""" +class MemoryEncoder(nn.Module): + """ + Encodes pixel features and masks into a memory representation for efficient image segmentation. - def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: - """Initializes a position embedding using random spatial frequencies.""" - super().__init__() - if scale is None or scale <= 0.0: - scale = 1.0 - self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats))) - - # Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation' - torch.use_deterministic_algorithms(False) - torch.backends.cudnn.deterministic = False - - def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: - """Positionally encode points that are normalized to [0,1].""" - # Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape - coords = 2 * coords - 1 - coords = coords @ self.positional_encoding_gaussian_matrix - coords = 2 * np.pi * coords - # Outputs d_1 x ... x d_n x C shape - return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) - - def forward(self, size: Tuple[int, int]) -> torch.Tensor: - """Generate positional encoding for a grid of the specified size.""" - h, w = size - device: Any = self.positional_encoding_gaussian_matrix.device - grid = torch.ones((h, w), device=device, dtype=torch.float32) - y_embed = grid.cumsum(dim=0) - 0.5 - x_embed = grid.cumsum(dim=1) - 0.5 - y_embed = y_embed / h - x_embed = x_embed / w - - pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) - return pe.permute(2, 0, 1) # C x H x W - - def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor: - """Positionally encode points that are not normalized to [0,1].""" - coords = coords_input.clone() - coords[:, :, 0] = coords[:, :, 0] / image_size[1] - coords[:, :, 1] = coords[:, :, 1] / image_size[0] - return self._pe_encoding(coords.to(torch.float)) # B x N x C - - -class Block(nn.Module): - """Transformer blocks with support of window attention and residual propagation blocks.""" + This class processes pixel-level features and masks, fusing them to generate encoded memory representations + suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model). + + Attributes: + mask_downsampler (MaskDownSampler): Module for downsampling input masks. + pix_feat_proj (nn.Conv2d): Convolutional layer for projecting pixel features. + fuser (Fuser): Module for fusing pixel features and masks. + position_encoding (PositionEmbeddingSine): Module for adding positional encoding to features. + out_proj (nn.Module): Output projection layer, either nn.Identity or nn.Conv2d. + + Methods: + forward: Processes input pixel features and masks to generate encoded memory representations. + + Examples: + >>> import torch + >>> encoder = MemoryEncoder(out_dim=256, in_dim=256) + >>> pix_feat = torch.randn(1, 256, 64, 64) + >>> masks = torch.randn(1, 1, 64, 64) + >>> encoded_feat, pos = encoder(pix_feat, masks) + >>> print(encoded_feat.shape, pos.shape) + torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64]) + """ def __init__( self, - dim: int, - num_heads: int, - mlp_ratio: float = 4.0, - qkv_bias: bool = True, - norm_layer: Type[nn.Module] = nn.LayerNorm, - act_layer: Type[nn.Module] = nn.GELU, - use_rel_pos: bool = False, - rel_pos_zero_init: bool = True, - window_size: int = 0, - input_size: Optional[Tuple[int, int]] = None, - ) -> None: - """ - Args: - dim (int): Number of input channels. - num_heads (int): Number of attention heads in each ViT block. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool): If True, add a learnable bias to query, key, value. - norm_layer (nn.Module): Normalization layer. - act_layer (nn.Module): Activation layer. - use_rel_pos (bool): If True, add relative positional embeddings to the attention map. - rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. - window_size (int): Window size for window attention blocks. If it equals 0, then - use global attention. - input_size (tuple(int, int), None): Input resolution for calculating the relative - positional parameter size. - """ + out_dim, + in_dim=256, # in_dim of pix_feats + ): + """Initializes the MemoryEncoder for encoding pixel features and masks into memory representations.""" super().__init__() - self.norm1 = norm_layer(dim) - self.attn = Attention( - dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - use_rel_pos=use_rel_pos, - rel_pos_zero_init=rel_pos_zero_init, - input_size=input_size if window_size == 0 else (window_size, window_size), - ) - self.norm2 = norm_layer(dim) - self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1) - self.window_size = window_size + self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) + self.fuser = Fuser(CXBlock(dim=256), num_layers=2) + self.position_encoding = PositionEmbeddingSine(num_pos_feats=64) + self.out_proj = nn.Identity() + if out_dim != in_dim: + self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Executes a forward pass through the transformer block with window attention and non-overlapping windows.""" - shortcut = x - x = self.norm1(x) - # Window partition - if self.window_size > 0: - H, W = x.shape[1], x.shape[2] - x, pad_hw = window_partition(x, self.window_size) + def forward( + self, + pix_feat: torch.Tensor, + masks: torch.Tensor, + skip_mask_sigmoid: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Processes pixel features and masks to generate encoded memory representations for segmentation.""" + if not skip_mask_sigmoid: + masks = F.sigmoid(masks) + masks = self.mask_downsampler(masks) - x = self.attn(x) - # Reverse window partition - if self.window_size > 0: - x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + # Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA + pix_feat = pix_feat.to(masks.device) - x = shortcut + x - return x + self.mlp(self.norm2(x)) + x = self.pix_feat_proj(pix_feat) + x = x + masks + x = self.fuser(x) + x = self.out_proj(x) + pos = self.position_encoding(x).to(x.dtype) -class Attention(nn.Module): - """Multi-head Attention block with relative position embeddings.""" + return {"vision_features": x, "vision_pos_enc": [pos]} - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = True, - use_rel_pos: bool = False, - rel_pos_zero_init: bool = True, - input_size: Optional[Tuple[int, int]] = None, - ) -> None: - """ - Initialize Attention module. - Args: - dim (int): Number of input channels. - num_heads (int): Number of attention heads. - qkv_bias (bool): If True, add a learnable bias to query, key, value. - rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. - input_size (tuple(int, int), None): Input resolution for calculating the relative - positional parameter size. - """ - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim**-0.5 +class ImageEncoder(nn.Module): + """ + Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings. - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.proj = nn.Linear(dim, dim) + This class combines a trunk network for feature extraction with a neck network for feature refinement + and positional encoding generation. It can optionally discard the lowest resolution features. - self.use_rel_pos = use_rel_pos - if self.use_rel_pos: - assert input_size is not None, "Input size must be provided if using relative positional encoding." - # Initialize relative positional embeddings - self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) - self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + Attributes: + trunk (nn.Module): The trunk network for initial feature extraction. + neck (nn.Module): The neck network for feature refinement and positional encoding generation. + scalp (int): Number of lowest resolution feature levels to discard. + + Methods: + forward: Processes the input image through the trunk and neck networks. + + Examples: + >>> trunk = SomeTrunkNetwork() + >>> neck = SomeNeckNetwork() + >>> encoder = ImageEncoder(trunk, neck, scalp=1) + >>> image = torch.randn(1, 3, 224, 224) + >>> output = encoder(image) + >>> print(output.keys()) + dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn']) + """ - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Applies the forward operation including attention, normalization, MLP, and indexing within window limits.""" - B, H, W, _ = x.shape - # qkv with shape (3, B, nHead, H * W, C) - qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - # q, k, v with shape (B * nHead, H * W, C) - q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + def __init__( + self, + trunk: nn.Module, + neck: nn.Module, + scalp: int = 0, + ): + """Initializes the ImageEncoder with trunk and neck networks for feature extraction and refinement.""" + super().__init__() + self.trunk = trunk + self.neck = neck + self.scalp = scalp + assert ( + self.trunk.channel_list == self.neck.backbone_channel_list + ), f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match." + + def forward(self, sample: torch.Tensor): + """Encodes input through patch embedding, positional embedding, transformer blocks, and neck module.""" + features, pos = self.neck(self.trunk(sample)) + if self.scalp > 0: + # Discard the lowest resolution features + features, pos = features[: -self.scalp], pos[: -self.scalp] + + src = features[-1] + output = { + "vision_features": src, + "vision_pos_enc": pos, + "backbone_fpn": features, + } + return output + + +class FpnNeck(nn.Module): + """ + A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models. - attn = (q * self.scale) @ k.transpose(-2, -1) + This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing, + similar to ViT positional embedding interpolation. - if self.use_rel_pos: - attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + Attributes: + position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding module. + convs (nn.ModuleList): List of convolutional layers for each backbone level. + backbone_channel_list (List[int]): List of channel dimensions from the backbone. + fpn_interp_model (str): Interpolation mode for FPN feature resizing. + fuse_type (str): Type of feature fusion, either 'sum' or 'avg'. + fpn_top_down_levels (List[int]): Levels to have top-down features in outputs. + + Methods: + forward: Performs forward pass through the FPN neck. + + Examples: + >>> backbone_channels = [64, 128, 256, 512] + >>> fpn_neck = FpnNeck(256, backbone_channels) + >>> inputs = [torch.rand(1, c, 32, 32) for c in backbone_channels] + >>> outputs, positions = fpn_neck(inputs) + >>> print(len(outputs), len(positions)) + 4 4 + """ - attn = attn.softmax(dim=-1) - x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) - return self.proj(x) + def __init__( + self, + d_model: int, + backbone_channel_list: List[int], + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + fpn_interp_model: str = "bilinear", + fuse_type: str = "sum", + fpn_top_down_levels: Optional[List[int]] = None, + ): + """ + Initializes a modified Feature Pyramid Network (FPN) neck. + This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing, + similar to ViT positional embedding interpolation. -def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: - """ - Partition into non-overlapping windows with padding if needed. - Args: - x (tensor): input tokens with [B, H, W, C]. - window_size (int): window size. - - Returns: - windows: windows after partition with [B * num_windows, window_size, window_size, C]. - (Hp, Wp): padded height and width before partition - """ - B, H, W, C = x.shape + Args: + d_model (int): Dimension of the model. + backbone_channel_list (List[int]): List of channel dimensions from the backbone. + kernel_size (int): Kernel size for the convolutional layers. + stride (int): Stride for the convolutional layers. + padding (int): Padding for the convolutional layers. + fpn_interp_model (str): Interpolation mode for FPN feature resizing. + fuse_type (str): Type of feature fusion, either 'sum' or 'avg'. + fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs. + + Examples: + >>> backbone_channels = [64, 128, 256, 512] + >>> fpn_neck = FpnNeck(256, backbone_channels) + >>> print(fpn_neck) + """ + super().__init__() + self.position_encoding = PositionEmbeddingSine(num_pos_feats=256) + self.convs = nn.ModuleList() + self.backbone_channel_list = backbone_channel_list + for dim in backbone_channel_list: + current = nn.Sequential() + current.add_module( + "conv", + nn.Conv2d( + in_channels=dim, + out_channels=d_model, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + ) - pad_h = (window_size - H % window_size) % window_size - pad_w = (window_size - W % window_size) % window_size - if pad_h > 0 or pad_w > 0: - x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) - Hp, Wp = H + pad_h, W + pad_w + self.convs.append(current) + self.fpn_interp_model = fpn_interp_model + assert fuse_type in ["sum", "avg"] + self.fuse_type = fuse_type + + # levels to have top-down features in its outputs + # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 + # have top-down propagation, while outputs of level 0 and level 1 have only + # lateral features from the same backbone level. + if fpn_top_down_levels is None: + # default is to have top-down features on all levels + fpn_top_down_levels = range(len(self.convs)) + self.fpn_top_down_levels = list(fpn_top_down_levels) + + def forward(self, xs: List[torch.Tensor]): + """ + Performs forward pass through the Feature Pyramid Network (FPN) neck. - x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows, (Hp, Wp) + This method processes a list of input tensors from the backbone through the FPN, applying lateral connections + and top-down feature fusion. It generates output feature maps and corresponding positional encodings. + Args: + xs (List[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W). -def window_unpartition( - windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] -) -> torch.Tensor: + Returns: + (Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing: + - out (List[torch.Tensor]): List of output feature maps after FPN processing, each with shape + (B, d_model, H, W). + - pos (List[torch.Tensor]): List of positional encodings corresponding to each output feature map. + + Examples: + >>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512]) + >>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]] + >>> outputs, positions = fpn_neck(inputs) + >>> print(len(outputs), len(positions)) + 4 4 + """ + out = [None] * len(self.convs) + pos = [None] * len(self.convs) + assert len(xs) == len(self.convs) + # fpn forward pass + # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py + prev_features = None + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + x = xs[i] + lateral_features = self.convs[n - i](x) + if i in self.fpn_top_down_levels and prev_features is not None: + top_down_features = F.interpolate( + prev_features.to(dtype=torch.float32), + scale_factor=2.0, + mode=self.fpn_interp_model, + align_corners=(None if self.fpn_interp_model == "nearest" else False), + antialias=False, + ) + prev_features = lateral_features + top_down_features + if self.fuse_type == "avg": + prev_features /= 2 + else: + prev_features = lateral_features + x_out = prev_features + out[i] = x_out + pos[i] = self.position_encoding(x_out).to(x_out.dtype) + + return out, pos + + +class Hiera(nn.Module): """ - Window unpartition into original sequences and removing padding. + Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks. - Args: - windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. - window_size (int): window size. - pad_hw (Tuple): padded height and width (Hp, Wp). - hw (Tuple): original height and width (H, W) before padding. + This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for + efficient multiscale feature extraction. It uses a series of transformer blocks organized into stages, + with optional pooling and global attention mechanisms. - Returns: - x: unpartitioned sequences with [B, H, W, C]. + Attributes: + window_spec (Tuple[int, ...]): Window sizes for each stage. + q_stride (Tuple[int, int]): Downsampling stride between stages. + stage_ends (List[int]): Indices of the last block in each stage. + q_pool_blocks (List[int]): Indices of blocks where pooling is applied. + return_interm_layers (bool): Whether to return intermediate layer outputs. + patch_embed (PatchEmbed): Module for patch embedding. + global_att_blocks (Tuple[int, ...]): Indices of blocks with global attention. + window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background. + pos_embed (nn.Parameter): Positional embedding for the background. + pos_embed_window (nn.Parameter): Positional embedding for the window. + blocks (nn.ModuleList): List of MultiScaleBlock modules. + channel_list (List[int]): List of output channel dimensions for each stage. + + Methods: + _get_pos_embed: Generates positional embeddings by interpolating and combining window and background embeddings. + forward: Performs the forward pass through the Hiera model. + + Examples: + >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3)) + >>> input_tensor = torch.randn(1, 3, 224, 224) + >>> output_features = model(input_tensor) + >>> for feat in output_features: + ... print(feat.shape) """ - Hp, Wp = pad_hw - H, W = hw - B = windows.shape[0] // (Hp * Wp // window_size // window_size) - x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) - - if Hp > H or Wp > W: - x = x[:, :H, :W, :].contiguous() - return x + def __init__( + self, + embed_dim: int = 96, # initial embed dim + num_heads: int = 1, # initial number of heads + drop_path_rate: float = 0.0, # stochastic depth + q_pool: int = 3, # number of q_pool stages + q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages + stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage + dim_mul: float = 2.0, # dim_mul factor at stage shift + head_mul: float = 2.0, # head_mul factor at stage shift + window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), + # window size per stage, when not using global att. + window_spec: Tuple[int, ...] = ( + 8, + 4, + 14, + 7, + ), + # global attn in these blocks + global_att_blocks: Tuple[int, ...] = ( + 12, + 16, + 20, + ), + return_interm_layers=True, # return feats from every stage + ): + """Initializes the Hiera model, configuring its hierarchical vision transformer architecture.""" + super().__init__() -def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: - """ - Get relative positional embeddings according to the relative positions of query and key sizes. + assert len(stages) == len(window_spec) + self.window_spec = window_spec - Args: - q_size (int): size of query q. - k_size (int): size of key k. - rel_pos (Tensor): relative position embeddings (L, C). + depth = sum(stages) + self.q_stride = q_stride + self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + assert 0 <= q_pool <= len(self.stage_ends[:-1]) + self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] + self.return_interm_layers = return_interm_layers - Returns: - Extracted positional embeddings according to relative positions. - """ - max_rel_dist = int(2 * max(q_size, k_size) - 1) - # Interpolate rel pos if needed. - if rel_pos.shape[0] != max_rel_dist: - # Interpolate rel pos. - rel_pos_resized = F.interpolate( - rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), - size=max_rel_dist, - mode="linear", + self.patch_embed = PatchEmbed( + embed_dim=embed_dim, + kernel_size=(7, 7), + stride=(4, 4), + padding=(3, 3), ) - rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) - else: - rel_pos_resized = rel_pos - - # Scale the coords with short length if shapes for q and k are different. - q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) - k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) - relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) - - return rel_pos_resized[relative_coords.long()] - - -def add_decomposed_rel_pos( - attn: torch.Tensor, - q: torch.Tensor, - rel_pos_h: torch.Tensor, - rel_pos_w: torch.Tensor, - q_size: Tuple[int, int], - k_size: Tuple[int, int], -) -> torch.Tensor: - """ - Calculate decomposed Relative Positional Embeddings from mvitv2 paper at - https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py. - - Args: - attn (Tensor): attention map. - q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). - rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. - rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. - q_size (Tuple): spatial sequence size of query q with (q_h, q_w). - k_size (Tuple): spatial sequence size of key k with (k_h, k_w). - - Returns: - attn (Tensor): attention map with added relative positional embeddings. - """ - q_h, q_w = q_size - k_h, k_w = k_size - Rh = get_rel_pos(q_h, k_h, rel_pos_h) - Rw = get_rel_pos(q_w, k_w, rel_pos_w) + # Which blocks have global att? + self.global_att_blocks = global_att_blocks - B, _, dim = q.shape - r_q = q.reshape(B, q_h, q_w, dim) - rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) - rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + # Windowed positional embedding (https://arxiv.org/abs/2311.05613) + self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size + self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)) + self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])) - attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( - B, q_h * q_w, k_h * k_w - ) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule - return attn + cur_stage = 1 + self.blocks = nn.ModuleList() + for i in range(depth): + dim_out = embed_dim + # lags by a block, so first block of + # next stage uses an initial window size + # of previous stage and final window size of current stage + window_size = self.window_spec[cur_stage - 1] -class PatchEmbed(nn.Module): - """Image to Patch Embedding.""" + if self.global_att_blocks is not None: + window_size = 0 if i in self.global_att_blocks else window_size - def __init__( - self, - kernel_size: Tuple[int, int] = (16, 16), - stride: Tuple[int, int] = (16, 16), - padding: Tuple[int, int] = (0, 0), - in_chans: int = 3, - embed_dim: int = 768, - ) -> None: - """ - Initialize PatchEmbed module. + if i - 1 in self.stage_ends: + dim_out = int(embed_dim * dim_mul) + num_heads = int(num_heads * head_mul) + cur_stage += 1 - Args: - kernel_size (Tuple): kernel size of the projection layer. - stride (Tuple): stride of the projection layer. - padding (Tuple): padding size of the projection layer. - in_chans (int): Number of input image channels. - embed_dim (int): Patch embedding dimension. - """ - super().__init__() + block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + drop_path=dpr[i], + q_stride=self.q_stride if i in self.q_pool_blocks else None, + window_size=window_size, + ) - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) + embed_dim = dim_out + self.blocks.append(block) - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Computes patch embedding by applying convolution and transposing resulting tensor.""" - return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C + self.channel_list = ( + [self.blocks[i].dim_out for i in self.stage_ends[::-1]] + if return_interm_layers + else [self.blocks[-1].dim_out] + ) + + def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: + """Generates positional embeddings by interpolating and combining window and background embeddings.""" + h, w = hw + window_embed = self.pos_embed_window + pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") + pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)]) + pos_embed = pos_embed.permute(0, 2, 3, 1) + return pos_embed + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + """Performs forward pass through Hiera model, extracting multiscale features from input images.""" + x = self.patch_embed(x) + # x: (B, H, W, C) + + # Add pos embed + x = x + self._get_pos_embed(x.shape[1:3]) + + outputs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers): + feats = x.permute(0, 3, 1, 2) + outputs.append(feats) + + return outputs diff --git a/ultralytics/models/sam2/modules/memory_attention.py b/ultralytics/models/sam/modules/memory_attention.py similarity index 57% rename from ultralytics/models/sam2/modules/memory_attention.py rename to ultralytics/models/sam/modules/memory_attention.py index 8b5673c31..b55b07302 100644 --- a/ultralytics/models/sam2/modules/memory_attention.py +++ b/ultralytics/models/sam/modules/memory_attention.py @@ -6,11 +6,50 @@ from typing import Optional import torch from torch import Tensor, nn -from .sam2_blocks import RoPEAttention +from .blocks import RoPEAttention class MemoryAttentionLayer(nn.Module): - """Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.""" + """ + Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks. + + This class combines self-attention, cross-attention, and feedforward components to process input tensors and + generate memory-based attention outputs. + + Attributes: + d_model (int): Dimensionality of the model. + dim_feedforward (int): Dimensionality of the feedforward network. + dropout_value (float): Dropout rate for regularization. + self_attn (RoPEAttention): Self-attention mechanism using RoPE (Rotary Position Embedding). + cross_attn_image (RoPEAttention): Cross-attention mechanism for image processing. + linear1 (nn.Linear): First linear layer of the feedforward network. + linear2 (nn.Linear): Second linear layer of the feedforward network. + norm1 (nn.LayerNorm): Layer normalization for self-attention output. + norm2 (nn.LayerNorm): Layer normalization for cross-attention output. + norm3 (nn.LayerNorm): Layer normalization for feedforward network output. + dropout1 (nn.Dropout): Dropout layer after self-attention. + dropout2 (nn.Dropout): Dropout layer after cross-attention. + dropout3 (nn.Dropout): Dropout layer after feedforward network. + activation (nn.ReLU): Activation function for the feedforward network. + pos_enc_at_attn (bool): Flag to add positional encoding at attention. + pos_enc_at_cross_attn_queries (bool): Flag to add positional encoding to cross-attention queries. + pos_enc_at_cross_attn_keys (bool): Flag to add positional encoding to cross-attention keys. + + Methods: + forward: Performs the full memory attention operation on input tensors. + _forward_sa: Performs self-attention on input tensor. + _forward_ca: Performs cross-attention between target and memory tensors. + + Examples: + >>> layer = MemoryAttentionLayer(d_model=256, dim_feedforward=2048, dropout=0.1) + >>> tgt = torch.randn(1, 100, 256) + >>> memory = torch.randn(1, 100, 64) + >>> pos = torch.randn(1, 100, 256) + >>> query_pos = torch.randn(1, 100, 256) + >>> output = layer(tgt, memory, pos, query_pos) + >>> print(output.shape) + torch.Size([1, 100, 256]) + """ def __init__( self, @@ -21,7 +60,7 @@ class MemoryAttentionLayer(nn.Module): pos_enc_at_cross_attn_keys: bool = True, pos_enc_at_cross_attn_queries: bool = False, ): - """Initializes a MemoryAttentionLayer with self-attention, cross-attention, and feedforward components.""" + """Initializes a memory attention layer with self-attention, cross-attention, and feedforward components.""" super().__init__() self.d_model = d_model self.dim_feedforward = dim_feedforward @@ -88,7 +127,7 @@ class MemoryAttentionLayer(nn.Module): query_pos: Optional[Tensor] = None, num_k_exclude_rope: int = 0, ) -> torch.Tensor: - """Performs self-attention, cross-attention, and MLP operations on input tensors for memory-based attention.""" + """Processes input tensors using self-attention, cross-attention, and MLP for memory-based attention.""" tgt = self._forward_sa(tgt, query_pos) tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) # MLP @@ -99,7 +138,35 @@ class MemoryAttentionLayer(nn.Module): class MemoryAttention(nn.Module): - """Memory attention module for processing sequential data with self and cross-attention mechanisms.""" + """ + Memory attention module for processing sequential data with self and cross-attention mechanisms. + + This class implements a multi-layer attention mechanism that combines self-attention and cross-attention + for processing sequential data, particularly useful in transformer-like architectures. + + Attributes: + d_model (int): The dimension of the model's hidden state. + layers (nn.ModuleList): A list of MemoryAttentionLayer modules. + num_layers (int): The number of attention layers. + norm (nn.LayerNorm): Layer normalization applied to the output. + pos_enc_at_input (bool): Whether to apply positional encoding at the input. + batch_first (bool): Whether the input tensors are in batch-first format. + + Methods: + forward: Processes input tensors through the attention layers. + + Examples: + >>> d_model = 256 + >>> layer = MemoryAttentionLayer(d_model) + >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3) + >>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model) + >>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model) + >>> curr_pos = torch.randn(10, 32, d_model) + >>> memory_pos = torch.randn(20, 32, d_model) + >>> output = attention(curr, memory, curr_pos, memory_pos) + >>> print(output.shape) + torch.Size([10, 32, 256]) + """ def __init__( self, @@ -126,7 +193,7 @@ class MemoryAttention(nn.Module): memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* ): - """Applies self-attention and cross-attention to input tensors, processing through multiple layers.""" + """Processes input tensors through multiple attention layers, applying self and cross-attention mechanisms.""" if isinstance(curr, list): assert isinstance(curr_pos, list) assert len(curr) == len(curr_pos) == 1 diff --git a/ultralytics/models/sam/modules/sam.py b/ultralytics/models/sam/modules/sam.py index 1617527f0..27509575b 100644 --- a/ultralytics/models/sam/modules/sam.py +++ b/ultralytics/models/sam/modules/sam.py @@ -9,25 +9,48 @@ from typing import List import torch +import torch.nn.functional as F from torch import nn +from torch.nn.init import trunc_normal_ -from .decoders import MaskDecoder +from ultralytics.nn.modules import MLP + +from .blocks import SAM2TwoWayTransformer +from .decoders import MaskDecoder, SAM2MaskDecoder from .encoders import ImageEncoderViT, PromptEncoder +from .utils import get_1d_sine_pe, select_closest_cond_frames + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 class SAMModel(nn.Module): """ - SAMModel (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate - image embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by - the mask decoder to predict object masks. + Segment Anything Model (SAM) for object segmentation tasks. + + This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images + and input prompts. Attributes: mask_threshold (float): Threshold value for mask prediction. - image_encoder (ImageEncoderViT): The backbone used to encode the image into embeddings. - prompt_encoder (PromptEncoder): Encodes various types of input prompts. - mask_decoder (MaskDecoder): Predicts object masks from the image and prompt embeddings. - pixel_mean (List[float]): Mean pixel values for image normalization. - pixel_std (List[float]): Standard deviation values for image normalization. + image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings. + prompt_encoder (PromptEncoder): Encoder for various types of input prompts. + mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings. + pixel_mean (torch.Tensor): Mean pixel values for image normalization, shape (3, 1, 1). + pixel_std (torch.Tensor): Standard deviation values for image normalization, shape (3, 1, 1). + + Methods: + __init__: Initializes the SAMModel with encoders, decoder, and normalization parameters. + + Examples: + >>> image_encoder = ImageEncoderViT(...) + >>> prompt_encoder = PromptEncoder(...) + >>> mask_decoder = MaskDecoder(...) + >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder) + >>> # Further usage depends on SAMPredictor class + + Notes: + All forward() operations are implemented in the SAMPredictor class. """ mask_threshold: float = 0.0 @@ -43,17 +66,22 @@ class SAMModel(nn.Module): """ Initialize the SAMModel class to predict object masks from an image and input prompts. - Note: - All forward() operations moved to SAMPredictor. - Args: image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings. prompt_encoder (PromptEncoder): Encodes various types of input prompts. mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts. - pixel_mean (List[float], optional): Mean values for normalizing pixels in the input image. Defaults to - (123.675, 116.28, 103.53). - pixel_std (List[float], optional): Std values for normalizing pixels in the input image. Defaults to - (58.395, 57.12, 57.375). + pixel_mean (List[float]): Mean values for normalizing pixels in the input image. + pixel_std (List[float]): Std values for normalizing pixels in the input image. + + Examples: + >>> image_encoder = ImageEncoderViT(...) + >>> prompt_encoder = PromptEncoder(...) + >>> mask_decoder = MaskDecoder(...) + >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder) + >>> # Further usage depends on SAMPredictor class + + Notes: + All forward() operations moved to SAMPredictor. """ super().__init__() self.image_encoder = image_encoder @@ -61,3 +89,846 @@ class SAMModel(nn.Module): self.mask_decoder = mask_decoder self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + +class SAM2Model(torch.nn.Module): + """ + SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities. + + This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms + for temporal consistency and efficient tracking of objects across frames. + + Attributes: + mask_threshold (float): Threshold value for mask prediction. + image_encoder (ImageEncoderViT): Visual encoder for extracting image features. + memory_attention (nn.Module): Module for attending to memory features. + memory_encoder (nn.Module): Encoder for generating memory representations. + num_maskmem (int): Number of accessible memory frames. + image_size (int): Size of input images. + backbone_stride (int): Stride of the backbone network output. + sam_prompt_embed_dim (int): Dimension of SAM prompt embeddings. + sam_image_embedding_size (int): Size of SAM image embeddings. + sam_prompt_encoder (PromptEncoder): Encoder for processing input prompts. + sam_mask_decoder (SAM2MaskDecoder): Decoder for generating object masks. + obj_ptr_proj (nn.Module): Projection layer for object pointers. + obj_ptr_tpos_proj (nn.Module): Projection for temporal positional encoding in object pointers. + + Methods: + forward_image: Processes image batch through encoder to extract multi-level features. + track_step: Performs a single tracking step, updating object masks and memory features. + + Examples: + >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder) + >>> image_batch = torch.rand(1, 3, 512, 512) + >>> features = model.forward_image(image_batch) + >>> track_results = model.track_step(0, True, features, None, None, None, {}) + """ + + mask_threshold: float = 0.0 + + def __init__( + self, + image_encoder, + memory_attention, + memory_encoder, + num_maskmem=7, + image_size=512, + backbone_stride=16, + sigmoid_scale_for_mem_enc=1.0, + sigmoid_bias_for_mem_enc=0.0, + binarize_mask_from_pts_for_mem_enc=False, + use_mask_input_as_output_without_sam=False, + max_cond_frames_in_attn=-1, + directly_add_no_mem_embed=False, + use_high_res_features_in_sam=False, + multimask_output_in_sam=False, + multimask_min_pt_num=1, + multimask_max_pt_num=1, + multimask_output_for_tracking=False, + use_multimask_token_for_obj_ptr: bool = False, + iou_prediction_use_sigmoid=False, + memory_temporal_stride_for_eval=1, + add_all_frames_to_correct_as_cond=False, + non_overlap_masks_for_mem_enc=False, + use_obj_ptrs_in_encoder=False, + max_obj_ptrs_in_encoder=16, + add_tpos_enc_to_obj_ptrs=True, + proj_tpos_enc_in_obj_ptrs=False, + only_obj_ptrs_in_the_past_for_eval=False, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + fixed_no_obj_ptr: bool = False, + soft_no_obj_ptr: bool = False, + use_mlp_for_obj_ptr_proj: bool = False, + sam_mask_decoder_extra_args=None, + compile_image_encoder: bool = False, + ): + """ + Initializes the SAM2Model for video object segmentation with memory-based tracking. + + Args: + image_encoder (nn.Module): Visual encoder for extracting image features. + memory_attention (nn.Module): Module for attending to memory features. + memory_encoder (nn.Module): Encoder for generating memory representations. + num_maskmem (int): Number of accessible memory frames. Default is 7 (1 input frame + 6 previous frames). + image_size (int): Size of input images. + backbone_stride (int): Stride of the image backbone output. + sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability. + sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability. + binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames + with clicks during evaluation. + use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM + prompt encoder and mask decoder on frames with mask input. + max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention. + -1 means no limit. + directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the + first frame. + use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder. + multimask_output_in_sam (bool): Whether to output multiple (3) masks for the first click on initial + conditioning frames. + multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM. + multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM. + multimask_output_for_tracking (bool): Whether to use multimask output for tracking. + use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers. + iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1]. + memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation. + add_all_frames_to_correct_as_cond (bool): Whether to append frames with correction clicks to conditioning + frame list. + non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in + memory encoder during evaluation. + use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder. + max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder + cross-attention. + add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in + the encoder. + proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional + encoding in object pointers. + only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past + during evaluation. + pred_obj_scores (bool): Whether to predict if there is an object in the frame. + pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores. + fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present. + soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation. + use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection. + sam_mask_decoder_extra_args (Dict | None): Extra arguments for constructing the SAM mask decoder. + compile_image_encoder (bool): Whether to compile the image encoder for faster inference. + + Examples: + >>> image_encoder = ImageEncoderViT(...) + >>> memory_attention = SAM2TwoWayTransformer(...) + >>> memory_encoder = nn.Sequential(...) + >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder) + >>> image_batch = torch.rand(1, 3, 512, 512) + >>> features = model.forward_image(image_batch) + >>> track_results = model.track_step(0, True, features, None, None, None, {}) + """ + super().__init__() + + # Part 1: the image backbone + self.image_encoder = image_encoder + # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting + self.use_high_res_features_in_sam = use_high_res_features_in_sam + self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 + self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder + self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder + if use_obj_ptrs_in_encoder: + # A conv layer to downsample the mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) + self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs + if proj_tpos_enc_in_obj_ptrs: + assert add_tpos_enc_to_obj_ptrs # these options need to be used together + self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval + + # Part 2: memory attention to condition current frame's visual features + # with memories (and obj ptrs) from past frames + self.memory_attention = memory_attention + self.hidden_dim = memory_attention.d_model + + # Part 3: memory encoder for the previous frame's outputs + self.memory_encoder = memory_encoder + self.mem_dim = self.hidden_dim + if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"): + # if there is compression of memories along channel dim + self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] + self.num_maskmem = num_maskmem # Number of memories accessible + # Temporal encoding of the memories + self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim)) + trunc_normal_(self.maskmem_tpos_enc, std=0.02) + # a single token to indicate no memory embedding from previous frames + self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + trunc_normal_(self.no_mem_embed, std=0.02) + trunc_normal_(self.no_mem_pos_enc, std=0.02) + self.directly_add_no_mem_embed = directly_add_no_mem_embed + # Apply sigmoid to the output raw mask logits (to turn them from + # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder + self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc + self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc + self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc + self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc + self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval + # On frames with mask input, whether to directly output the input mask without + # using a SAM prompt encoder + mask decoder + self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid + + # Part 4: SAM-style prompt encoder (for both mask and point inputs) + # and SAM-style mask decoder for the final mask output + self.image_size = image_size + self.backbone_stride = backbone_stride + self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args + self.pred_obj_scores = pred_obj_scores + self.pred_obj_scores_mlp = pred_obj_scores_mlp + self.fixed_no_obj_ptr = fixed_no_obj_ptr + self.soft_no_obj_ptr = soft_no_obj_ptr + if self.fixed_no_obj_ptr: + assert self.pred_obj_scores + assert self.use_obj_ptrs_in_encoder + if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: + self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) + trunc_normal_(self.no_obj_ptr, std=0.02) + self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + + self._build_sam_heads() + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + self.max_cond_frames_in_attn = max_cond_frames_in_attn + + # Model compilation + if compile_image_encoder: + # Compile the forward function (not the full module) to allow loading checkpoints. + print("Image encoder compilation is enabled. First forward pass will be slow.") + self.image_encoder.forward = torch.compile( + self.image_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + @property + def device(self): + """Returns the device on which the model's parameters are stored.""" + return next(self.parameters()).device + + def forward(self, *args, **kwargs): + """Processes image and prompt inputs to generate object masks and scores in video sequences.""" + raise NotImplementedError( + "Please use the corresponding methods in SAM2VideoPredictor for inference." + "See notebooks/video_predictor_example.ipynb for an example." + ) + + def _build_sam_heads(self): + """Builds SAM-style prompt encoder and mask decoder for image segmentation tasks.""" + self.sam_prompt_embed_dim = self.hidden_dim + self.sam_image_embedding_size = self.image_size // self.backbone_stride + + # build PromptEncoder and MaskDecoder from SAM + # (their hyperparameters like `mask_in_chans=16` are from SAM code) + self.sam_prompt_encoder = PromptEncoder( + embed_dim=self.sam_prompt_embed_dim, + image_embedding_size=( + self.sam_image_embedding_size, + self.sam_image_embedding_size, + ), + input_image_size=(self.image_size, self.image_size), + mask_in_chans=16, + ) + self.sam_mask_decoder = SAM2MaskDecoder( + num_multimask_outputs=3, + transformer=SAM2TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + **(self.sam_mask_decoder_extra_args or {}), + ) + if self.use_obj_ptrs_in_encoder: + # a linear projection on SAM output tokens to turn them into object pointers + self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) + if self.use_mlp_for_obj_ptr_proj: + self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3) + else: + self.obj_ptr_proj = torch.nn.Identity() + if self.proj_tpos_enc_in_obj_ptrs: + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) + else: + self.obj_ptr_tpos_proj = torch.nn.Identity() + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + ): + """ + Forward pass through SAM prompt encoders and mask heads. + + This method processes image features and optional point/mask inputs to generate object masks and scores. + + Args: + backbone_features (torch.Tensor): Image features with shape (B, C, H, W). + point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts. + 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute + pixel-unit coordinates in (x, y) format for P input points. + 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks, + 0 means negative clicks, and -1 means padding. + mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the + same spatial size as the image. + high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes + (B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps + for SAM decoder. + multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False, + output only 1 mask and its IoU estimate. + + Returns: + (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): + low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits. + high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits. + ious: Tensor of shape (B, M) with estimated IoU for each output mask. + low_res_masks: Tensor of shape (B, 1, H*4, W*4) with best low-resolution mask. + high_res_masks: Tensor of shape (B, 1, H*16, W*16) with best high-resolution mask. + obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask. + object_score_logits: Tensor of shape (B,) with object score logits. + + Where M is 3 if multimask_output=True, and 1 if multimask_output=False. + + Examples: + >>> backbone_features = torch.rand(1, 256, 32, 32) + >>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])} + >>> mask_inputs = torch.rand(1, 1, 512, 512) + >>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs) + >>> low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits = results + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=self.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + ) + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + if self.pred_obj_scores: + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + # Only hard possible with gt + assert not self.teacher_force_obj_scores_for_mem + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + """Processes mask inputs directly as output, bypassing SAM encoder/decoder.""" + # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + mask_inputs_float = mask_inputs.float() + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks, + size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + # a dummy IoU prediction of all 1's under mask input + ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() + if not self.use_obj_ptrs_in_encoder: + # all zeros as a dummy object pointer (of shape [B, C]) + obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device) + else: + # produce an object pointer using the SAM decoder from the mask input + _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( + backbone_features=backbone_features, + mask_inputs=self.mask_downsample(mask_inputs_float), + high_res_features=high_res_features, + ) + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing.float() + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + if self.pred_obj_scores: + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_masks, + high_res_masks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def forward_image(self, img_batch: torch.Tensor): + """Processes image batch through encoder to extract multi-level features for SAM model.""" + backbone_out = self.image_encoder(img_batch) + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0]) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1]) + return backbone_out + + def _prepare_backbone_features(self, backbone_out): + """Prepares and flattens visual features from the image backbone output for further processing.""" + backbone_out = backbone_out.copy() + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def _prepare_memory_conditioned_features( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + ): + """Prepares memory-conditioned features by fusing current frame's visual features with previous memories.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + device = current_vision_feats[-1].device + # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. + # In this case, we skip the fusion with any memory. + if self.num_maskmem == 0: # Disable memory and skip fusion + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + return pix_feat + + num_obj_ptr_tokens = 0 + # Step 1: condition the visual features of the current frame on previous memories + if not is_init_cond_frame: + # Retrieve the memories encoded with the maskmem backbone + to_cat_memory, to_cat_memory_pos_embed = [], [] + # Add conditioning frames's output first (all cond frames have t_pos=0 for + # when getting temporal positional embedding below) + assert len(output_dict["cond_frame_outputs"]) > 0 + # Select a maximum number of temporally closest cond frames for cross attention + cond_outputs = output_dict["cond_frame_outputs"] + selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( + frame_idx, cond_outputs, self.max_cond_frames_in_attn + ) + t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] + # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory + # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 + # We also allow taking the memory frame non-consecutively (with r>1), in which case + # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. + r = self.memory_temporal_stride_for_eval + for t_pos in range(1, self.num_maskmem): + t_rel = self.num_maskmem - t_pos # how many frames before current frame + if t_rel == 1: + # for t_rel == 1, we take the last frame (regardless of r) + if not track_in_reverse: + # the frame immediately before this frame (i.e. frame_idx - 1) + prev_frame_idx = frame_idx - t_rel + else: + # the frame immediately after this frame (i.e. frame_idx + 1) + prev_frame_idx = frame_idx + t_rel + else: + # for t_rel >= 2, we take the memory frame from every r-th frames + if not track_in_reverse: + # first find the nearest frame among every r-th frames before this frame + # for r=1, this would be (frame_idx - 2) + prev_frame_idx = ((frame_idx - 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx - (t_rel - 2) * r + else: + # first find the nearest frame among every r-th frames after this frame + # for r=1, this would be (frame_idx + 2) + prev_frame_idx = -(-(frame_idx + 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx + (t_rel - 2) * r + out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) + if out is None: + # If an unselected conditioning frame is among the last (self.num_maskmem - 1) + # frames, we still attend to it as if it's a non-conditioning frame. + out = unselected_cond_outputs.get(prev_frame_idx, None) + t_pos_and_prevs.append((t_pos, out)) + + for t_pos, prev in t_pos_and_prevs: + if prev is None: + continue # skip padding frames + # "maskmem_features" might have been offloaded to CPU in demo use cases, + # so we load it back to GPU (it's a no-op if it's already on GPU). + feats = prev["maskmem_features"].cuda(non_blocking=True) + to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) + # Spatial positional encoding (it might have been offloaded to CPU in eval) + maskmem_enc = prev["maskmem_pos_enc"][-1].cuda() + maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) + # Temporal positional encoding + maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] + to_cat_memory_pos_embed.append(maskmem_enc) + + # Construct the list of past object pointers + if self.use_obj_ptrs_in_encoder: + max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) + # First add those object pointers from selected conditioning frames + # (optionally, only include object pointers in the past during evaluation) + if not self.training and self.only_obj_ptrs_in_the_past_for_eval: + ptr_cond_outputs = { + t: out + for t, out in selected_cond_outputs.items() + if (t >= frame_idx if track_in_reverse else t <= frame_idx) + } + else: + ptr_cond_outputs = selected_cond_outputs + pos_and_ptrs = [ + # Temporal pos encoding contains how far away each pointer is from current frame + (abs(frame_idx - t), out["obj_ptr"]) + for t, out in ptr_cond_outputs.items() + ] + # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame + for t_diff in range(1, max_obj_ptrs_in_encoder): + t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff + if t < 0 or (num_frames is not None and t >= num_frames): + break + out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None)) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"])) + # If we have at least one object pointer, add them to the across attention + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape + obj_ptrs = torch.stack(ptrs_list, dim=0) + # a temporal positional embedding based on how far each object pointer is from + # the current frame (sine embedding normalized by the max pointer num). + if self.add_tpos_enc_to_obj_ptrs: + t_diff_max = max_obj_ptrs_in_encoder - 1 + tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim + obj_pos = torch.tensor(pos_list, device=device) + obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) + obj_pos = self.obj_ptr_tpos_proj(obj_pos) + obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) + else: + obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) + if self.mem_dim < C: + # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C + obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim) + obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) + obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos_embed.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[0] + else: + num_obj_ptr_tokens = 0 + else: + # for initial conditioning frames, encode them without using any previous memory + if self.directly_add_no_mem_embed: + # directly add no-mem embedding (instead of using the transformer encoder) + pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + # Use a dummy token on the first frame (to avoid empty memory input to transformer encoder) + to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] + to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] + + # Step 2: Concatenate the memories and forward through the transformer encoder + memory = torch.cat(to_cat_memory, dim=0) + memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) + + pix_feat_with_mem = self.memory_attention( + curr=current_vision_feats, + curr_pos=current_vision_pos_embeds, + memory=memory, + memory_pos=memory_pos_embed, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + # reshape the output (HW)BC => BCHW + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + is_mask_from_pts, + ): + """Encodes frame features and masks into a new memory representation for video segmentation.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res) + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + maskmem_out = self.memory_encoder( + pix_feat, + mask_for_mem, + skip_mask_sigmoid=True, # sigmoid already applied + ) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + + return maskmem_features, maskmem_pos_enc + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + ): + """Performs a single tracking step, updating object masks and memory features based on current frame inputs.""" + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + _, + ) = sam_outputs + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + return current_out + + def _use_multimask(self, is_init_cond_frame, point_inputs): + """Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = ( + self.multimask_output_in_sam + and (is_init_cond_frame or self.multimask_output_for_tracking) + and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) + ) + return multimask_output + + def _apply_non_overlapping_constraints(self, pred_masks): + """Applies non-overlapping constraints to masks, keeping highest scoring object per location.""" + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks diff --git a/ultralytics/models/sam/modules/tiny_encoder.py b/ultralytics/models/sam/modules/tiny_encoder.py index bc026dd64..0aca78ed6 100644 --- a/ultralytics/models/sam/modules/tiny_encoder.py +++ b/ultralytics/models/sam/modules/tiny_encoder.py @@ -17,16 +17,40 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint +from ultralytics.nn.modules import LayerNorm2d from ultralytics.utils.instance import to_2tuple class Conv2d_BN(torch.nn.Sequential): - """A sequential container that performs 2D convolution followed by batch normalization.""" + """ + A sequential container that performs 2D convolution followed by batch normalization. + + Attributes: + c (torch.nn.Conv2d): 2D convolution layer. + 1 (torch.nn.BatchNorm2d): Batch normalization layer. + + Methods: + __init__: Initializes the Conv2d_BN with specified parameters. + + Args: + a (int): Number of input channels. + b (int): Number of output channels. + ks (int): Kernel size for the convolution. Defaults to 1. + stride (int): Stride for the convolution. Defaults to 1. + pad (int): Padding for the convolution. Defaults to 0. + dilation (int): Dilation factor for the convolution. Defaults to 1. + groups (int): Number of groups for the convolution. Defaults to 1. + bn_weight_init (float): Initial value for batch normalization weight. Defaults to 1. + + Examples: + >>> conv_bn = Conv2d_BN(3, 64, ks=3, stride=1, pad=1) + >>> input_tensor = torch.randn(1, 3, 224, 224) + >>> output = conv_bn(input_tensor) + >>> print(output.shape) + """ def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): - """Initializes the MBConv model with given input channels, output channels, expansion ratio, activation, and - drop path. - """ + """Initializes a sequential container with 2D convolution followed by batch normalization.""" super().__init__() self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) bn = torch.nn.BatchNorm2d(b) @@ -36,12 +60,29 @@ class Conv2d_BN(torch.nn.Sequential): class PatchEmbed(nn.Module): - """Embeds images into patches and projects them into a specified embedding dimension.""" + """ + Embeds images into patches and projects them into a specified embedding dimension. + + Attributes: + patches_resolution (Tuple[int, int]): Resolution of the patches after embedding. + num_patches (int): Total number of patches. + in_chans (int): Number of input channels. + embed_dim (int): Dimension of the embedding. + seq (nn.Sequential): Sequence of convolutional and activation layers for patch embedding. + + Methods: + forward: Processes the input tensor through the patch embedding sequence. + + Examples: + >>> import torch + >>> patch_embed = PatchEmbed(in_chans=3, embed_dim=96, resolution=224, activation=nn.GELU) + >>> x = torch.randn(1, 3, 224, 224) + >>> output = patch_embed(x) + >>> print(output.shape) + """ def __init__(self, in_chans, embed_dim, resolution, activation): - """Initialize the PatchMerging class with specified input, output dimensions, resolution and activation - function. - """ + """Initializes patch embedding with convolutional layers for image-to-patch conversion and projection.""" super().__init__() img_size: Tuple[int, int] = to_2tuple(resolution) self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) @@ -56,17 +97,40 @@ class PatchEmbed(nn.Module): ) def forward(self, x): - """Runs input tensor 'x' through the PatchMerging model's sequence of operations.""" + """Processes input tensor through patch embedding sequence, converting images to patch embeddings.""" return self.seq(x) class MBConv(nn.Module): - """Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.""" + """ + Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture. + + Attributes: + in_chans (int): Number of input channels. + hidden_chans (int): Number of hidden channels. + out_chans (int): Number of output channels. + conv1 (Conv2d_BN): First convolutional layer. + act1 (nn.Module): First activation function. + conv2 (Conv2d_BN): Depthwise convolutional layer. + act2 (nn.Module): Second activation function. + conv3 (Conv2d_BN): Final convolutional layer. + act3 (nn.Module): Third activation function. + drop_path (nn.Module): Drop path layer (Identity for inference). + + Methods: + forward: Performs the forward pass through the MBConv layer. + + Examples: + >>> in_chans, out_chans = 32, 64 + >>> mbconv = MBConv(in_chans, out_chans, expand_ratio=4, activation=nn.ReLU, drop_path=0.1) + >>> x = torch.randn(1, in_chans, 56, 56) + >>> output = mbconv(x) + >>> print(output.shape) + torch.Size([1, 64, 56, 56]) + """ def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path): - """Initializes a convolutional layer with specified dimensions, input resolution, depth, and activation - function. - """ + """Initializes the MBConv layer with specified input/output channels, expansion ratio, and activation.""" super().__init__() self.in_chans = in_chans self.hidden_chans = int(in_chans * expand_ratio) @@ -86,7 +150,7 @@ class MBConv(nn.Module): self.drop_path = nn.Identity() def forward(self, x): - """Implements the forward pass for the model architecture.""" + """Implements the forward pass of MBConv, applying convolutions and skip connection.""" shortcut = x x = self.conv1(x) x = self.act1(x) @@ -99,12 +163,34 @@ class MBConv(nn.Module): class PatchMerging(nn.Module): - """Merges neighboring patches in the feature map and projects to a new dimension.""" + """ + Merges neighboring patches in the feature map and projects to a new dimension. + + This class implements a patch merging operation that combines spatial information and adjusts the feature + dimension. It uses a series of convolutional layers with batch normalization to achieve this. + + Attributes: + input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map. + dim (int): The input dimension of the feature map. + out_dim (int): The output dimension after merging and projection. + act (nn.Module): The activation function used between convolutions. + conv1 (Conv2d_BN): The first convolutional layer for dimension projection. + conv2 (Conv2d_BN): The second convolutional layer for spatial merging. + conv3 (Conv2d_BN): The third convolutional layer for final projection. + + Methods: + forward: Applies the patch merging operation to the input tensor. + + Examples: + >>> input_resolution = (56, 56) + >>> patch_merging = PatchMerging(input_resolution, dim=64, out_dim=128, activation=nn.ReLU) + >>> x = torch.randn(4, 64, 56, 56) + >>> output = patch_merging(x) + >>> print(output.shape) + """ def __init__(self, input_resolution, dim, out_dim, activation): - """Initializes the ConvLayer with specific dimension, input resolution, depth, activation, drop path, and other - optional parameters. - """ + """Initializes the PatchMerging module for merging and projecting neighboring patches in feature maps.""" super().__init__() self.input_resolution = input_resolution @@ -117,7 +203,7 @@ class PatchMerging(nn.Module): self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) def forward(self, x): - """Applies forward pass on the input utilizing convolution and activation layers, and returns the result.""" + """Applies patch merging and dimension projection to the input feature map.""" if x.ndim == 3: H, W = self.input_resolution B = len(x) @@ -137,7 +223,24 @@ class ConvLayer(nn.Module): """ Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv). - Optionally applies downsample operations to the output, and provides support for gradient checkpointing. + This layer optionally applies downsample operations to the output and supports gradient checkpointing. + + Attributes: + dim (int): Dimensionality of the input and output. + input_resolution (Tuple[int, int]): Resolution of the input image. + depth (int): Number of MBConv layers in the block. + use_checkpoint (bool): Whether to use gradient checkpointing to save memory. + blocks (nn.ModuleList): List of MBConv layers. + downsample (Optional[Callable]): Function for downsampling the output. + + Methods: + forward: Processes the input through the convolutional layers. + + Examples: + >>> input_tensor = torch.randn(1, 64, 56, 56) + >>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU) + >>> output = conv_layer(input_tensor) + >>> print(output.shape) """ def __init__( @@ -155,16 +258,25 @@ class ConvLayer(nn.Module): """ Initializes the ConvLayer with the given dimensions and settings. + This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and + optionally applies downsampling to the output. + Args: dim (int): The dimensionality of the input and output. input_resolution (Tuple[int, int]): The resolution of the input image. depth (int): The number of MBConv layers in the block. activation (Callable): Activation function applied after each convolution. - drop_path (Union[float, List[float]]): Drop path rate. Single float or a list of floats for each MBConv. + drop_path (float | List[float]): Drop path rate. Single float or a list of floats for each MBConv. downsample (Optional[Callable]): Function for downsampling the output. None to skip downsampling. use_checkpoint (bool): Whether to use gradient checkpointing to save memory. out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`. conv_expand_ratio (float): Expansion ratio for the MBConv layers. + + Examples: + >>> input_tensor = torch.randn(1, 64, 56, 56) + >>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU) + >>> output = conv_layer(input_tensor) + >>> print(output.shape) """ super().__init__() self.dim = dim @@ -194,7 +306,7 @@ class ConvLayer(nn.Module): ) def forward(self, x): - """Processes the input through a series of convolutional layers and returns the activated output.""" + """Processes input through convolutional layers, applying MBConv blocks and optional downsampling.""" for blk in self.blocks: x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x) return x if self.downsample is None else self.downsample(x) @@ -202,13 +314,33 @@ class ConvLayer(nn.Module): class Mlp(nn.Module): """ - Multi-layer Perceptron (MLP) for transformer architectures. + Multi-layer Perceptron (MLP) module for transformer architectures. + + This module applies layer normalization, two fully-connected layers with an activation function in between, + and dropout. It is commonly used in transformer-based architectures. - This layer takes an input with in_features, applies layer normalization and two fully-connected layers. + Attributes: + norm (nn.LayerNorm): Layer normalization applied to the input. + fc1 (nn.Linear): First fully-connected layer. + fc2 (nn.Linear): Second fully-connected layer. + act (nn.Module): Activation function applied after the first fully-connected layer. + drop (nn.Dropout): Dropout layer applied after the activation function. + + Methods: + forward: Applies the MLP operations on the input tensor. + + Examples: + >>> import torch + >>> from torch import nn + >>> mlp = Mlp(in_features=256, hidden_features=512, out_features=256, act_layer=nn.GELU, drop=0.1) + >>> x = torch.randn(32, 100, 256) + >>> output = mlp(x) + >>> print(output.shape) + torch.Size([32, 100, 256]) """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): - """Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc.""" + """Initializes a multi-layer perceptron with configurable input, hidden, and output dimensions.""" super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -219,7 +351,7 @@ class Mlp(nn.Module): self.drop = nn.Dropout(drop) def forward(self, x): - """Applies operations on input x and returns modified x, runs downsample if not None.""" + """Applies MLP operations: layer norm, FC layers, activation, and dropout to the input tensor.""" x = self.norm(x) x = self.fc1(x) x = self.act(x) @@ -230,12 +362,37 @@ class Mlp(nn.Module): class Attention(torch.nn.Module): """ - Multi-head attention module with support for spatial awareness, applying attention biases based on spatial - resolution. Implements trainable attention biases for each unique offset between spatial positions in the resolution - grid. + Multi-head attention module with spatial awareness and trainable attention biases. + + This module implements a multi-head attention mechanism with support for spatial awareness, applying + attention biases based on spatial resolution. It includes trainable attention biases for each unique + offset between spatial positions in the resolution grid. Attributes: - ab (Tensor, optional): Cached attention biases for inference, deleted during training. + num_heads (int): Number of attention heads. + scale (float): Scaling factor for attention scores. + key_dim (int): Dimensionality of the keys and queries. + nh_kd (int): Product of num_heads and key_dim. + d (int): Dimensionality of the value vectors. + dh (int): Product of d and num_heads. + attn_ratio (float): Attention ratio affecting the dimensions of the value vectors. + norm (nn.LayerNorm): Layer normalization applied to input. + qkv (nn.Linear): Linear layer for computing query, key, and value projections. + proj (nn.Linear): Linear layer for final projection. + attention_biases (nn.Parameter): Learnable attention biases. + attention_bias_idxs (Tensor): Indices for attention biases. + ab (Tensor): Cached attention biases for inference, deleted during training. + + Methods: + train: Sets the module in training mode and handles the 'ab' attribute. + forward: Performs the forward pass of the attention mechanism. + + Examples: + >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14)) + >>> x = torch.randn(1, 196, 256) + >>> output = attn(x) + >>> print(output.shape) + torch.Size([1, 196, 256]) """ def __init__( @@ -247,17 +404,28 @@ class Attention(torch.nn.Module): resolution=(14, 14), ): """ - Initializes the Attention module. + Initializes the Attention module for multi-head attention with spatial awareness. + + This module implements a multi-head attention mechanism with support for spatial awareness, applying + attention biases based on spatial resolution. It includes trainable attention biases for each unique + offset between spatial positions in the resolution grid. Args: dim (int): The dimensionality of the input and output. key_dim (int): The dimensionality of the keys and queries. - num_heads (int, optional): Number of attention heads. Default is 8. - attn_ratio (float, optional): Attention ratio, affecting the dimensions of the value vectors. Default is 4. - resolution (Tuple[int, int], optional): Spatial resolution of the input feature map. Default is (14, 14). + num_heads (int): Number of attention heads. Default is 8. + attn_ratio (float): Attention ratio, affecting the dimensions of the value vectors. Default is 4. + resolution (Tuple[int, int]): Spatial resolution of the input feature map. Default is (14, 14). Raises: - AssertionError: If `resolution` is not a tuple of length 2. + AssertionError: If 'resolution' is not a tuple of length 2. + + Examples: + >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14)) + >>> x = torch.randn(1, 196, 256) + >>> output = attn(x) + >>> print(output.shape) + torch.Size([1, 196, 256]) """ super().__init__() @@ -290,7 +458,7 @@ class Attention(torch.nn.Module): @torch.no_grad() def train(self, mode=True): - """Sets the module in training mode and handles attribute 'ab' based on the mode.""" + """Performs multi-head attention with spatial awareness and trainable attention biases.""" super().train(mode) if mode and hasattr(self, "ab"): del self.ab @@ -298,7 +466,7 @@ class Attention(torch.nn.Module): self.ab = self.attention_biases[:, self.attention_bias_idxs] def forward(self, x): # x - """Performs forward pass over the input tensor 'x' by applying normalization and querying keys/values.""" + """Applies multi-head attention with spatial awareness and trainable attention biases.""" B, N, _ = x.shape # B, N, C # Normalization @@ -322,7 +490,34 @@ class Attention(torch.nn.Module): class TinyViTBlock(nn.Module): - """TinyViT Block that applies self-attention and a local convolution to the input.""" + """ + TinyViT Block that applies self-attention and a local convolution to the input. + + This block is a key component of the TinyViT architecture, combining self-attention mechanisms with + local convolutions to process input features efficiently. + + Attributes: + dim (int): The dimensionality of the input and output. + input_resolution (Tuple[int, int]): Spatial resolution of the input feature map. + num_heads (int): Number of attention heads. + window_size (int): Size of the attention window. + mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension. + drop_path (nn.Module): Stochastic depth layer, identity function during inference. + attn (Attention): Self-attention module. + mlp (Mlp): Multi-layer perceptron module. + local_conv (Conv2d_BN): Depth-wise local convolution layer. + + Methods: + forward: Processes the input through the TinyViT block. + extra_repr: Returns a string with extra information about the block's parameters. + + Examples: + >>> input_tensor = torch.randn(1, 196, 192) + >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3) + >>> output = block(input_tensor) + >>> print(output.shape) + torch.Size([1, 196, 192]) + """ def __init__( self, @@ -337,22 +532,32 @@ class TinyViTBlock(nn.Module): activation=nn.GELU, ): """ - Initializes the TinyViTBlock. + Initializes a TinyViT block with self-attention and local convolution. + + This block is a key component of the TinyViT architecture, combining self-attention mechanisms with + local convolutions to process input features efficiently. Args: - dim (int): The dimensionality of the input and output. - input_resolution (Tuple[int, int]): Spatial resolution of the input feature map. + dim (int): Dimensionality of the input and output features. + input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width). num_heads (int): Number of attention heads. - window_size (int, optional): Window size for attention. Default is 7. - mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4. - drop (float, optional): Dropout rate. Default is 0. - drop_path (float, optional): Stochastic depth rate. Default is 0. - local_conv_size (int, optional): The kernel size of the local convolution. Default is 3. - activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU. + window_size (int): Size of the attention window. Must be greater than 0. + mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension. + drop (float): Dropout rate. + drop_path (float): Stochastic depth rate. + local_conv_size (int): Kernel size of the local convolution. + activation (torch.nn.Module): Activation function for MLP. Raises: - AssertionError: If `window_size` is not greater than 0. - AssertionError: If `dim` is not divisible by `num_heads`. + AssertionError: If window_size is not greater than 0. + AssertionError: If dim is not divisible by num_heads. + + Examples: + >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3) + >>> input_tensor = torch.randn(1, 196, 192) + >>> output = block(input_tensor) + >>> print(output.shape) + torch.Size([1, 196, 192]) """ super().__init__() self.dim = dim @@ -380,9 +585,7 @@ class TinyViTBlock(nn.Module): self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) def forward(self, x): - """Applies attention-based transformation or padding to input 'x' before passing it through a local - convolution. - """ + """Applies self-attention, local convolution, and MLP operations to the input tensor.""" h, w = self.input_resolution b, hw, c = x.shape # batch, height*width, channels assert hw == h * w, "input feature has wrong size" @@ -424,8 +627,19 @@ class TinyViTBlock(nn.Module): return x + self.drop_path(self.mlp(x)) def extra_repr(self) -> str: - """Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of - attentions heads, window size, and MLP ratio. + """ + Returns a string representation of the TinyViTBlock's parameters. + + This method provides a formatted string containing key information about the TinyViTBlock, including its + dimension, input resolution, number of attention heads, window size, and MLP ratio. + + Returns: + (str): A formatted string containing the block's parameters. + + Examples: + >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0) + >>> print(block.extra_repr()) + dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0 """ return ( f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " @@ -434,7 +648,31 @@ class TinyViTBlock(nn.Module): class BasicLayer(nn.Module): - """A basic TinyViT layer for one stage in a TinyViT architecture.""" + """ + A basic TinyViT layer for one stage in a TinyViT architecture. + + This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks + and an optional downsampling operation. + + Attributes: + dim (int): The dimensionality of the input and output features. + input_resolution (Tuple[int, int]): Spatial resolution of the input feature map. + depth (int): Number of TinyViT blocks in this layer. + use_checkpoint (bool): Whether to use gradient checkpointing to save memory. + blocks (nn.ModuleList): List of TinyViT blocks that make up this layer. + downsample (nn.Module | None): Downsample layer at the end of the layer, if specified. + + Methods: + forward: Processes the input through the layer's blocks and optional downsampling. + extra_repr: Returns a string with the layer's parameters for printing. + + Examples: + >>> input_tensor = torch.randn(1, 3136, 192) + >>> layer = BasicLayer(dim=192, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7) + >>> output = layer(input_tensor) + >>> print(output.shape) + torch.Size([1, 784, 384]) + """ def __init__( self, @@ -453,25 +691,34 @@ class BasicLayer(nn.Module): out_dim=None, ): """ - Initializes the BasicLayer. + Initializes a BasicLayer in the TinyViT architecture. + + This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to + process feature maps at a specific resolution and dimensionality within the TinyViT model. Args: - dim (int): The dimensionality of the input and output. - input_resolution (Tuple[int, int]): Spatial resolution of the input feature map. - depth (int): Number of TinyViT blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4. - drop (float, optional): Dropout rate. Default is 0. - drop_path (float | tuple[float], optional): Stochastic depth rate. Default is 0. - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default is None. - use_checkpoint (bool, optional): Whether to use checkpointing to save memory. Default is False. - local_conv_size (int, optional): Kernel size of the local convolution. Default is 3. - activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU. - out_dim (int | None, optional): The output dimension of the layer. Default is None. + dim (int): Dimensionality of the input and output features. + input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width). + depth (int): Number of TinyViT blocks in this layer. + num_heads (int): Number of attention heads in each TinyViT block. + window_size (int): Size of the local window for attention computation. + mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension. + drop (float): Dropout rate. + drop_path (float | List[float]): Stochastic depth rate. Can be a float or a list of floats for each block. + downsample (nn.Module | None): Downsampling layer at the end of the layer. None to skip downsampling. + use_checkpoint (bool): Whether to use gradient checkpointing to save memory. + local_conv_size (int): Kernel size for the local convolution in each TinyViT block. + activation (nn.Module): Activation function used in the MLP. + out_dim (int | None): Output dimension after downsampling. None means it will be the same as `dim`. Raises: - ValueError: If `drop_path` is a list of float but its length doesn't match `depth`. + ValueError: If `drop_path` is a list and its length doesn't match `depth`. + + Examples: + >>> layer = BasicLayer(dim=96, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7) + >>> x = torch.randn(1, 56*56, 96) + >>> output = layer(x) + >>> print(output.shape) """ super().__init__() self.dim = dim @@ -505,58 +752,49 @@ class BasicLayer(nn.Module): ) def forward(self, x): - """Performs forward propagation on the input tensor and returns a normalized tensor.""" + """Processes input through TinyViT blocks and optional downsampling.""" for blk in self.blocks: x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x) return x if self.downsample is None else self.downsample(x) def extra_repr(self) -> str: - """Returns a string representation of the extra_repr function with the layer's parameters.""" + """Returns a string with the layer's parameters for printing.""" return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" -class LayerNorm2d(nn.Module): - """A PyTorch implementation of Layer Normalization in 2D.""" - - def __init__(self, num_channels: int, eps: float = 1e-6) -> None: - """Initialize LayerNorm2d with the number of channels and an optional epsilon.""" - super().__init__() - self.weight = nn.Parameter(torch.ones(num_channels)) - self.bias = nn.Parameter(torch.zeros(num_channels)) - self.eps = eps - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Perform a forward pass, normalizing the input tensor.""" - u = x.mean(1, keepdim=True) - s = (x - u).pow(2).mean(1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.eps) - return self.weight[:, None, None] * x + self.bias[:, None, None] - - class TinyViT(nn.Module): """ - The TinyViT architecture for vision tasks. + TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction. + + This class implements the TinyViT model, which combines elements of vision transformers and convolutional + neural networks for improved efficiency and performance on vision tasks. Attributes: img_size (int): Input image size. - in_chans (int): Number of input channels. num_classes (int): Number of classification classes. - embed_dims (List[int]): List of embedding dimensions for each layer. - depths (List[int]): List of depths for each layer. - num_heads (List[int]): List of number of attention heads for each layer. - window_sizes (List[int]): List of window sizes for each layer. + depths (List[int]): Number of blocks in each stage. + num_layers (int): Total number of layers in the network. mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension. - drop_rate (float): Dropout rate for drop layers. - drop_path_rate (float): Drop path rate for stochastic depth. - use_checkpoint (bool): Use checkpointing for efficient memory usage. - mbconv_expand_ratio (float): Expansion ratio for MBConv layer. - local_conv_size (int): Local convolution kernel size. - layer_lr_decay (float): Layer-wise learning rate decay. - - Note: - This implementation is generalized to accept a list of depths, attention heads, - embedding dimensions and window sizes, which allows you to create a - "stack" of TinyViT models of varying configurations. + patch_embed (PatchEmbed): Module for patch embedding. + patches_resolution (Tuple[int, int]): Resolution of embedded patches. + layers (nn.ModuleList): List of network layers. + norm_head (nn.LayerNorm): Layer normalization for the classifier head. + head (nn.Linear): Linear layer for final classification. + neck (nn.Sequential): Neck module for feature refinement. + + Methods: + set_layer_lr_decay: Sets layer-wise learning rate decay. + _init_weights: Initializes weights for linear and normalization layers. + no_weight_decay_keywords: Returns keywords for parameters that should not use weight decay. + forward_features: Processes input through the feature extraction layers. + forward: Performs a forward pass through the entire network. + + Examples: + >>> model = TinyViT(img_size=224, num_classes=1000) + >>> x = torch.randn(1, 3, 224, 224) + >>> features = model.forward_features(x) + >>> print(features.shape) + torch.Size([1, 256, 64, 64]) """ def __init__( @@ -579,21 +817,33 @@ class TinyViT(nn.Module): """ Initializes the TinyViT model. + This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of + attention and convolution blocks, and a classification head. + Args: - img_size (int, optional): The input image size. Defaults to 224. - in_chans (int, optional): Number of input channels. Defaults to 3. - num_classes (int, optional): Number of classification classes. Defaults to 1000. - embed_dims (List[int], optional): List of embedding dimensions per layer. Defaults to [96, 192, 384, 768]. - depths (List[int], optional): List of depths for each layer. Defaults to [2, 2, 6, 2]. - num_heads (List[int], optional): List of number of attention heads per layer. Defaults to [3, 6, 12, 24]. - window_sizes (List[int], optional): List of window sizes for each layer. Defaults to [7, 7, 14, 7]. - mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension. Defaults to 4. - drop_rate (float, optional): Dropout rate. Defaults to 0. - drop_path_rate (float, optional): Drop path rate for stochastic depth. Defaults to 0.1. - use_checkpoint (bool, optional): Whether to use checkpointing for efficient memory usage. Defaults to False. - mbconv_expand_ratio (float, optional): Expansion ratio for MBConv layer. Defaults to 4.0. - local_conv_size (int, optional): Local convolution kernel size. Defaults to 3. - layer_lr_decay (float, optional): Layer-wise learning rate decay. Defaults to 1.0. + img_size (int): Size of the input image. Default is 224. + in_chans (int): Number of input channels. Default is 3. + num_classes (int): Number of classes for classification. Default is 1000. + embed_dims (Tuple[int, int, int, int]): Embedding dimensions for each stage. + Default is (96, 192, 384, 768). + depths (Tuple[int, int, int, int]): Number of blocks in each stage. Default is (2, 2, 6, 2). + num_heads (Tuple[int, int, int, int]): Number of attention heads in each stage. + Default is (3, 6, 12, 24). + window_sizes (Tuple[int, int, int, int]): Window sizes for each stage. Default is (7, 7, 14, 7). + mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. Default is 4.0. + drop_rate (float): Dropout rate. Default is 0.0. + drop_path_rate (float): Stochastic depth rate. Default is 0.1. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default is False. + mbconv_expand_ratio (float): Expansion ratio for MBConv layer. Default is 4.0. + local_conv_size (int): Kernel size for local convolutions. Default is 3. + layer_lr_decay (float): Layer-wise learning rate decay factor. Default is 1.0. + + Examples: + >>> model = TinyViT(img_size=224, num_classes=1000) + >>> x = torch.randn(1, 3, 224, 224) + >>> output = model(x) + >>> print(output.shape) + torch.Size([1, 1000]) """ super().__init__() self.img_size = img_size @@ -671,7 +921,7 @@ class TinyViT(nn.Module): ) def set_layer_lr_decay(self, layer_lr_decay): - """Sets the learning rate decay for each layer in the TinyViT model.""" + """Sets layer-wise learning rate decay for the TinyViT model based on depth.""" decay_rate = layer_lr_decay # Layers -> blocks (depth) @@ -706,7 +956,7 @@ class TinyViT(nn.Module): self.apply(_check_lr_scale) def _init_weights(self, m): - """Initializes weights for linear layers and layer normalization in the given module.""" + """Initializes weights for linear and normalization layers in the TinyViT model.""" if isinstance(m, nn.Linear): # NOTE: This initialization is needed only for training. # trunc_normal_(m.weight, std=.02) @@ -718,11 +968,11 @@ class TinyViT(nn.Module): @torch.jit.ignore def no_weight_decay_keywords(self): - """Returns a dictionary of parameter names where weight decay should not be applied.""" + """Returns a set of keywords for parameters that should not use weight decay.""" return {"attention_biases"} def forward_features(self, x): - """Runs the input through the model layers and returns the transformed output.""" + """Processes input through feature extraction layers, returning spatial features.""" x = self.patch_embed(x) # x input is (N, C, H, W) x = self.layers[0](x) @@ -737,5 +987,5 @@ class TinyViT(nn.Module): return self.neck(x) def forward(self, x): - """Executes a forward pass on the input tensor through the constructed model layers.""" + """Performs the forward pass through the TinyViT model, extracting features from the input image.""" return self.forward_features(x) diff --git a/ultralytics/models/sam/modules/transformer.py b/ultralytics/models/sam/modules/transformer.py index 6375c2adc..b8c408193 100644 --- a/ultralytics/models/sam/modules/transformer.py +++ b/ultralytics/models/sam/modules/transformer.py @@ -11,19 +11,31 @@ from ultralytics.nn.modules import MLPBlock class TwoWayTransformer(nn.Module): """ - A Two-Way Transformer module that enables the simultaneous attention to both image and query points. This class - serves as a specialized transformer decoder that attends to an input image using queries whose positional embedding - is supplied. This is particularly useful for tasks like object detection, image segmentation, and point cloud - processing. + A Two-Way Transformer module for simultaneous attention to image and query points. + + This class implements a specialized transformer decoder that attends to an input image using queries with + supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point + cloud processing. Attributes: - depth (int): The number of layers in the transformer. - embedding_dim (int): The channel dimension for the input embeddings. - num_heads (int): The number of heads for multihead attention. - mlp_dim (int): The internal channel dimension for the MLP block. - layers (nn.ModuleList): The list of TwoWayAttentionBlock layers that make up the transformer. - final_attn_token_to_image (Attention): The final attention layer applied from the queries to the image. - norm_final_attn (nn.LayerNorm): The layer normalization applied to the final queries. + depth (int): Number of layers in the transformer. + embedding_dim (int): Channel dimension for input embeddings. + num_heads (int): Number of heads for multihead attention. + mlp_dim (int): Internal channel dimension for the MLP block. + layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer. + final_attn_token_to_image (Attention): Final attention layer from queries to image. + norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries. + + Methods: + forward: Processes image and point embeddings through the transformer. + + Examples: + >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048) + >>> image_embedding = torch.randn(1, 256, 32, 32) + >>> image_pe = torch.randn(1, 256, 32, 32) + >>> point_embedding = torch.randn(1, 100, 256) + >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding) + >>> print(output_queries.shape, output_image.shape) """ def __init__( @@ -36,15 +48,33 @@ class TwoWayTransformer(nn.Module): attention_downsample_rate: int = 2, ) -> None: """ - A transformer decoder that attends to an input image using queries whose positional embedding is supplied. + Initialize a Two-Way Transformer for simultaneous attention to image and query points. Args: - depth (int): number of layers in the transformer - embedding_dim (int): the channel dimension for the input embeddings - num_heads (int): the number of heads for multihead attention. Must - divide embedding_dim - mlp_dim (int): the channel dimension internal to the MLP block - activation (nn.Module): the activation to use in the MLP block + depth (int): Number of layers in the transformer. + embedding_dim (int): Channel dimension for input embeddings. + num_heads (int): Number of heads for multihead attention. Must divide embedding_dim. + mlp_dim (int): Internal channel dimension for the MLP block. + activation (Type[nn.Module]): Activation function to use in the MLP block. + attention_downsample_rate (int): Downsampling rate for attention mechanism. + + Attributes: + depth (int): Number of layers in the transformer. + embedding_dim (int): Channel dimension for input embeddings. + embedding_dim (int): Channel dimension for input embeddings. + num_heads (int): Number of heads for multihead attention. + mlp_dim (int): Internal channel dimension for the MLP block. + layers (nn.ModuleList): List of TwoWayAttentionBlock layers. + final_attn_token_to_image (Attention): Final attention layer from queries to image. + norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries. + + Examples: + >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048) + >>> image_embedding = torch.randn(1, 256, 32, 32) + >>> image_pe = torch.randn(1, 256, 32, 32) + >>> point_embedding = torch.randn(1, 100, 256) + >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding) + >>> print(output_queries.shape, output_image.shape) """ super().__init__() self.depth = depth @@ -75,15 +105,23 @@ class TwoWayTransformer(nn.Module): point_embedding: Tensor, ) -> Tuple[Tensor, Tensor]: """ + Processes image and point embeddings through the Two-Way Transformer. + Args: - image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w. - image_pe (torch.Tensor): the positional encoding to add to the image. Must have same shape as image_embedding. - point_embedding (torch.Tensor): the embedding to add to the query points. - Must have shape B x N_points x embedding_dim for any N_points. + image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W). + image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding. + point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim). Returns: - (torch.Tensor): the processed point_embedding - (torch.Tensor): the processed image_embedding + (Tuple[torch.Tensor, torch.Tensor]): Processed point_embedding and image_embedding. + + Examples: + >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048) + >>> image_embedding = torch.randn(1, 256, 32, 32) + >>> image_pe = torch.randn(1, 256, 32, 32) + >>> point_embedding = torch.randn(1, 100, 256) + >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding) + >>> print(output_queries.shape, output_image.shape) """ # BxCxHxW -> BxHWxC == B x N_image_tokens x C image_embedding = image_embedding.flatten(2).permute(0, 2, 1) @@ -114,21 +152,34 @@ class TwoWayTransformer(nn.Module): class TwoWayAttentionBlock(nn.Module): """ - An attention block that performs both self-attention and cross-attention in two directions: queries to keys and - keys to queries. This block consists of four main layers: (1) self-attention on sparse inputs, (2) cross-attention - of sparse inputs to dense inputs, (3) an MLP block on sparse inputs, and (4) cross-attention of dense inputs to - sparse inputs. + A two-way attention block for simultaneous attention to image and query points. + + This class implements a specialized transformer block with four main layers: self-attention on sparse inputs, + cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense + inputs to sparse inputs. Attributes: - self_attn (Attention): The self-attention layer for the queries. - norm1 (nn.LayerNorm): Layer normalization following the first attention block. + self_attn (Attention): Self-attention layer for queries. + norm1 (nn.LayerNorm): Layer normalization after self-attention. cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys. - norm2 (nn.LayerNorm): Layer normalization following the second attention block. - mlp (MLPBlock): MLP block that transforms the query embeddings. - norm3 (nn.LayerNorm): Layer normalization following the MLP block. - norm4 (nn.LayerNorm): Layer normalization following the third attention block. + norm2 (nn.LayerNorm): Layer normalization after token-to-image attention. + mlp (MLPBlock): MLP block for transforming query embeddings. + norm3 (nn.LayerNorm): Layer normalization after MLP block. + norm4 (nn.LayerNorm): Layer normalization after image-to-token attention. cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries. - skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer. + skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer. + + Methods: + forward: Applies self-attention and cross-attention to queries and keys. + + Examples: + >>> embedding_dim, num_heads = 256, 8 + >>> block = TwoWayAttentionBlock(embedding_dim, num_heads) + >>> queries = torch.randn(1, 100, embedding_dim) + >>> keys = torch.randn(1, 1000, embedding_dim) + >>> query_pe = torch.randn(1, 100, embedding_dim) + >>> key_pe = torch.randn(1, 1000, embedding_dim) + >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe) """ def __init__( @@ -141,16 +192,28 @@ class TwoWayAttentionBlock(nn.Module): skip_first_layer_pe: bool = False, ) -> None: """ - A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse - inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse - inputs. + Initializes a TwoWayAttentionBlock for simultaneous attention to image and query points. + + This block implements a specialized transformer layer with four main components: self-attention on sparse + inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention + of dense inputs to sparse inputs. Args: - embedding_dim (int): the channel dimension of the embeddings - num_heads (int): the number of heads in the attention layers - mlp_dim (int): the hidden dimension of the mlp block - activation (nn.Module): the activation of the mlp block - skip_first_layer_pe (bool): skip the PE on the first layer + embedding_dim (int): Channel dimension of the embeddings. + num_heads (int): Number of attention heads in the attention layers. + mlp_dim (int): Hidden dimension of the MLP block. + activation (Type[nn.Module]): Activation function for the MLP block. + attention_downsample_rate (int): Downsampling rate for the attention mechanism. + skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer. + + Examples: + >>> embedding_dim, num_heads = 256, 8 + >>> block = TwoWayAttentionBlock(embedding_dim, num_heads) + >>> queries = torch.randn(1, 100, embedding_dim) + >>> keys = torch.randn(1, 1000, embedding_dim) + >>> query_pe = torch.randn(1, 100, embedding_dim) + >>> key_pe = torch.randn(1, 1000, embedding_dim) + >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe) """ super().__init__() self.self_attn = Attention(embedding_dim, num_heads) @@ -168,7 +231,7 @@ class TwoWayAttentionBlock(nn.Module): self.skip_first_layer_pe = skip_first_layer_pe def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]: - """Apply self-attention and cross-attention to queries and keys and return the processed embeddings.""" + """Applies two-way attention to process query and key embeddings in a transformer block.""" # Self attention block if self.skip_first_layer_pe: @@ -202,8 +265,34 @@ class TwoWayAttentionBlock(nn.Module): class Attention(nn.Module): - """An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and - values. + """ + An attention layer with downscaling capability for embedding size after projection. + + This class implements a multi-head attention mechanism with the option to downsample the internal + dimension of queries, keys, and values. + + Attributes: + embedding_dim (int): Dimensionality of input embeddings. + kv_in_dim (int): Dimensionality of key and value inputs. + internal_dim (int): Internal dimension after downsampling. + num_heads (int): Number of attention heads. + q_proj (nn.Linear): Linear projection for queries. + k_proj (nn.Linear): Linear projection for keys. + v_proj (nn.Linear): Linear projection for values. + out_proj (nn.Linear): Linear projection for output. + + Methods: + _separate_heads: Separates input tensor into attention heads. + _recombine_heads: Recombines separated attention heads. + forward: Computes attention output for given query, key, and value tensors. + + Examples: + >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2) + >>> q = torch.randn(1, 100, 256) + >>> k = v = torch.randn(1, 50, 256) + >>> output = attn(q, k, v) + >>> print(output.shape) + torch.Size([1, 100, 256]) """ def __init__( @@ -214,15 +303,27 @@ class Attention(nn.Module): kv_in_dim: int = None, ) -> None: """ - Initializes the Attention model with the given dimensions and settings. + Initializes the Attention module with specified dimensions and settings. + + This class implements a multi-head attention mechanism with optional downsampling of the internal + dimension for queries, keys, and values. Args: - embedding_dim (int): The dimensionality of the input embeddings. - num_heads (int): The number of attention heads. - downsample_rate (int, optional): The factor by which the internal dimensions are downsampled. Defaults to 1. + embedding_dim (int): Dimensionality of input embeddings. + num_heads (int): Number of attention heads. + downsample_rate (int): Factor by which internal dimensions are downsampled. Defaults to 1. + kv_in_dim (int | None): Dimensionality of key and value inputs. If None, uses embedding_dim. Raises: - AssertionError: If 'num_heads' does not evenly divide the internal dim (embedding_dim / downsample_rate). + AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate). + + Examples: + >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2) + >>> q = torch.randn(1, 100, 256) + >>> k = v = torch.randn(1, 50, 256) + >>> output = attn(q, k, v) + >>> print(output.shape) + torch.Size([1, 100, 256]) """ super().__init__() self.embedding_dim = embedding_dim @@ -238,20 +339,20 @@ class Attention(nn.Module): @staticmethod def _separate_heads(x: Tensor, num_heads: int) -> Tensor: - """Separate the input tensor into the specified number of attention heads.""" + """Separates the input tensor into the specified number of attention heads.""" b, n, c = x.shape x = x.reshape(b, n, num_heads, c // num_heads) return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head @staticmethod def _recombine_heads(x: Tensor) -> Tensor: - """Recombine the separated attention heads into a single tensor.""" + """Recombines separated attention heads into a single tensor.""" b, n_heads, n_tokens, c_per_head = x.shape x = x.transpose(1, 2) return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: - """Compute the attention output given the input query, key, and value tensors.""" + """Applies multi-head attention to query, key, and value tensors with optional downsampling.""" # Input projections q = self.q_proj(q) diff --git a/ultralytics/models/sam2/modules/utils.py b/ultralytics/models/sam/modules/utils.py similarity index 65% rename from ultralytics/models/sam2/modules/utils.py rename to ultralytics/models/sam/modules/utils.py index b09dd9b2f..debc62e8e 100644 --- a/ultralytics/models/sam2/modules/utils.py +++ b/ultralytics/models/sam/modules/utils.py @@ -1,5 +1,7 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license +from typing import Tuple + import torch import torch.nn.functional as F @@ -70,7 +72,7 @@ def get_1d_sine_pe(pos_inds, dim, temperature=10000): def init_t_xy(end_x: int, end_y: int): - """Initializes 1D and 2D coordinate tensors for a grid of size end_x by end_y.""" + """Initializes 1D and 2D coordinate tensors for a grid of specified dimensions.""" t = torch.arange(end_x * end_y, dtype=torch.float32) t_x = (t % end_x).float() t_y = torch.div(t, end_x, rounding_mode="floor").float() @@ -78,7 +80,7 @@ def init_t_xy(end_x: int, end_y: int): def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): - """Computes axial complex exponential positional encodings for 2D spatial positions.""" + """Computes axial complex exponential positional encodings for 2D spatial positions in a grid.""" freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) @@ -91,7 +93,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): - """Reshapes frequency tensor for broadcasting, ensuring compatibility with input tensor dimensions.""" + """Reshapes frequency tensor for broadcasting with input tensor, ensuring dimensional compatibility.""" ndim = x.ndim assert 0 <= 1 < ndim assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) @@ -189,3 +191,103 @@ def window_unpartition(windows, window_size, pad_hw, hw): if Hp > H or Wp > W: x = x[:, :H, :W, :].contiguous() return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Extracts relative positional embeddings based on query and key sizes. + + Args: + q_size (int): Size of the query. + k_size (int): Size of the key. + rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative + distance and C is the embedding dimension. + + Returns: + (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size, + k_size, C). + + Examples: + >>> q_size, k_size = 8, 16 + >>> rel_pos = torch.randn(31, 64) # 31 = 2 * max(8, 16) - 1 + >>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos) + >>> print(extracted_pos.shape) + torch.Size([8, 16, 64]) + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Adds decomposed Relative Positional Embeddings to the attention map. + + This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2 + paper. It enhances the attention mechanism by incorporating spatial relationships between query and key + positions. + + Args: + attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w). + q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C). + rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C). + q_size (Tuple[int, int]): Spatial sequence size of query q as (q_h, q_w). + k_size (Tuple[int, int]): Spatial sequence size of key k as (k_h, k_w). + + Returns: + (torch.Tensor): Updated attention map with added relative positional embeddings, shape + (B, q_h * q_w, k_h * k_w). + + Examples: + >>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8 + >>> attn = torch.rand(B, q_h * q_w, k_h * k_w) + >>> q = torch.rand(B, q_h * q_w, C) + >>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C) + >>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C) + >>> q_size, k_size = (q_h, q_w), (k_h, k_w) + >>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size) + >>> print(updated_attn.shape) + torch.Size([1, 64, 64]) + + References: + https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( + B, q_h * q_w, k_h * k_w + ) + + return attn diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py index e03ba6256..034133514 100644 --- a/ultralytics/models/sam/predict.py +++ b/ultralytics/models/sam/predict.py @@ -34,35 +34,64 @@ from .build import build_sam class Predictor(BasePredictor): """ - Predictor class for the Segment Anything Model (SAM), extending BasePredictor. + Predictor class for SAM, enabling real-time image segmentation with promptable capabilities. - The class provides an interface for model inference tailored to image segmentation tasks. - With advanced architecture and promptable segmentation capabilities, it facilitates flexible and real-time - mask generation. The class is capable of working with various types of prompts such as bounding boxes, - points, and low-resolution masks. + This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image + segmentation tasks. It supports various input prompts like points, bounding boxes, and masks for + fine-grained control over segmentation results. Attributes: - cfg (dict): Configuration dictionary specifying model and task-related parameters. - overrides (dict): Dictionary containing values that override the default configuration. - _callbacks (dict): Dictionary of user-defined callback functions to augment behavior. - args (namespace): Namespace to hold command-line arguments or other operational variables. - im (torch.Tensor): Preprocessed input image tensor. - features (torch.Tensor): Extracted image features used for inference. - prompts (dict): Collection of various prompt types, such as bounding boxes and points. - segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones. + args (SimpleNamespace): Configuration arguments for the predictor. + model (torch.nn.Module): The loaded SAM model. + device (torch.device): The device (CPU or GPU) on which the model is loaded. + im (torch.Tensor): The preprocessed input image. + features (torch.Tensor): Extracted image features. + prompts (Dict): Dictionary to store various types of prompts (e.g., bboxes, points, masks). + segment_all (bool): Flag to indicate if full image segmentation should be performed. + mean (torch.Tensor): Mean values for image normalization. + std (torch.Tensor): Standard deviation values for image normalization. + + Methods: + preprocess: Prepares input images for model inference. + pre_transform: Performs initial transformations on the input image. + inference: Performs segmentation inference based on input prompts. + prompt_inference: Internal function for prompt-based segmentation inference. + generate: Generates segmentation masks for an entire image. + setup_model: Initializes the SAM model for inference. + get_model: Builds and returns a SAM model. + postprocess: Post-processes model outputs to generate final results. + setup_source: Sets up the data source for inference. + set_image: Sets and preprocesses a single image for inference. + get_im_features: Extracts image features using the SAM image encoder. + set_prompts: Sets prompts for subsequent inference. + reset_image: Resets the current image and its features. + remove_small_regions: Removes small disconnected regions and holes from masks. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model_path='sam_model.pt') + >>> predictor.set_image('image.jpg') + >>> masks, scores, boxes = predictor.generate() + >>> results = predictor.postprocess((masks, scores, boxes), im, orig_img) """ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): """ Initialize the Predictor with configuration, overrides, and callbacks. - The method sets up the Predictor object and applies any configuration overrides or callbacks provided. It - initializes task-specific settings for SAM, such as retina_masks being set to True for optimal results. + Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or + callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True + for optimal results. Args: - cfg (dict): Configuration dictionary. - overrides (dict, optional): Dictionary of values to override default configuration. - _callbacks (dict, optional): Dictionary of callback functions to customize behavior. + cfg (Dict): Configuration dictionary containing default settings. + overrides (Dict | None): Dictionary of values to override default configuration. + _callbacks (Dict | None): Dictionary of callback functions to customize behavior. + + Examples: + >>> predictor = Predictor(cfg=DEFAULT_CFG) + >>> predictor = Predictor(overrides={'imgsz': 640}) + >>> predictor = Predictor(_callbacks={'on_predict_start': custom_callback}) """ if overrides is None: overrides = {} @@ -78,14 +107,19 @@ class Predictor(BasePredictor): """ Preprocess the input image for model inference. - The method prepares the input image by applying transformations and normalization. - It supports both torch.Tensor and list of np.ndarray as input formats. + This method prepares the input image by applying transformations and normalization. It supports both + torch.Tensor and list of np.ndarray as input formats. Args: - im (torch.Tensor | List[np.ndarray]): BCHW tensor format or list of HWC numpy arrays. + im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays. Returns: - (torch.Tensor): The preprocessed image tensor. + (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype. + + Examples: + >>> predictor = Predictor() + >>> image = torch.rand(1, 3, 640, 640) + >>> preprocessed_image = predictor.preprocess(image) """ if self.im is not None: return self.im @@ -106,14 +140,24 @@ class Predictor(BasePredictor): """ Perform initial transformations on the input image for preprocessing. - The method applies transformations such as resizing to prepare the image for further preprocessing. + This method applies transformations such as resizing to prepare the image for further preprocessing. Currently, batched inference is not supported; hence the list length should be 1. Args: - im (List[np.ndarray]): List containing images in HWC numpy array format. + im (List[np.ndarray]): List containing a single image in HWC numpy array format. Returns: - (List[np.ndarray]): List of transformed images. + (List[np.ndarray]): List containing the transformed image. + + Raises: + AssertionError: If the input list contains more than one image. + + Examples: + >>> predictor = Predictor() + >>> image = np.random.rand(480, 640, 3) # Single HWC image + >>> transformed = predictor.pre_transform([image]) + >>> print(len(transformed)) + 1 """ assert len(im) == 1, "SAM model does not currently support batched inference" letterbox = LetterBox(self.args.imgsz, auto=False, center=False) @@ -121,23 +165,32 @@ class Predictor(BasePredictor): def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs): """ - Perform image segmentation inference based on the given input cues, using the currently loaded image. This - method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and - mask decoder for real-time and promptable segmentation tasks. + Perform image segmentation inference based on the given input cues, using the currently loaded image. + + This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt + encoder, and mask decoder for real-time and promptable segmentation tasks. Args: im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). - bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format. - points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels. - labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background. - masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256. - multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts. + bboxes (np.ndarray | List | None): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels. + labels (np.ndarray | List | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background. + masks (np.ndarray | None): Low-resolution masks from previous predictions, shape (N, H, W). For SAM H=W=256. + multimask_output (bool): Flag to return multiple masks. Helpful for ambiguous prompts. + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. Returns: - (tuple): Contains the following three elements. - - np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks. + (tuple): Contains the following three elements: + - np.ndarray: The output masks in shape (C, H, W), where C is the number of generated masks. - np.ndarray: An array of length C containing quality scores predicted by the model for each mask. - - np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256. + - np.ndarray: Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model_path='sam_model.pt') + >>> predictor.set_image('image.jpg') + >>> masks, scores, logits = predictor.inference(im, bboxes=[[0, 0, 100, 100]]) """ # Override prompts if any stored in self.prompts bboxes = self.prompts.pop("bboxes", bboxes) @@ -151,22 +204,30 @@ class Predictor(BasePredictor): def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False): """ - Internal function for image segmentation inference based on cues like bounding boxes, points, and masks. - Leverages SAM's specialized architecture for prompt-based, real-time segmentation. + Performs image segmentation inference based on input cues using SAM's specialized architecture. + + This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation. + It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks. Args: - im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). - bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format. - points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels. - labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background. - masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256. - multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts. + im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W). + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N,). 1 for foreground, 0 for background. + masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256. + multimask_output (bool): Flag to return multiple masks for ambiguous prompts. Returns: - (tuple): Contains the following three elements. - - np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks. - - np.ndarray: An array of length C containing quality scores predicted by the model for each mask. - - np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256. + (tuple): Tuple containing: + - np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks. + - np.ndarray: Quality scores predicted by the model for each mask, with length C. + - np.ndarray: Low-resolution logits with shape (C, H, W) for subsequent inference, where H=W=256. + + Examples: + >>> predictor = Predictor() + >>> im = torch.rand(1, 3, 1024, 1024) + >>> bboxes = [[100, 100, 200, 200]] + >>> masks, scores, logits = predictor.prompt_inference(im, bboxes=bboxes) """ features = self.get_im_features(im) if self.features is None else self.features @@ -224,27 +285,32 @@ class Predictor(BasePredictor): """ Perform image segmentation using the Segment Anything Model (SAM). - This function segments an entire image into constituent parts by leveraging SAM's advanced architecture + This method segments an entire image into constituent parts by leveraging SAM's advanced architecture and real-time performance capabilities. It can optionally work on image crops for finer segmentation. Args: - im (torch.Tensor): Input tensor representing the preprocessed image with dimensions (N, C, H, W). - crop_n_layers (int): Specifies the number of layers for additional mask predictions on image crops. - Each layer produces 2**i_layer number of image crops. - crop_overlap_ratio (float): Determines the overlap between crops. Scaled down in subsequent layers. - crop_downscale_factor (int): Scaling factor for the number of sampled points-per-side in each layer. - point_grids (list[np.ndarray], optional): Custom grids for point sampling normalized to [0,1]. - Used in the nth crop layer. - points_stride (int, optional): Number of points to sample along each side of the image. - Exclusive with 'point_grids'. + im (torch.Tensor): Input tensor representing the preprocessed image with shape (N, C, H, W). + crop_n_layers (int): Number of layers for additional mask predictions on image crops. + crop_overlap_ratio (float): Overlap between crops, scaled down in subsequent layers. + crop_downscale_factor (int): Scaling factor for sampled points-per-side in each layer. + point_grids (List[np.ndarray] | None): Custom grids for point sampling normalized to [0,1]. + points_stride (int): Number of points to sample along each side of the image. points_batch_size (int): Batch size for the number of points processed simultaneously. - conf_thres (float): Confidence threshold [0,1] for filtering based on the model's mask quality prediction. - stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on mask stability. + conf_thres (float): Confidence threshold [0,1] for filtering based on mask quality prediction. + stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on stability. stability_score_offset (float): Offset value for calculating stability score. crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops. Returns: - (tuple): A tuple containing segmented masks, confidence scores, and bounding boxes. + (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing: + - pred_masks (torch.Tensor): Segmented masks with shape (N, H, W). + - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,). + - pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4). + + Examples: + >>> predictor = Predictor() + >>> im = torch.rand(1, 3, 1024, 1024) # Example input image + >>> masks, scores, boxes = predictor.generate(im) """ import torchvision # scope for faster 'import ultralytics' @@ -326,11 +392,9 @@ class Predictor(BasePredictor): model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration. verbose (bool): If True, prints selected device information. - Attributes: - model (torch.nn.Module): The SAM model allocated to the chosen device for inference. - device (torch.device): The device to which the model and tensors are allocated. - mean (torch.Tensor): The mean values for image normalization. - std (torch.Tensor): The standard deviation values for image normalization. + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model=sam_model, verbose=True) """ device = select_device(self.args.device, verbose=verbose) if model is None: @@ -349,23 +413,32 @@ class Predictor(BasePredictor): self.done_warmup = True def get_model(self): - """Built Segment Anything Model (SAM) model.""" + """Retrieves or builds the Segment Anything Model (SAM) for image segmentation tasks.""" return build_sam(self.args.model) def postprocess(self, preds, img, orig_imgs): """ Post-processes SAM's inference outputs to generate object detection masks and bounding boxes. - The method scales masks and boxes to the original image size and applies a threshold to the mask predictions. - The SAM model uses advanced architecture and promptable segmentation tasks to achieve real-time performance. + This method scales masks and boxes to the original image size and applies a threshold to the mask + predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks. Args: - preds (tuple): The output from SAM model inference, containing masks, scores, and optional bounding boxes. - img (torch.Tensor): The processed input image tensor. - orig_imgs (list | torch.Tensor): The original, unprocessed images. + preds (Tuple[torch.Tensor]): The output from SAM model inference, containing: + - pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W). + - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1). + - pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True. + img (torch.Tensor): The processed input image tensor with shape (C, H, W). + orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images. Returns: - (list): List of Results objects containing detection masks, bounding boxes, and other metadata. + (List[Results]): List of Results objects containing detection masks, bounding boxes, and other + metadata for each processed image. + + Examples: + >>> predictor = Predictor() + >>> preds = predictor.inference(img) + >>> results = predictor.postprocess(preds, img, orig_imgs) """ # (N, 1, H, W), (N, 1) pred_masks, pred_scores = preds[:2] @@ -393,11 +466,23 @@ class Predictor(BasePredictor): """ Sets up the data source for inference. - This method configures the data source from which images will be fetched for inference. The source could be a - directory, a video file, or other types of image data sources. + This method configures the data source from which images will be fetched for inference. It supports + various input types such as image files, directories, video files, and other compatible data sources. Args: - source (str | Path): The path to the image data source for inference. + source (str | Path | None): The path or identifier for the image data source. Can be a file path, + directory path, URL, or other supported source types. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_source('path/to/images') + >>> predictor.setup_source('video.mp4') + >>> predictor.setup_source(None) # Uses default source if available + + Notes: + - If source is None, the method may use a default source if configured. + - The method adapts to different source types and prepares them for subsequent inference steps. + - Supported source types may include local files, directories, URLs, and video streams. """ if source is not None: super().setup_source(source) @@ -406,14 +491,25 @@ class Predictor(BasePredictor): """ Preprocesses and sets a single image for inference. - This function sets up the model if not already initialized, configures the data source to the specified image, - and preprocesses the image for feature extraction. Only one image can be set at a time. + This method prepares the model for inference on a single image by setting up the model if not already + initialized, configuring the data source, and preprocessing the image for feature extraction. It + ensures that only one image is set at a time and extracts image features for subsequent use. Args: - image (str | np.ndarray): Image file path as a string, or a np.ndarray image read by cv2. + image (str | np.ndarray): Path to the image file as a string, or a numpy array representing + an image read by cv2. Raises: - AssertionError: If more than one image is set. + AssertionError: If more than one image is attempted to be set. + + Examples: + >>> predictor = Predictor() + >>> predictor.set_image('path/to/image.jpg') + >>> predictor.set_image(cv2.imread('path/to/image.jpg')) + + Notes: + - This method should be called before performing inference on a new image. + - The extracted features are stored in the `self.features` attribute for later use. """ if self.model is None: self.setup_model(model=None) @@ -425,35 +521,44 @@ class Predictor(BasePredictor): break def get_im_features(self, im): - """Get image features from the SAM image encoder.""" + """Extracts image features using the SAM model's image encoder for subsequent mask prediction.""" return self.model.image_encoder(im) def set_prompts(self, prompts): - """Set prompts in advance.""" + """Sets prompts for subsequent inference operations.""" self.prompts = prompts def reset_image(self): - """Resets the image and its features to None.""" + """Resets the current image and its features, clearing them for subsequent inference.""" self.im = None self.features = None @staticmethod def remove_small_regions(masks, min_area=0, nms_thresh=0.7): """ - Perform post-processing on segmentation masks generated by the Segment Anything Model (SAM). Specifically, this - function removes small disconnected regions and holes from the input masks, and then performs Non-Maximum + Remove small disconnected regions and holes from segmentation masks. + + This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM). + It removes small disconnected regions and holes from the input masks, and then performs Non-Maximum Suppression (NMS) to eliminate any newly created duplicate boxes. Args: - masks (torch.Tensor): A tensor containing the masks to be processed. Shape should be (N, H, W), where N is - the number of masks, H is height, and W is width. - min_area (int): The minimum area below which disconnected regions and holes will be removed. Defaults to 0. - nms_thresh (float): The IoU threshold for the NMS algorithm. Defaults to 0.7. + masks (torch.Tensor): Segmentation masks to be processed, with shape (N, H, W) where N is the number of + masks, H is height, and W is width. + min_area (int): Minimum area threshold for removing disconnected regions and holes. Regions smaller than + this will be removed. + nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes. Returns: - (tuple([torch.Tensor, List[int]])): - - new_masks (torch.Tensor): The processed masks with small regions removed. Shape is (N, H, W). - - keep (List[int]): The indices of the remaining masks post-NMS, which can be used to filter the boxes. + (tuple): + - new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W). + - keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes. + + Examples: + >>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks + >>> new_masks, keep = remove_small_regions(masks, min_area=100, nms_thresh=0.7) + >>> print(f"Original masks: {masks.shape}, Processed masks: {new_masks.shape}") + >>> print(f"Indices of kept masks: {keep}") """ import torchvision # scope for faster 'import ultralytics' @@ -480,3 +585,188 @@ class Predictor(BasePredictor): keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh) return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep + + +class SAM2Predictor(Predictor): + """ + SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture. + + This class extends the base Predictor class to implement SAM2-specific functionality for image + segmentation tasks. It provides methods for model initialization, feature extraction, and + prompt-based inference. + + Attributes: + _bb_feat_sizes (List[Tuple[int, int]]): Feature sizes for different backbone levels. + model (torch.nn.Module): The loaded SAM2 model. + device (torch.device): The device (CPU or GPU) on which the model is loaded. + features (Dict[str, torch.Tensor]): Cached image features for efficient inference. + segment_all (bool): Flag to indicate if all segments should be predicted. + prompts (Dict): Dictionary to store various types of prompts for inference. + + Methods: + get_model: Retrieves and initializes the SAM2 model. + prompt_inference: Performs image segmentation inference based on various prompts. + set_image: Preprocesses and sets a single image for inference. + get_im_features: Extracts and processes image features using SAM2's image encoder. + + Examples: + >>> predictor = SAM2Predictor(cfg) + >>> predictor.set_image("path/to/image.jpg") + >>> bboxes = [[100, 100, 200, 200]] + >>> masks, scores, _ = predictor.prompt_inference(predictor.im, bboxes=bboxes) + >>> print(f"Predicted {len(masks)} masks with average score {scores.mean():.2f}") + """ + + _bb_feat_sizes = [ + (256, 256), + (128, 128), + (64, 64), + ] + + def get_model(self): + """Retrieves and initializes the Segment Anything Model 2 (SAM2) for image segmentation tasks.""" + return build_sam(self.args.model) + + def prompt_inference( + self, + im, + bboxes=None, + points=None, + labels=None, + masks=None, + multimask_output=False, + img_idx=-1, + ): + """ + Performs image segmentation inference based on various prompts using SAM2 architecture. + + This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images + based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and + multi-object prediction scenarios. + + Args: + im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W). + bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels. + labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background. + masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W). + multimask_output (bool): Flag to return multiple masks for ambiguous prompts. + img_idx (int): Index of the image in the batch to process. + + Returns: + (tuple): Tuple containing: + - np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks. + - np.ndarray: Quality scores for each mask, with length C. + - np.ndarray: Low-resolution logits with shape (C, 256, 256) for subsequent inference. + + Examples: + >>> predictor = SAM2Predictor(cfg) + >>> image = torch.rand(1, 3, 640, 640) + >>> bboxes = [[100, 100, 200, 200]] + >>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes) + >>> print(f"Generated {masks.shape[0]} masks with average score {scores.mean():.2f}") + + Notes: + - The method supports batched inference for multiple objects when points or bboxes are provided. + - Input prompts (bboxes, points) are automatically scaled to match the input image dimensions. + - When both bboxes and points are provided, they are merged into a single 'points' input for the model. + + References: + - SAM2 Paper: [Add link to SAM2 paper when available] + """ + features = self.get_im_features(im) if self.features is None else self.features + + src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:] + r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) + # Transform input prompts + if points is not None: + points = torch.as_tensor(points, dtype=torch.float32, device=self.device) + points = points[None] if points.ndim == 1 else points + # Assuming labels are all positive if users don't pass labels. + if labels is None: + labels = torch.ones(points.shape[0]) + labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device) + points *= r + # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1) + points, labels = points[:, None], labels[:, None] + if bboxes is not None: + bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device) + bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes + bboxes = bboxes.view(-1, 2, 2) * r + bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1) + # NOTE: merge "boxes" and "points" into a single "points" input + # (where boxes are added at the beginning) to model.sam_prompt_encoder + if points is not None: + points = torch.cat([bboxes, points], dim=1) + labels = torch.cat([bbox_labels, labels], dim=1) + else: + points, labels = bboxes, bbox_labels + if masks is not None: + masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) + + points = (points, labels) if points is not None else None + + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points=points, + boxes=None, + masks=masks, + ) + # Predict masks + batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction + high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]] + pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder( + image_embeddings=features["image_embed"][img_idx].unsqueeze(0), + image_pe=self.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=batched_mode, + high_res_features=high_res_features, + ) + # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) + # `d` could be 1 or 3 depends on `multimask_output`. + return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) + + def set_image(self, image): + """ + Preprocesses and sets a single image for inference using the SAM2 model. + + This method initializes the model if not already done, configures the data source to the specified image, + and preprocesses the image for feature extraction. It supports setting only one image at a time. + + Args: + image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image. + + Raises: + AssertionError: If more than one image is attempted to be set. + + Examples: + >>> predictor = SAM2Predictor() + >>> predictor.set_image("path/to/image.jpg") + >>> predictor.set_image(np.array([...])) # Using a numpy array + + Notes: + - This method must be called before performing any inference on a new image. + - The method caches the extracted features for efficient subsequent inferences on the same image. + - Only one image can be set at a time. To process multiple images, call this method for each new image. + """ + if self.model is None: + self.setup_model(model=None) + self.setup_source(image) + assert len(self.dataset) == 1, "`set_image` only supports setting one image!" + for batch in self.dataset: + im = self.preprocess(batch[1]) + self.features = self.get_im_features(im) + break + + def get_im_features(self, im): + """Extracts image features from the SAM image encoder for subsequent processing.""" + backbone_out = self.model.forward_image(im) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + feats = [ + feat.permute(1, 2, 0).view(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + return {"image_embed": feats[-1], "high_res_feats": feats[:-1]} diff --git a/ultralytics/models/sam2/__init__.py b/ultralytics/models/sam2/__init__.py deleted file mode 100644 index 755a160af..000000000 --- a/ultralytics/models/sam2/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Ultralytics YOLO 🚀, AGPL-3.0 license - -from .model import SAM2 -from .predict import SAM2Predictor - -__all__ = "SAM2", "SAM2Predictor" # tuple or list diff --git a/ultralytics/models/sam2/build.py b/ultralytics/models/sam2/build.py deleted file mode 100644 index 2791029a5..000000000 --- a/ultralytics/models/sam2/build.py +++ /dev/null @@ -1,156 +0,0 @@ -# Ultralytics YOLO 🚀, AGPL-3.0 license - -import torch - -from ultralytics.utils.downloads import attempt_download_asset - -from .modules.encoders import FpnNeck, Hiera, ImageEncoder, MemoryEncoder -from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer -from .modules.sam2 import SAM2Model - - -def build_sam2_t(checkpoint=None): - """Build and return a Segment Anything Model (SAM2) tiny-size model with specified architecture parameters.""" - return _build_sam2( - encoder_embed_dim=96, - encoder_stages=[1, 2, 7, 2], - encoder_num_heads=1, - encoder_global_att_blocks=[5, 7, 9], - encoder_window_spec=[8, 4, 14, 7], - encoder_backbone_channel_list=[768, 384, 192, 96], - checkpoint=checkpoint, - ) - - -def build_sam2_s(checkpoint=None): - """Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters.""" - return _build_sam2( - encoder_embed_dim=96, - encoder_stages=[1, 2, 11, 2], - encoder_num_heads=1, - encoder_global_att_blocks=[7, 10, 13], - encoder_window_spec=[8, 4, 14, 7], - encoder_backbone_channel_list=[768, 384, 192, 96], - checkpoint=checkpoint, - ) - - -def build_sam2_b(checkpoint=None): - """Builds and returns a Segment Anything Model (SAM2) base-size model with specified architecture parameters.""" - return _build_sam2( - encoder_embed_dim=112, - encoder_stages=[2, 3, 16, 3], - encoder_num_heads=2, - encoder_global_att_blocks=[12, 16, 20], - encoder_window_spec=[8, 4, 14, 7], - encoder_window_spatial_size=[14, 14], - encoder_backbone_channel_list=[896, 448, 224, 112], - checkpoint=checkpoint, - ) - - -def build_sam2_l(checkpoint=None): - """Build and return a Segment Anything Model (SAM2) large-size model with specified architecture parameters.""" - return _build_sam2( - encoder_embed_dim=144, - encoder_stages=[2, 6, 36, 4], - encoder_num_heads=2, - encoder_global_att_blocks=[23, 33, 43], - encoder_window_spec=[8, 4, 16, 8], - encoder_backbone_channel_list=[1152, 576, 288, 144], - checkpoint=checkpoint, - ) - - -def _build_sam2( - encoder_embed_dim=1280, - encoder_stages=[2, 6, 36, 4], - encoder_num_heads=2, - encoder_global_att_blocks=[7, 15, 23, 31], - encoder_backbone_channel_list=[1152, 576, 288, 144], - encoder_window_spatial_size=[7, 7], - encoder_window_spec=[8, 4, 16, 8], - checkpoint=None, -): - """Builds a SAM2 model with specified architecture parameters and optional checkpoint loading.""" - image_encoder = ImageEncoder( - trunk=Hiera( - embed_dim=encoder_embed_dim, - num_heads=encoder_num_heads, - stages=encoder_stages, - global_att_blocks=encoder_global_att_blocks, - window_pos_embed_bkg_spatial_size=encoder_window_spatial_size, - window_spec=encoder_window_spec, - ), - neck=FpnNeck( - d_model=256, - backbone_channel_list=encoder_backbone_channel_list, - fpn_top_down_levels=[2, 3], - fpn_interp_model="nearest", - ), - scalp=1, - ) - memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer()) - memory_encoder = MemoryEncoder(out_dim=64) - - sam2 = SAM2Model( - image_encoder=image_encoder, - memory_attention=memory_attention, - memory_encoder=memory_encoder, - num_maskmem=7, - image_size=1024, - sigmoid_scale_for_mem_enc=20.0, - sigmoid_bias_for_mem_enc=-10.0, - use_mask_input_as_output_without_sam=True, - directly_add_no_mem_embed=True, - use_high_res_features_in_sam=True, - multimask_output_in_sam=True, - iou_prediction_use_sigmoid=True, - use_obj_ptrs_in_encoder=True, - add_tpos_enc_to_obj_ptrs=True, - only_obj_ptrs_in_the_past_for_eval=True, - pred_obj_scores=True, - pred_obj_scores_mlp=True, - fixed_no_obj_ptr=True, - multimask_output_for_tracking=True, - use_multimask_token_for_obj_ptr=True, - multimask_min_pt_num=0, - multimask_max_pt_num=1, - use_mlp_for_obj_ptr_proj=True, - compile_image_encoder=False, - sam_mask_decoder_extra_args=dict( - dynamic_multimask_via_stability=True, - dynamic_multimask_stability_delta=0.05, - dynamic_multimask_stability_thresh=0.98, - ), - ) - - if checkpoint is not None: - checkpoint = attempt_download_asset(checkpoint) - with open(checkpoint, "rb") as f: - state_dict = torch.load(f)["model"] - sam2.load_state_dict(state_dict) - sam2.eval() - return sam2 - - -sam_model_map = { - "sam2_t.pt": build_sam2_t, - "sam2_s.pt": build_sam2_s, - "sam2_b.pt": build_sam2_b, - "sam2_l.pt": build_sam2_l, -} - - -def build_sam2(ckpt="sam_b.pt"): - """Constructs a Segment Anything Model (SAM2) based on the specified checkpoint, with various size options.""" - model_builder = None - ckpt = str(ckpt) # to allow Path ckpt types - for k in sam_model_map.keys(): - if ckpt.endswith(k): - model_builder = sam_model_map.get(k) - - if not model_builder: - raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}") - - return model_builder(ckpt) diff --git a/ultralytics/models/sam2/model.py b/ultralytics/models/sam2/model.py deleted file mode 100644 index 4b3265e19..000000000 --- a/ultralytics/models/sam2/model.py +++ /dev/null @@ -1,97 +0,0 @@ -# Ultralytics YOLO 🚀, AGPL-3.0 license -""" -SAM2 model interface. - -This module provides an interface to the Segment Anything Model (SAM2) from Ultralytics, designed for real-time image -segmentation tasks. The SAM2 model allows for promptable segmentation with unparalleled versatility in image analysis, -and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new -image distributions and tasks without prior knowledge. - -Key Features: - - Promptable segmentation - - Real-time performance - - Zero-shot transfer capabilities - - Trained on SA-1B dataset -""" - -from ultralytics.models.sam import SAM - -from .build import build_sam2 -from .predict import SAM2Predictor - - -class SAM2(SAM): - """ - SAM2 class for real-time image segmentation using the Segment Anything Model (SAM2). - - This class extends the SAM base class, providing an interface to the SAM2 model for promptable segmentation - tasks. It supports loading pre-trained weights and offers zero-shot performance capabilities. - - Attributes: - model (torch.nn.Module): The loaded SAM2 model. - task_map (Dict[str, Type[SAM2Predictor]]): Mapping of 'segment' task to SAM2Predictor. - - Methods: - __init__: Initializes the SAM2 model with pre-trained weights. - _load: Loads specified weights into the SAM2 model. - - Examples: - >>> sam2 = SAM2("sam2_b.pt") - >>> sam2._load('path/to/sam2_weights.pt') - >>> task_map = sam2.task_map - >>> print(task_map) - {'segment': SAM2Predictor} - - Notes: - - Supports .pt and .pth file extensions for model weights. - - Offers zero-shot transfer capabilities for new image distributions and tasks. - """ - - def __init__(self, model="sam2_b.pt") -> None: - """ - Initializes the SAM2 model with a pre-trained model file. - - Args: - model (str): Path to the pre-trained SAM2 model file. File should have a .pt or .pth extension. - - Raises: - NotImplementedError: If the model file extension is not .pt or .pth. - - Examples: - >>> sam2 = SAM2("sam2_b.pt") - """ - super().__init__(model=model) - - def _load(self, weights: str, task=None): - """ - Loads the specified weights into the SAM2 model. - - This method is responsible for loading pre-trained weights into the SAM2 model. It supports loading - weights from files with .pt or .pth extensions. - - Args: - weights (str): Path to the weights file. Should be a file with .pt or .pth extension. - task (str | None): Task name. If provided, it may be used to configure model-specific settings. - - Examples: - >>> sam2_model = SAM2() - >>> sam2_model._load('path/to/sam2_weights.pt') - """ - self.model = build_sam2(weights) - - @property - def task_map(self): - """ - Provides a mapping from the 'segment' task to its corresponding 'Predictor'. - - Returns: - (Dict[str, Type[SAM2Predictor]]): A dictionary mapping the 'segment' task to its corresponding - SAM2Predictor class. - - Examples: - >>> sam2 = SAM2() - >>> task_map = sam2.task_map - >>> print(task_map) - {'segment': SAM2Predictor} - """ - return {"segment": {"predictor": SAM2Predictor}} diff --git a/ultralytics/models/sam2/modules/__init__.py b/ultralytics/models/sam2/modules/__init__.py deleted file mode 100644 index 9e68dc122..000000000 --- a/ultralytics/models/sam2/modules/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Ultralytics YOLO 🚀, AGPL-3.0 license diff --git a/ultralytics/models/sam2/modules/decoders.py b/ultralytics/models/sam2/modules/decoders.py deleted file mode 100644 index ac6c60a60..000000000 --- a/ultralytics/models/sam2/modules/decoders.py +++ /dev/null @@ -1,305 +0,0 @@ -# Ultralytics YOLO 🚀, AGPL-3.0 license - -from typing import List, Optional, Tuple, Type - -import torch -from torch import nn - -from ultralytics.nn.modules import MLP, LayerNorm2d - - -class MaskDecoder(nn.Module): - """Transformer-based decoder predicting instance segmentation masks from image and prompt embeddings.""" - - def __init__( - self, - transformer_dim: int, - transformer: nn.Module, - num_multimask_outputs: int = 3, - activation: Type[nn.Module] = nn.GELU, - iou_head_depth: int = 3, - iou_head_hidden_dim: int = 256, - use_high_res_features: bool = False, - iou_prediction_use_sigmoid=False, - dynamic_multimask_via_stability=False, - dynamic_multimask_stability_delta=0.05, - dynamic_multimask_stability_thresh=0.98, - pred_obj_scores: bool = False, - pred_obj_scores_mlp: bool = False, - use_multimask_token_for_obj_ptr: bool = False, - ) -> None: - """ - Initializes the MaskDecoder module for predicting instance segmentation masks. - - Args: - transformer_dim (int): Channel dimension of the transformer. - transformer (nn.Module): Transformer used to predict masks. - num_multimask_outputs (int): Number of masks to predict when disambiguating masks. - activation (Type[nn.Module]): Type of activation to use when upscaling masks. - iou_head_depth (int): Depth of the MLP used to predict mask quality. - iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality. - use_high_res_features (bool): Whether to use high-resolution features. - iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction. - dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability. - dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability. - dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability. - pred_obj_scores (bool): Whether to predict object scores. - pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction. - use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer. - - Attributes: - transformer_dim (int): Channel dimension of the transformer. - transformer (nn.Module): Transformer used to predict masks. - num_multimask_outputs (int): Number of masks to predict when disambiguating masks. - iou_token (nn.Embedding): Embedding for IOU token. - num_mask_tokens (int): Total number of mask tokens. - mask_tokens (nn.Embedding): Embedding for mask tokens. - pred_obj_scores (bool): Whether to predict object scores. - obj_score_token (nn.Embedding): Embedding for object score token. - use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer. - output_upscaling (nn.Sequential): Upscaling layers for output. - use_high_res_features (bool): Whether to use high-resolution features. - conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0). - conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1). - output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks. - iou_prediction_head (MLP): MLP for IOU prediction. - pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction. - dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability. - dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability. - """ - super().__init__() - self.transformer_dim = transformer_dim - self.transformer = transformer - - self.num_multimask_outputs = num_multimask_outputs - - self.iou_token = nn.Embedding(1, transformer_dim) - self.num_mask_tokens = num_multimask_outputs + 1 - self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) - - self.pred_obj_scores = pred_obj_scores - if self.pred_obj_scores: - self.obj_score_token = nn.Embedding(1, transformer_dim) - self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr - - self.output_upscaling = nn.Sequential( - nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), - LayerNorm2d(transformer_dim // 4), - activation(), - nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), - activation(), - ) - self.use_high_res_features = use_high_res_features - if use_high_res_features: - self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1) - self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1) - - self.output_hypernetworks_mlps = nn.ModuleList( - [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)] - ) - - self.iou_prediction_head = MLP( - transformer_dim, - iou_head_hidden_dim, - self.num_mask_tokens, - iou_head_depth, - sigmoid=iou_prediction_use_sigmoid, - ) - if self.pred_obj_scores: - self.pred_obj_score_head = nn.Linear(transformer_dim, 1) - if pred_obj_scores_mlp: - self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) - - # When outputting a single mask, optionally we can dynamically fall back to the best - # multimask output token if the single mask output token gives low stability scores. - self.dynamic_multimask_via_stability = dynamic_multimask_via_stability - self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta - self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh - - def forward( - self, - image_embeddings: torch.Tensor, - image_pe: torch.Tensor, - sparse_prompt_embeddings: torch.Tensor, - dense_prompt_embeddings: torch.Tensor, - multimask_output: bool, - repeat_image: bool, - high_res_features: Optional[List[torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Predicts masks given image and prompt embeddings. - - Args: - image_embeddings (torch.Tensor): Embeddings from the image encoder. - image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings. - sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes. - dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs. - multimask_output (bool): Whether to return multiple masks or a single mask. - repeat_image (bool): Flag to repeat the image embeddings. - high_res_features (List[torch.Tensor] | None): Optional high-resolution features. - - Returns: - (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing: - - masks (torch.Tensor): Batched predicted masks. - - iou_pred (torch.Tensor): Batched predictions of mask quality. - - sam_tokens_out (torch.Tensor): Batched SAM token for mask output. - - Examples: - >>> image_embeddings = torch.rand(1, 256, 64, 64) - >>> image_pe = torch.rand(1, 256, 64, 64) - >>> sparse_prompt_embeddings = torch.rand(1, 2, 256) - >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64) - >>> decoder = MaskDecoder(256, transformer) - >>> masks, iou_pred, sam_tokens_out = decoder.forward(image_embeddings, image_pe, - ... sparse_prompt_embeddings, dense_prompt_embeddings, True, False) - """ - masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( - image_embeddings=image_embeddings, - image_pe=image_pe, - sparse_prompt_embeddings=sparse_prompt_embeddings, - dense_prompt_embeddings=dense_prompt_embeddings, - repeat_image=repeat_image, - high_res_features=high_res_features, - ) - - # Select the correct mask or masks for output - if multimask_output: - masks = masks[:, 1:, :, :] - iou_pred = iou_pred[:, 1:] - elif self.dynamic_multimask_via_stability and not self.training: - masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) - else: - masks = masks[:, 0:1, :, :] - iou_pred = iou_pred[:, 0:1] - - if multimask_output and self.use_multimask_token_for_obj_ptr: - sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape - else: - # Take the mask output token. Here we *always* use the token for single mask output. - # At test time, even if we track after 1-click (and using multimask_output=True), - # we still take the single mask token here. The rationale is that we always track - # after multiple clicks during training, so the past tokens seen during training - # are always the single mask token (and we'll let it be the object-memory token). - sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape - - # Prepare output - return masks, iou_pred, sam_tokens_out, object_score_logits - - def predict_masks( - self, - image_embeddings: torch.Tensor, - image_pe: torch.Tensor, - sparse_prompt_embeddings: torch.Tensor, - dense_prompt_embeddings: torch.Tensor, - repeat_image: bool, - high_res_features: Optional[List[torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Predicts instance segmentation masks from image and prompt embeddings using a transformer architecture.""" - # Concatenate output tokens - s = 0 - if self.pred_obj_scores: - output_tokens = torch.cat( - [ - self.obj_score_token.weight, - self.iou_token.weight, - self.mask_tokens.weight, - ], - dim=0, - ) - s = 1 - else: - output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) - output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) - tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) - - # Expand per-image data in batch direction to be per-mask - if repeat_image: - src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) - else: - assert image_embeddings.shape[0] == tokens.shape[0] - src = image_embeddings - src = src + dense_prompt_embeddings - assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" - pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) - b, c, h, w = src.shape - - # Run the transformer - hs, src = self.transformer(src, pos_src, tokens) - iou_token_out = hs[:, s, :] - mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] - - # Upscale mask embeddings and predict masks using the mask tokens - src = src.transpose(1, 2).view(b, c, h, w) - if not self.use_high_res_features: - upscaled_embedding = self.output_upscaling(src) - else: - dc1, ln1, act1, dc2, act2 = self.output_upscaling - feat_s0, feat_s1 = high_res_features - upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) - upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) - - hyper_in_list: List[torch.Tensor] = [] - for i in range(self.num_mask_tokens): - hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) - hyper_in = torch.stack(hyper_in_list, dim=1) - b, c, h, w = upscaled_embedding.shape - masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) - - # Generate mask quality predictions - iou_pred = self.iou_prediction_head(iou_token_out) - if self.pred_obj_scores: - assert s == 1 - object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) - else: - # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 - object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) - - return masks, iou_pred, mask_tokens_out, object_score_logits - - def _get_stability_scores(self, mask_logits): - """Computes mask stability scores based on IoU between upper and lower thresholds.""" - mask_logits = mask_logits.flatten(-2) - stability_delta = self.dynamic_multimask_stability_delta - area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() - area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() - stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) - return stability_scores - - def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): - """ - Dynamically selects the most stable mask output based on stability scores and IoU predictions. - - When outputting a single mask, if the stability score from the current single-mask output (based on output token - 0) falls below a threshold, we instead select from multi-mask outputs (based on output token 1~3) the mask with - the highest predicted IoU score. - - This is intended to ensure a valid mask for both clicking and tracking. - """ - # The best mask from multimask output tokens (1~3) - multimask_logits = all_mask_logits[:, 1:, :, :] - multimask_iou_scores = all_iou_scores[:, 1:] - best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) - batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device) - best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] - best_multimask_logits = best_multimask_logits.unsqueeze(1) - best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] - best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) - - # The mask from singlemask output token 0 and its stability score - singlemask_logits = all_mask_logits[:, 0:1, :, :] - singlemask_iou_scores = all_iou_scores[:, 0:1] - stability_scores = self._get_stability_scores(singlemask_logits) - is_stable = stability_scores >= self.dynamic_multimask_stability_thresh - - # Dynamically fall back to best multimask output upon low stability scores. - mask_logits_out = torch.where( - is_stable[..., None, None].expand_as(singlemask_logits), - singlemask_logits, - best_multimask_logits, - ) - iou_scores_out = torch.where( - is_stable.expand_as(singlemask_iou_scores), - singlemask_iou_scores, - best_multimask_iou_scores, - ) - return mask_logits_out, iou_scores_out diff --git a/ultralytics/models/sam2/modules/encoders.py b/ultralytics/models/sam2/modules/encoders.py deleted file mode 100644 index b4cd0f8f2..000000000 --- a/ultralytics/models/sam2/modules/encoders.py +++ /dev/null @@ -1,332 +0,0 @@ -# Ultralytics YOLO 🚀, AGPL-3.0 license - -from typing import List, Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ultralytics.models.sam.modules.encoders import PatchEmbed - -from .sam2_blocks import CXBlock, Fuser, MaskDownSampler, MultiScaleBlock, PositionEmbeddingSine - - -class MemoryEncoder(nn.Module): - """Encodes pixel features and masks into a memory representation for efficient image segmentation.""" - - def __init__( - self, - out_dim, - in_dim=256, # in_dim of pix_feats - ): - """Initializes the MemoryEncoder module for encoding pixel features and masks in SAM-like models.""" - super().__init__() - - self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1) - - self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) - self.fuser = Fuser(CXBlock(dim=256), num_layers=2) - self.position_encoding = PositionEmbeddingSine(num_pos_feats=64) - self.out_proj = nn.Identity() - if out_dim != in_dim: - self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) - - def forward( - self, - pix_feat: torch.Tensor, - masks: torch.Tensor, - skip_mask_sigmoid: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Processes pixel features and masks, fusing them to generate encoded memory representations.""" - if not skip_mask_sigmoid: - masks = F.sigmoid(masks) - masks = self.mask_downsampler(masks) - - # Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA - pix_feat = pix_feat.to(masks.device) - - x = self.pix_feat_proj(pix_feat) - x = x + masks - x = self.fuser(x) - x = self.out_proj(x) - - pos = self.position_encoding(x).to(x.dtype) - - return {"vision_features": x, "vision_pos_enc": [pos]} - - -class ImageEncoder(nn.Module): - """Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings.""" - - def __init__( - self, - trunk: nn.Module, - neck: nn.Module, - scalp: int = 0, - ): - """Initializes an image encoder with a trunk, neck, and optional scalp for feature extraction.""" - super().__init__() - self.trunk = trunk - self.neck = neck - self.scalp = scalp - assert ( - self.trunk.channel_list == self.neck.backbone_channel_list - ), f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match." - - def forward(self, sample: torch.Tensor): - """Processes image input through trunk and neck, returning features, positional encodings, and FPN outputs.""" - features, pos = self.neck(self.trunk(sample)) - if self.scalp > 0: - # Discard the lowest resolution features - features, pos = features[: -self.scalp], pos[: -self.scalp] - - src = features[-1] - output = { - "vision_features": src, - "vision_pos_enc": pos, - "backbone_fpn": features, - } - return output - - -class FpnNeck(nn.Module): - """Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.""" - - def __init__( - self, - d_model: int, - backbone_channel_list: List[int], - kernel_size: int = 1, - stride: int = 1, - padding: int = 0, - fpn_interp_model: str = "bilinear", - fuse_type: str = "sum", - fpn_top_down_levels: Optional[List[int]] = None, - ): - """ - Initializes a modified Feature Pyramid Network (FPN) neck. - - This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing, - similar to ViT positional embedding interpolation. - - Args: - d_model (int): Dimension of the model. - backbone_channel_list (List[int]): List of channel dimensions from the backbone. - kernel_size (int): Kernel size for the convolutional layers. - stride (int): Stride for the convolutional layers. - padding (int): Padding for the convolutional layers. - fpn_interp_model (str): Interpolation mode for FPN feature resizing. - fuse_type (str): Type of feature fusion, either 'sum' or 'avg'. - fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs. - - Attributes: - position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding. - convs (nn.ModuleList): List of convolutional layers for each backbone level. - backbone_channel_list (List[int]): List of channel dimensions from the backbone. - fpn_interp_model (str): Interpolation mode for FPN feature resizing. - fuse_type (str): Type of feature fusion. - fpn_top_down_levels (List[int]): Levels with top-down feature propagation. - - Examples: - >>> backbone_channels = [64, 128, 256, 512] - >>> fpn_neck = FpnNeck(256, backbone_channels) - >>> print(fpn_neck) - """ - super().__init__() - self.position_encoding = PositionEmbeddingSine(num_pos_feats=256) - self.convs = nn.ModuleList() - self.backbone_channel_list = backbone_channel_list - for dim in backbone_channel_list: - current = nn.Sequential() - current.add_module( - "conv", - nn.Conv2d( - in_channels=dim, - out_channels=d_model, - kernel_size=kernel_size, - stride=stride, - padding=padding, - ), - ) - - self.convs.append(current) - self.fpn_interp_model = fpn_interp_model - assert fuse_type in ["sum", "avg"] - self.fuse_type = fuse_type - - # levels to have top-down features in its outputs - # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 - # have top-down propagation, while outputs of level 0 and level 1 have only - # lateral features from the same backbone level. - if fpn_top_down_levels is None: - # default is to have top-down features on all levels - fpn_top_down_levels = range(len(self.convs)) - self.fpn_top_down_levels = list(fpn_top_down_levels) - - def forward(self, xs: List[torch.Tensor]): - """ - Performs forward pass through the Feature Pyramid Network (FPN) neck. - - Args: - xs (List[torch.Tensor]): List of input tensors from the backbone, with shape (B, C, H, W) for each tensor. - - Returns: - (Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing two lists: - - out: List of output feature maps after FPN processing, with shape (B, d_model, H, W) for each tensor. - - pos: List of positional encodings corresponding to each output feature map. - - Examples: - >>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512]) - >>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]] - >>> outputs, positions = fpn_neck(inputs) - """ - out = [None] * len(self.convs) - pos = [None] * len(self.convs) - assert len(xs) == len(self.convs) - # fpn forward pass - # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py - prev_features = None - # forward in top-down order (from low to high resolution) - n = len(self.convs) - 1 - for i in range(n, -1, -1): - x = xs[i] - lateral_features = self.convs[n - i](x) - if i in self.fpn_top_down_levels and prev_features is not None: - top_down_features = F.interpolate( - prev_features.to(dtype=torch.float32), - scale_factor=2.0, - mode=self.fpn_interp_model, - align_corners=(None if self.fpn_interp_model == "nearest" else False), - antialias=False, - ) - prev_features = lateral_features + top_down_features - if self.fuse_type == "avg": - prev_features /= 2 - else: - prev_features = lateral_features - x_out = prev_features - out[i] = x_out - pos[i] = self.position_encoding(x_out).to(x_out.dtype) - - return out, pos - - -class Hiera(nn.Module): - """Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.""" - - def __init__( - self, - embed_dim: int = 96, # initial embed dim - num_heads: int = 1, # initial number of heads - drop_path_rate: float = 0.0, # stochastic depth - q_pool: int = 3, # number of q_pool stages - q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages - stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage - dim_mul: float = 2.0, # dim_mul factor at stage shift - head_mul: float = 2.0, # head_mul factor at stage shift - window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), - # window size per stage, when not using global att. - window_spec: Tuple[int, ...] = ( - 8, - 4, - 14, - 7, - ), - # global attn in these blocks - global_att_blocks: Tuple[int, ...] = ( - 12, - 16, - 20, - ), - return_interm_layers=True, # return feats from every stage - ): - """Initializes a Hiera model with configurable architecture for hierarchical vision transformers.""" - super().__init__() - - assert len(stages) == len(window_spec) - self.window_spec = window_spec - - depth = sum(stages) - self.q_stride = q_stride - self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] - assert 0 <= q_pool <= len(self.stage_ends[:-1]) - self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] - self.return_interm_layers = return_interm_layers - - self.patch_embed = PatchEmbed( - embed_dim=embed_dim, - kernel_size=(7, 7), - stride=(4, 4), - padding=(3, 3), - ) - # Which blocks have global att? - self.global_att_blocks = global_att_blocks - - # Windowed positional embedding (https://arxiv.org/abs/2311.05613) - self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size - self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)) - self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])) - - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule - - cur_stage = 1 - self.blocks = nn.ModuleList() - - for i in range(depth): - dim_out = embed_dim - # lags by a block, so first block of - # next stage uses an initial window size - # of previous stage and final window size of current stage - window_size = self.window_spec[cur_stage - 1] - - if self.global_att_blocks is not None: - window_size = 0 if i in self.global_att_blocks else window_size - - if i - 1 in self.stage_ends: - dim_out = int(embed_dim * dim_mul) - num_heads = int(num_heads * head_mul) - cur_stage += 1 - - block = MultiScaleBlock( - dim=embed_dim, - dim_out=dim_out, - num_heads=num_heads, - drop_path=dpr[i], - q_stride=self.q_stride if i in self.q_pool_blocks else None, - window_size=window_size, - ) - - embed_dim = dim_out - self.blocks.append(block) - - self.channel_list = ( - [self.blocks[i].dim_out for i in self.stage_ends[::-1]] - if return_interm_layers - else [self.blocks[-1].dim_out] - ) - - def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: - """Generate positional embeddings by interpolating and combining window and background embeddings.""" - h, w = hw - window_embed = self.pos_embed_window - pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") - pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)]) - pos_embed = pos_embed.permute(0, 2, 3, 1) - return pos_embed - - def forward(self, x: torch.Tensor) -> List[torch.Tensor]: - """Performs hierarchical vision transformer forward pass, returning multiscale feature maps.""" - x = self.patch_embed(x) - # x: (B, H, W, C) - - # Add pos embed - x = x + self._get_pos_embed(x.shape[1:3]) - - outputs = [] - for i, blk in enumerate(self.blocks): - x = blk(x) - if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers): - feats = x.permute(0, 3, 1, 2) - outputs.append(feats) - - return outputs diff --git a/ultralytics/models/sam2/modules/sam2.py b/ultralytics/models/sam2/modules/sam2.py deleted file mode 100644 index 363241a19..000000000 --- a/ultralytics/models/sam2/modules/sam2.py +++ /dev/null @@ -1,804 +0,0 @@ -# Ultralytics YOLO 🚀, AGPL-3.0 license - -import torch -import torch.distributed -import torch.nn.functional as F -from torch.nn.init import trunc_normal_ - -from ultralytics.models.sam.modules.encoders import PromptEncoder -from ultralytics.nn.modules import MLP - -from .decoders import MaskDecoder -from .sam2_blocks import TwoWayTransformer -from .utils import get_1d_sine_pe, select_closest_cond_frames - -# a large negative value as a placeholder score for missing objects -NO_OBJ_SCORE = -1024.0 - - -class SAM2Model(torch.nn.Module): - """SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.""" - - mask_threshold: float = 0.0 - - def __init__( - self, - image_encoder, - memory_attention, - memory_encoder, - num_maskmem=7, # default 1 input frame + 6 previous frames - image_size=512, - backbone_stride=16, # stride of the image backbone output - sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob - sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob - # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks - binarize_mask_from_pts_for_mem_enc=False, - use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder - # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit, - # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model - # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM. - max_cond_frames_in_attn=-1, - # on the first frame, whether to directly add the no-memory embedding to the image feature - # (instead of using the transformer encoder) - directly_add_no_mem_embed=False, - # whether to use high-resolution feature maps in the SAM mask decoder - use_high_res_features_in_sam=False, - # whether to output multiple (3) masks for the first click on initial conditioning frames - multimask_output_in_sam=False, - # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`; - # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points) - multimask_min_pt_num=1, - multimask_max_pt_num=1, - # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`) - multimask_output_for_tracking=False, - # Whether to use multimask tokens for obj ptr; Only relevant when both - # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True - use_multimask_token_for_obj_ptr: bool = False, - # whether to use sigmoid to restrict ious prediction to [0-1] - iou_prediction_use_sigmoid=False, - # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5). - # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of - # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. - memory_temporal_stride_for_eval=1, - # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click - # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames - add_all_frames_to_correct_as_cond=False, - # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) - non_overlap_masks_for_mem_enc=False, - # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder - use_obj_ptrs_in_encoder=False, - # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`) - max_obj_ptrs_in_encoder=16, - # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`) - add_tpos_enc_to_obj_ptrs=True, - # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference - # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) - proj_tpos_enc_in_obj_ptrs=False, - # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation - # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) - only_obj_ptrs_in_the_past_for_eval=False, - # Whether to predict if there is an object in the frame - pred_obj_scores: bool = False, - # Whether to use an MLP to predict object scores - pred_obj_scores_mlp: bool = False, - # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True; - # Whether to have a fixed no obj pointer when there is no object present - # or to use it as an additive embedding with obj_ptr produced by decoder - fixed_no_obj_ptr: bool = False, - # Soft no object, i.e. mix in no_obj_ptr softly, - # hope to make recovery easier if there is a mistake and mitigate accumulation of errors - soft_no_obj_ptr: bool = False, - use_mlp_for_obj_ptr_proj: bool = False, - # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. - sam_mask_decoder_extra_args=None, - compile_image_encoder: bool = False, - ): - """Initializes SAM2Model model with image encoder, memory attention, and memory encoder components.""" - super().__init__() - - # Part 1: the image backbone - self.image_encoder = image_encoder - # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting - self.use_high_res_features_in_sam = use_high_res_features_in_sam - self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 - self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder - self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder - if use_obj_ptrs_in_encoder: - # A conv layer to downsample the mask prompt to stride 4 (the same stride as - # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, - # so that it can be fed into the SAM mask decoder to generate a pointer. - self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) - self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs - if proj_tpos_enc_in_obj_ptrs: - assert add_tpos_enc_to_obj_ptrs # these options need to be used together - self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs - self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval - - # Part 2: memory attention to condition current frame's visual features - # with memories (and obj ptrs) from past frames - self.memory_attention = memory_attention - self.hidden_dim = memory_attention.d_model - - # Part 3: memory encoder for the previous frame's outputs - self.memory_encoder = memory_encoder - self.mem_dim = self.hidden_dim - if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"): - # if there is compression of memories along channel dim - self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] - self.num_maskmem = num_maskmem # Number of memories accessible - # Temporal encoding of the memories - self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim)) - trunc_normal_(self.maskmem_tpos_enc, std=0.02) - # a single token to indicate no memory embedding from previous frames - self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) - self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) - trunc_normal_(self.no_mem_embed, std=0.02) - trunc_normal_(self.no_mem_pos_enc, std=0.02) - self.directly_add_no_mem_embed = directly_add_no_mem_embed - # Apply sigmoid to the output raw mask logits (to turn them from - # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder - self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc - self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc - self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc - self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc - self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval - # On frames with mask input, whether to directly output the input mask without - # using a SAM prompt encoder + mask decoder - self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam - self.multimask_output_in_sam = multimask_output_in_sam - self.multimask_min_pt_num = multimask_min_pt_num - self.multimask_max_pt_num = multimask_max_pt_num - self.multimask_output_for_tracking = multimask_output_for_tracking - self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr - self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid - - # Part 4: SAM-style prompt encoder (for both mask and point inputs) - # and SAM-style mask decoder for the final mask output - self.image_size = image_size - self.backbone_stride = backbone_stride - self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args - self.pred_obj_scores = pred_obj_scores - self.pred_obj_scores_mlp = pred_obj_scores_mlp - self.fixed_no_obj_ptr = fixed_no_obj_ptr - self.soft_no_obj_ptr = soft_no_obj_ptr - if self.fixed_no_obj_ptr: - assert self.pred_obj_scores - assert self.use_obj_ptrs_in_encoder - if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: - self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) - trunc_normal_(self.no_obj_ptr, std=0.02) - self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj - - self._build_sam_heads() - self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond - self.max_cond_frames_in_attn = max_cond_frames_in_attn - - # Model compilation - if compile_image_encoder: - # Compile the forward function (not the full module) to allow loading checkpoints. - print("Image encoder compilation is enabled. First forward pass will be slow.") - self.image_encoder.forward = torch.compile( - self.image_encoder.forward, - mode="max-autotune", - fullgraph=True, - dynamic=False, - ) - - @property - def device(self): - """Returns the device on which the model's parameters are stored.""" - return next(self.parameters()).device - - def forward(self, *args, **kwargs): - """Processes input frames and prompts to generate object masks and scores in video sequences.""" - raise NotImplementedError( - "Please use the corresponding methods in SAM2VideoPredictor for inference." - "See notebooks/video_predictor_example.ipynb for an example." - ) - - def _build_sam_heads(self): - """Builds SAM-style prompt encoder and mask decoder for image segmentation tasks.""" - self.sam_prompt_embed_dim = self.hidden_dim - self.sam_image_embedding_size = self.image_size // self.backbone_stride - - # build PromptEncoder and MaskDecoder from SAM - # (their hyperparameters like `mask_in_chans=16` are from SAM code) - self.sam_prompt_encoder = PromptEncoder( - embed_dim=self.sam_prompt_embed_dim, - image_embedding_size=( - self.sam_image_embedding_size, - self.sam_image_embedding_size, - ), - input_image_size=(self.image_size, self.image_size), - mask_in_chans=16, - ) - self.sam_mask_decoder = MaskDecoder( - num_multimask_outputs=3, - transformer=TwoWayTransformer( - depth=2, - embedding_dim=self.sam_prompt_embed_dim, - mlp_dim=2048, - num_heads=8, - ), - transformer_dim=self.sam_prompt_embed_dim, - iou_head_depth=3, - iou_head_hidden_dim=256, - use_high_res_features=self.use_high_res_features_in_sam, - iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, - pred_obj_scores=self.pred_obj_scores, - pred_obj_scores_mlp=self.pred_obj_scores_mlp, - use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, - **(self.sam_mask_decoder_extra_args or {}), - ) - if self.use_obj_ptrs_in_encoder: - # a linear projection on SAM output tokens to turn them into object pointers - self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) - if self.use_mlp_for_obj_ptr_proj: - self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3) - else: - self.obj_ptr_proj = torch.nn.Identity() - if self.proj_tpos_enc_in_obj_ptrs: - # a linear projection on temporal positional encoding in object pointers to - # avoid potential interference with spatial positional encoding - self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) - else: - self.obj_ptr_tpos_proj = torch.nn.Identity() - - def _forward_sam_heads( - self, - backbone_features, - point_inputs=None, - mask_inputs=None, - high_res_features=None, - multimask_output=False, - ): - """ - Forward SAM prompt encoders and mask heads. - - Args: - backbone_features (torch.Tensor): Image features with shape (B, C, H, W). - point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts. - 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute - pixel-unit coordinates in (x, y) format for P input points. - 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks, - 0 means negative clicks, and -1 means padding. - mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the - same spatial size as the image. - high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes - (B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps - for SAM decoder. - multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False, - output only 1 mask and its IoU estimate. - - Returns: - (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): - low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits. - high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits. - ious: Tensor of shape (B, M) with estimated IoU for each output mask. - low_res_masks: Tensor of shape (B, 1, H*4, W*4) with best low-resolution mask. - high_res_masks: Tensor of shape (B, 1, H*16, W*16) with best high-resolution mask. - obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask. - object_score_logits: Tensor of shape (B,) with object score logits. - - Where M is 3 if multimask_output=True, and 1 if multimask_output=False. - - Examples: - >>> backbone_features = torch.rand(1, 256, 32, 32) - >>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])} - >>> mask_inputs = torch.rand(1, 1, 512, 512) - >>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs) - >>> low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits = results - """ - B = backbone_features.size(0) - device = backbone_features.device - assert backbone_features.size(1) == self.sam_prompt_embed_dim - assert backbone_features.size(2) == self.sam_image_embedding_size - assert backbone_features.size(3) == self.sam_image_embedding_size - - # a) Handle point prompts - if point_inputs is not None: - sam_point_coords = point_inputs["point_coords"] - sam_point_labels = point_inputs["point_labels"] - assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B - else: - # If no points are provide, pad with an empty point (with label -1) - sam_point_coords = torch.zeros(B, 1, 2, device=device) - sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) - - # b) Handle mask prompts - if mask_inputs is not None: - # If mask_inputs is provided, downsize it into low-res mask input if needed - # and feed it as a dense mask prompt into the SAM mask encoder - assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) - if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: - sam_mask_prompt = F.interpolate( - mask_inputs.float(), - size=self.sam_prompt_encoder.mask_input_size, - align_corners=False, - mode="bilinear", - antialias=True, # use antialias for downsampling - ) - else: - sam_mask_prompt = mask_inputs - else: - # Otherwise, simply feed None (and SAM's prompt encoder will add - # a learned `no_mask_embed` to indicate no mask input in this case). - sam_mask_prompt = None - - sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( - points=(sam_point_coords, sam_point_labels), - boxes=None, - masks=sam_mask_prompt, - ) - ( - low_res_multimasks, - ious, - sam_output_tokens, - object_score_logits, - ) = self.sam_mask_decoder( - image_embeddings=backbone_features, - image_pe=self.sam_prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=multimask_output, - repeat_image=False, # the image is already batched - high_res_features=high_res_features, - ) - if self.pred_obj_scores: - is_obj_appearing = object_score_logits > 0 - - # Mask used for spatial memories is always a *hard* choice between obj and no obj, - # consistent with the actual mask prediction - low_res_multimasks = torch.where( - is_obj_appearing[:, None, None], - low_res_multimasks, - NO_OBJ_SCORE, - ) - - # convert masks from possibly bfloat16 (or float16) to float32 - # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) - low_res_multimasks = low_res_multimasks.float() - high_res_multimasks = F.interpolate( - low_res_multimasks, - size=(self.image_size, self.image_size), - mode="bilinear", - align_corners=False, - ) - - sam_output_token = sam_output_tokens[:, 0] - if multimask_output: - # take the best mask prediction (with the highest IoU estimation) - best_iou_inds = torch.argmax(ious, dim=-1) - batch_inds = torch.arange(B, device=device) - low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) - high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) - if sam_output_tokens.size(1) > 1: - sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] - else: - low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks - - # Extract object pointer from the SAM output token (with occlusion handling) - obj_ptr = self.obj_ptr_proj(sam_output_token) - if self.pred_obj_scores: - # Allow *soft* no obj ptr, unlike for masks - if self.soft_no_obj_ptr: - # Only hard possible with gt - assert not self.teacher_force_obj_scores_for_mem - lambda_is_obj_appearing = object_score_logits.sigmoid() - else: - lambda_is_obj_appearing = is_obj_appearing.float() - - if self.fixed_no_obj_ptr: - obj_ptr = lambda_is_obj_appearing * obj_ptr - obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr - - return ( - low_res_multimasks, - high_res_multimasks, - ious, - low_res_masks, - high_res_masks, - obj_ptr, - object_score_logits, - ) - - def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): - """Processes mask inputs to generate output mask logits and object pointers without using SAM.""" - # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). - out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 - mask_inputs_float = mask_inputs.float() - high_res_masks = mask_inputs_float * out_scale + out_bias - low_res_masks = F.interpolate( - high_res_masks, - size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), - align_corners=False, - mode="bilinear", - antialias=True, # use antialias for downsampling - ) - # a dummy IoU prediction of all 1's under mask input - ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() - if not self.use_obj_ptrs_in_encoder: - # all zeros as a dummy object pointer (of shape [B, C]) - obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device) - else: - # produce an object pointer using the SAM decoder from the mask input - _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( - backbone_features=backbone_features, - mask_inputs=self.mask_downsample(mask_inputs_float), - high_res_features=high_res_features, - ) - # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; - # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying - # on the object_scores from the SAM decoder. - is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) - is_obj_appearing = is_obj_appearing[..., None] - lambda_is_obj_appearing = is_obj_appearing.float() - object_score_logits = out_scale * lambda_is_obj_appearing + out_bias - if self.pred_obj_scores: - if self.fixed_no_obj_ptr: - obj_ptr = lambda_is_obj_appearing * obj_ptr - obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr - - return ( - low_res_masks, - high_res_masks, - ious, - low_res_masks, - high_res_masks, - obj_ptr, - object_score_logits, - ) - - def forward_image(self, img_batch: torch.Tensor): - """Process image batch through encoder to extract multi-level features for SAM model.""" - backbone_out = self.image_encoder(img_batch) - if self.use_high_res_features_in_sam: - # precompute projected level 0 and level 1 features in SAM decoder - # to avoid running it again on every SAM click - backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0]) - backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1]) - return backbone_out - - def _prepare_backbone_features(self, backbone_out): - """Prepare and flatten visual features from the image backbone output.""" - backbone_out = backbone_out.copy() - assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) - assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels - - feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] - vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] - - feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] - # flatten NxCxHxW to HWxNxC - vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] - vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] - - return backbone_out, vision_feats, vision_pos_embeds, feat_sizes - - def _prepare_memory_conditioned_features( - self, - frame_idx, - is_init_cond_frame, - current_vision_feats, - current_vision_pos_embeds, - feat_sizes, - output_dict, - num_frames, - track_in_reverse=False, # tracking in reverse time order (for demo usage) - ): - """Prepares memory-conditioned features by fusing current frame's visual features with previous memories.""" - B = current_vision_feats[-1].size(1) # batch size on this frame - C = self.hidden_dim - H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size - device = current_vision_feats[-1].device - # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. - # In this case, we skip the fusion with any memory. - if self.num_maskmem == 0: # Disable memory and skip fusion - pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) - return pix_feat - - num_obj_ptr_tokens = 0 - # Step 1: condition the visual features of the current frame on previous memories - if not is_init_cond_frame: - # Retrieve the memories encoded with the maskmem backbone - to_cat_memory, to_cat_memory_pos_embed = [], [] - # Add conditioning frames's output first (all cond frames have t_pos=0 for - # when getting temporal positional embedding below) - assert len(output_dict["cond_frame_outputs"]) > 0 - # Select a maximum number of temporally closest cond frames for cross attention - cond_outputs = output_dict["cond_frame_outputs"] - selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( - frame_idx, cond_outputs, self.max_cond_frames_in_attn - ) - t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] - # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory - # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 - # We also allow taking the memory frame non-consecutively (with r>1), in which case - # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. - r = self.memory_temporal_stride_for_eval - for t_pos in range(1, self.num_maskmem): - t_rel = self.num_maskmem - t_pos # how many frames before current frame - if t_rel == 1: - # for t_rel == 1, we take the last frame (regardless of r) - if not track_in_reverse: - # the frame immediately before this frame (i.e. frame_idx - 1) - prev_frame_idx = frame_idx - t_rel - else: - # the frame immediately after this frame (i.e. frame_idx + 1) - prev_frame_idx = frame_idx + t_rel - else: - # for t_rel >= 2, we take the memory frame from every r-th frames - if not track_in_reverse: - # first find the nearest frame among every r-th frames before this frame - # for r=1, this would be (frame_idx - 2) - prev_frame_idx = ((frame_idx - 2) // r) * r - # then seek further among every r-th frames - prev_frame_idx = prev_frame_idx - (t_rel - 2) * r - else: - # first find the nearest frame among every r-th frames after this frame - # for r=1, this would be (frame_idx + 2) - prev_frame_idx = -(-(frame_idx + 2) // r) * r - # then seek further among every r-th frames - prev_frame_idx = prev_frame_idx + (t_rel - 2) * r - out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) - if out is None: - # If an unselected conditioning frame is among the last (self.num_maskmem - 1) - # frames, we still attend to it as if it's a non-conditioning frame. - out = unselected_cond_outputs.get(prev_frame_idx, None) - t_pos_and_prevs.append((t_pos, out)) - - for t_pos, prev in t_pos_and_prevs: - if prev is None: - continue # skip padding frames - # "maskmem_features" might have been offloaded to CPU in demo use cases, - # so we load it back to GPU (it's a no-op if it's already on GPU). - feats = prev["maskmem_features"].cuda(non_blocking=True) - to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) - # Spatial positional encoding (it might have been offloaded to CPU in eval) - maskmem_enc = prev["maskmem_pos_enc"][-1].cuda() - maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) - # Temporal positional encoding - maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] - to_cat_memory_pos_embed.append(maskmem_enc) - - # Construct the list of past object pointers - if self.use_obj_ptrs_in_encoder: - max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) - # First add those object pointers from selected conditioning frames - # (optionally, only include object pointers in the past during evaluation) - if not self.training and self.only_obj_ptrs_in_the_past_for_eval: - ptr_cond_outputs = { - t: out - for t, out in selected_cond_outputs.items() - if (t >= frame_idx if track_in_reverse else t <= frame_idx) - } - else: - ptr_cond_outputs = selected_cond_outputs - pos_and_ptrs = [ - # Temporal pos encoding contains how far away each pointer is from current frame - (abs(frame_idx - t), out["obj_ptr"]) - for t, out in ptr_cond_outputs.items() - ] - # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame - for t_diff in range(1, max_obj_ptrs_in_encoder): - t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff - if t < 0 or (num_frames is not None and t >= num_frames): - break - out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None)) - if out is not None: - pos_and_ptrs.append((t_diff, out["obj_ptr"])) - # If we have at least one object pointer, add them to the across attention - if len(pos_and_ptrs) > 0: - pos_list, ptrs_list = zip(*pos_and_ptrs) - # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape - obj_ptrs = torch.stack(ptrs_list, dim=0) - # a temporal positional embedding based on how far each object pointer is from - # the current frame (sine embedding normalized by the max pointer num). - if self.add_tpos_enc_to_obj_ptrs: - t_diff_max = max_obj_ptrs_in_encoder - 1 - tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim - obj_pos = torch.tensor(pos_list, device=device) - obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) - obj_pos = self.obj_ptr_tpos_proj(obj_pos) - obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) - else: - obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) - if self.mem_dim < C: - # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C - obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim) - obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) - obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) - to_cat_memory.append(obj_ptrs) - to_cat_memory_pos_embed.append(obj_pos) - num_obj_ptr_tokens = obj_ptrs.shape[0] - else: - num_obj_ptr_tokens = 0 - else: - # for initial conditioning frames, encode them without using any previous memory - if self.directly_add_no_mem_embed: - # directly add no-mem embedding (instead of using the transformer encoder) - pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed - pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) - return pix_feat_with_mem - - # Use a dummy token on the first frame (to avoid empty memory input to transformer encoder) - to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] - to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] - - # Step 2: Concatenate the memories and forward through the transformer encoder - memory = torch.cat(to_cat_memory, dim=0) - memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) - - pix_feat_with_mem = self.memory_attention( - curr=current_vision_feats, - curr_pos=current_vision_pos_embeds, - memory=memory, - memory_pos=memory_pos_embed, - num_obj_ptr_tokens=num_obj_ptr_tokens, - ) - # reshape the output (HW)BC => BCHW - pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) - return pix_feat_with_mem - - def _encode_new_memory( - self, - current_vision_feats, - feat_sizes, - pred_masks_high_res, - is_mask_from_pts, - ): - """Encodes the current frame's features and predicted masks into a new memory representation.""" - B = current_vision_feats[-1].size(1) # batch size on this frame - C = self.hidden_dim - H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size - # top-level feature, (HW)BC => BCHW - pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) - if self.non_overlap_masks_for_mem_enc and not self.training: - # optionally, apply non-overlapping constraints to the masks (it's applied - # in the batch dimension and should only be used during eval, where all - # the objects come from the same video under batch size 1). - pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res) - # scale the raw mask logits with a temperature before applying sigmoid - binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts - if binarize and not self.training: - mask_for_mem = (pred_masks_high_res > 0).float() - else: - # apply sigmoid on the raw mask logits to turn them into range (0, 1) - mask_for_mem = torch.sigmoid(pred_masks_high_res) - # apply scale and bias terms to the sigmoid probabilities - if self.sigmoid_scale_for_mem_enc != 1.0: - mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc - if self.sigmoid_bias_for_mem_enc != 0.0: - mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc - maskmem_out = self.memory_encoder( - pix_feat, - mask_for_mem, - skip_mask_sigmoid=True, # sigmoid already applied - ) - maskmem_features = maskmem_out["vision_features"] - maskmem_pos_enc = maskmem_out["vision_pos_enc"] - - return maskmem_features, maskmem_pos_enc - - def track_step( - self, - frame_idx, - is_init_cond_frame, - current_vision_feats, - current_vision_pos_embeds, - feat_sizes, - point_inputs, - mask_inputs, - output_dict, - num_frames, - track_in_reverse=False, # tracking in reverse time order (for demo usage) - # Whether to run the memory encoder on the predicted masks. Sometimes we might want - # to skip the memory encoder with `run_mem_encoder=False`. For example, - # in demo we might call `track_step` multiple times for each user click, - # and only encode the memory when the user finalizes their clicks. And in ablation - # settings like SAM training on static images, we don't need the memory encoder. - run_mem_encoder=True, - # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). - prev_sam_mask_logits=None, - ): - """Performs a single tracking step, updating object masks and memory features based on current frame inputs.""" - current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} - # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW - if len(current_vision_feats) > 1: - high_res_features = [ - x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) - for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) - ] - else: - high_res_features = None - if mask_inputs is not None and self.use_mask_input_as_output_without_sam: - # When use_mask_input_as_output_without_sam=True, we directly output the mask input - # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. - pix_feat = current_vision_feats[-1].permute(1, 2, 0) - pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) - sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs) - else: - # fused the visual feature with previous memory features in the memory bank - pix_feat_with_mem = self._prepare_memory_conditioned_features( - frame_idx=frame_idx, - is_init_cond_frame=is_init_cond_frame, - current_vision_feats=current_vision_feats[-1:], - current_vision_pos_embeds=current_vision_pos_embeds[-1:], - feat_sizes=feat_sizes[-1:], - output_dict=output_dict, - num_frames=num_frames, - track_in_reverse=track_in_reverse, - ) - # apply SAM-style segmentation head - # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, - # e.g. in demo where such logits come from earlier interaction instead of correction sampling - # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) - if prev_sam_mask_logits is not None: - assert point_inputs is not None and mask_inputs is None - mask_inputs = prev_sam_mask_logits - multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) - sam_outputs = self._forward_sam_heads( - backbone_features=pix_feat_with_mem, - point_inputs=point_inputs, - mask_inputs=mask_inputs, - high_res_features=high_res_features, - multimask_output=multimask_output, - ) - ( - _, - _, - _, - low_res_masks, - high_res_masks, - obj_ptr, - _, - ) = sam_outputs - - current_out["pred_masks"] = low_res_masks - current_out["pred_masks_high_res"] = high_res_masks - current_out["obj_ptr"] = obj_ptr - - # Finally run the memory encoder on the predicted mask to encode - # it into a new memory feature (that can be used in future frames) - if run_mem_encoder and self.num_maskmem > 0: - high_res_masks_for_mem_enc = high_res_masks - maskmem_features, maskmem_pos_enc = self._encode_new_memory( - current_vision_feats=current_vision_feats, - feat_sizes=feat_sizes, - pred_masks_high_res=high_res_masks_for_mem_enc, - is_mask_from_pts=(point_inputs is not None), - ) - current_out["maskmem_features"] = maskmem_features - current_out["maskmem_pos_enc"] = maskmem_pos_enc - else: - current_out["maskmem_features"] = None - current_out["maskmem_pos_enc"] = None - - return current_out - - def _use_multimask(self, is_init_cond_frame, point_inputs): - """Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs.""" - num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) - multimask_output = ( - self.multimask_output_in_sam - and (is_init_cond_frame or self.multimask_output_for_tracking) - and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) - ) - return multimask_output - - def _apply_non_overlapping_constraints(self, pred_masks): - """Applies non-overlapping constraints to object masks, keeping highest scoring object at each location.""" - batch_size = pred_masks.size(0) - if batch_size == 1: - return pred_masks - - device = pred_masks.device - # "max_obj_inds": object index of the object with the highest score at each location - max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) - # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` - batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] - keep = max_obj_inds == batch_obj_inds - # suppress overlapping regions' scores below -10.0 so that the foreground regions - # don't overlap (here sigmoid(-10.0)=4.5398e-05) - pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) - return pred_masks diff --git a/ultralytics/models/sam2/modules/sam2_blocks.py b/ultralytics/models/sam2/modules/sam2_blocks.py deleted file mode 100644 index 67dc587e9..000000000 --- a/ultralytics/models/sam2/modules/sam2_blocks.py +++ /dev/null @@ -1,715 +0,0 @@ -# Ultralytics YOLO 🚀, AGPL-3.0 license - -import copy -import math -from functools import partial -from typing import Optional, Tuple, Type, Union - -import torch -import torch.nn.functional as F -from torch import Tensor, nn - -from ultralytics.models.sam.modules.transformer import ( - Attention, -) -from ultralytics.models.sam.modules.transformer import ( - TwoWayAttentionBlock as SAMTwoWayAttentionBlock, -) -from ultralytics.models.sam.modules.transformer import ( - TwoWayTransformer as SAMTwoWayTransformer, -) -from ultralytics.nn.modules import MLP, LayerNorm2d - -from .utils import apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition - - -class DropPath(nn.Module): - """Implements stochastic depth regularization for neural networks during training.""" - - def __init__(self, drop_prob=0.0, scale_by_keep=True): - """Initialize DropPath module with specified drop probability and scaling option.""" - super(DropPath, self).__init__() - self.drop_prob = drop_prob - self.scale_by_keep = scale_by_keep - - def forward(self, x): - """Applies stochastic depth to input tensor during training, with optional scaling.""" - if self.drop_prob == 0.0 or not self.training: - return x - keep_prob = 1 - self.drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) - random_tensor = x.new_empty(shape).bernoulli_(keep_prob) - if keep_prob > 0.0 and self.scale_by_keep: - random_tensor.div_(keep_prob) - return x * random_tensor - - -class MaskDownSampler(nn.Module): - """Downsamples and embeds masks using convolutional layers and layer normalization for efficient processing.""" - - def __init__( - self, - embed_dim=256, - kernel_size=4, - stride=4, - padding=0, - total_stride=16, - activation=nn.GELU, - ): - """Initializes a mask downsampler module for progressive downsampling and channel expansion.""" - super().__init__() - num_layers = int(math.log2(total_stride) // math.log2(stride)) - assert stride**num_layers == total_stride - self.encoder = nn.Sequential() - mask_in_chans, mask_out_chans = 1, 1 - for _ in range(num_layers): - mask_out_chans = mask_in_chans * (stride**2) - self.encoder.append( - nn.Conv2d( - mask_in_chans, - mask_out_chans, - kernel_size=kernel_size, - stride=stride, - padding=padding, - ) - ) - self.encoder.append(LayerNorm2d(mask_out_chans)) - self.encoder.append(activation()) - mask_in_chans = mask_out_chans - - self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) - - def forward(self, x): - """Downsamples and encodes input mask to embed_dim channels using convolutional layers and LayerNorm2d.""" - return self.encoder(x) - - -# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) -class CXBlock(nn.Module): - """ - ConvNeXt Block for efficient feature extraction in convolutional neural networks. - - This block implements a modified version of the ConvNeXt architecture, offering two equivalent - implementations for improved performance and flexibility. - - Attributes: - dwconv (nn.Conv2d): Depthwise convolution layer. - norm (LayerNorm2d): Layer normalization applied to channels. - pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer. - act (nn.GELU): GELU activation function. - pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer. - gamma (nn.Parameter | None): Learnable scale parameter for layer scaling. - drop_path (nn.Module): DropPath layer for stochastic depth regularization. - - Methods: - forward: Processes the input tensor through the ConvNeXt block. - - Examples: - >>> import torch - >>> x = torch.randn(1, 64, 56, 56) - >>> block = CXBlock(dim=64, kernel_size=7, padding=3) - >>> output = block(x) - >>> print(output.shape) - torch.Size([1, 64, 56, 56]) - """ - - def __init__( - self, - dim, - kernel_size=7, - padding=3, - drop_path=0.0, - layer_scale_init_value=1e-6, - use_dwconv=True, - ): - """ - Initialize a ConvNeXt Block. - - This block implements a ConvNeXt architecture with optional depthwise convolution, layer normalization, - pointwise convolutions, and GELU activation. - - Args: - dim (int): Number of input channels. - kernel_size (int): Size of the convolutional kernel. Default is 7. - padding (int): Padding size for the convolution. Default is 3. - drop_path (float): Stochastic depth rate. Default is 0.0. - layer_scale_init_value (float): Initial value for Layer Scale. Default is 1e-6. - use_dwconv (bool): Whether to use depthwise convolution. Default is True. - - Attributes: - dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer. - norm (LayerNorm2d): Layer normalization applied to the output of dwconv. - pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer. - act (nn.GELU): GELU activation function. - pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer. - gamma (nn.Parameter | None): Learnable scale parameter for the residual path. - - Examples: - >>> block = CXBlock(dim=64, kernel_size=7, padding=3) - >>> x = torch.randn(1, 64, 32, 32) - >>> output = block(x) - >>> print(output.shape) - torch.Size([1, 64, 32, 32]) - """ - super().__init__() - self.dwconv = nn.Conv2d( - dim, - dim, - kernel_size=kernel_size, - padding=padding, - groups=dim if use_dwconv else 1, - ) # depthwise conv - self.norm = LayerNorm2d(dim, eps=1e-6) - self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers - self.act = nn.GELU() - self.pwconv2 = nn.Linear(4 * dim, dim) - self.gamma = ( - nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) - if layer_scale_init_value > 0 - else None - ) - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - - def forward(self, x): - """Applies ConvNeXt block operations to input tensor, including convolutions and residual connection.""" - input = x - x = self.dwconv(x) - x = self.norm(x) - x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) - x = self.pwconv1(x) - x = self.act(x) - x = self.pwconv2(x) - if self.gamma is not None: - x = self.gamma * x - x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) - - x = input + self.drop_path(x) - return x - - -class Fuser(nn.Module): - """ - A module for fusing features through multiple layers of a neural network. - - This class applies a series of identical layers to an input tensor, optionally projecting the input first. - - Attributes: - proj (nn.Module): An optional input projection layer. Identity if no projection is needed. - layers (nn.ModuleList): A list of identical layers to be applied sequentially. - - Methods: - forward: Applies the fuser to an input tensor. - - Examples: - >>> layer = CXBlock(dim=256) - >>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True) - >>> x = torch.randn(1, 256, 32, 32) - >>> output = fuser(x) - >>> print(output.shape) - torch.Size([1, 256, 32, 32]) - """ - - def __init__(self, layer, num_layers, dim=None, input_projection=False): - """ - Initializes the Fuser module. - - This module creates a sequence of identical layers and optionally applies an input projection. - - Args: - layer (nn.Module): The layer to be replicated in the fuser. - num_layers (int): The number of times to replicate the layer. - dim (int | None): The dimension for input projection, if used. - input_projection (bool): Whether to use input projection. - - Attributes: - proj (nn.Module): The input projection layer, or nn.Identity if not used. - layers (nn.ModuleList): A list of replicated layers. - - Examples: - >>> layer = nn.Linear(64, 64) - >>> fuser = Fuser(layer, num_layers=3, dim=64, input_projection=True) - >>> input_tensor = torch.randn(1, 64) - >>> output = fuser(input_tensor) - """ - super().__init__() - self.proj = nn.Identity() - self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)]) - - if input_projection: - assert dim is not None - self.proj = nn.Conv2d(dim, dim, kernel_size=1) - - def forward(self, x): - """Applies a series of layers to the input tensor, optionally projecting it first.""" - x = self.proj(x) - for layer in self.layers: - x = layer(x) - return x - - -class TwoWayAttentionBlock(SAMTwoWayAttentionBlock): - """ - A two-way attention block for performing self-attention and cross-attention in both directions. - - This block extends the SAMTwoWayAttentionBlock and consists of four main components: self-attention on - sparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and - cross-attention from dense to sparse inputs. - - Attributes: - self_attn (Attention): Self-attention layer for queries. - norm1 (nn.LayerNorm): Layer normalization after the first attention block. - cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys. - norm2 (nn.LayerNorm): Layer normalization after the second attention block. - mlp (MLP): MLP block for transforming query embeddings. - norm3 (nn.LayerNorm): Layer normalization after the MLP block. - norm4 (nn.LayerNorm): Layer normalization after the third attention block. - cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries. - skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer. - - Methods: - forward: Processes input through the attention blocks and MLP. - - Examples: - >>> block = TwoWayAttentionBlock(embedding_dim=256, num_heads=8) - >>> sparse_input = torch.randn(1, 100, 256) - >>> dense_input = torch.randn(1, 256, 16, 16) - >>> sparse_output, dense_output = block(sparse_input, dense_input) - """ - - def __init__( - self, - embedding_dim: int, - num_heads: int, - mlp_dim: int = 2048, - activation: Type[nn.Module] = nn.ReLU, - attention_downsample_rate: int = 2, - skip_first_layer_pe: bool = False, - ) -> None: - """ - Initializes a TwoWayAttentionBlock for performing self-attention and cross-attention in two directions. - - This block consists of four main layers: self-attention on sparse inputs, cross-attention of sparse inputs - to dense inputs, an MLP block on sparse inputs, and cross-attention of dense inputs to sparse inputs. - - Args: - embedding_dim (int): The channel dimension of the embeddings. - num_heads (int): The number of heads in the attention layers. - mlp_dim (int): The hidden dimension of the MLP block. - activation (Type[nn.Module]): The activation function of the MLP block. - attention_downsample_rate (int): The downsample rate for attention computations. - skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer. - - Attributes: - self_attn (Attention): The self-attention layer for the queries. - norm1 (nn.LayerNorm): Layer normalization following the first attention block. - cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys. - norm2 (nn.LayerNorm): Layer normalization following the second attention block. - mlp (MLP): MLP block that transforms the query embeddings. - norm3 (nn.LayerNorm): Layer normalization following the MLP block. - norm4 (nn.LayerNorm): Layer normalization following the third attention block. - cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries. - skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer. - - Examples: - >>> block = TwoWayAttentionBlock(embedding_dim=256, num_heads=8, mlp_dim=2048) - >>> sparse_inputs = torch.randn(1, 100, 256) - >>> dense_inputs = torch.randn(1, 256, 32, 32) - >>> sparse_outputs, dense_outputs = block(sparse_inputs, dense_inputs) - """ - super().__init__(embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate, skip_first_layer_pe) - self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, act=activation) - - -class TwoWayTransformer(SAMTwoWayTransformer): - """ - A Two-Way Transformer module for simultaneous attention to image and query points. - - This class implements a specialized transformer decoder that attends to an input image using queries with - supplied positional embeddings. It is particularly useful for tasks like object detection, image - segmentation, and point cloud processing. - - Attributes: - depth (int): Number of layers in the transformer. - embedding_dim (int): Channel dimension for input embeddings. - num_heads (int): Number of heads for multihead attention. - mlp_dim (int): Internal channel dimension for the MLP block. - layers (nn.ModuleList): List of TwoWayAttentionBlock layers comprising the transformer. - final_attn_token_to_image (Attention): Final attention layer from queries to image. - norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries. - - Methods: - forward: Processes input image embeddings and query embeddings through the transformer. - - Examples: - >>> transformer = TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048) - >>> image_embedding = torch.randn(1, 256, 64, 64) - >>> query_embedding = torch.randn(1, 100, 256) - >>> output = transformer(image_embedding, query_embedding) - """ - - def __init__( - self, - depth: int, - embedding_dim: int, - num_heads: int, - mlp_dim: int, - activation: Type[nn.Module] = nn.ReLU, - attention_downsample_rate: int = 2, - ) -> None: - """ - Initializes a TwoWayTransformer instance. - - This transformer decoder attends to an input image using queries with supplied positional embeddings. - It is designed for tasks like object detection, image segmentation, and point cloud processing. - - Args: - depth (int): Number of layers in the transformer. - embedding_dim (int): Channel dimension for the input embeddings. - num_heads (int): Number of heads for multihead attention. Must divide embedding_dim. - mlp_dim (int): Channel dimension internal to the MLP block. - activation (Type[nn.Module]): Activation function to use in the MLP block. - attention_downsample_rate (int): Downsampling rate for attention computations. - - Attributes: - depth (int): Number of layers in the transformer. - embedding_dim (int): Channel dimension for the input embeddings. - num_heads (int): Number of heads for multihead attention. - mlp_dim (int): Internal channel dimension for the MLP block. - layers (nn.ModuleList): List of TwoWayAttentionBlock layers comprising the transformer. - final_attn_token_to_image (Attention): Final attention layer from queries to image. - norm_final_attn (nn.LayerNorm): Layer normalization applied to the final queries. - - Examples: - >>> transformer = TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048) - >>> transformer - TwoWayTransformer( - (layers): ModuleList( - (0-4): 5 x TwoWayAttentionBlock(...) - ) - (final_attn_token_to_image): Attention(...) - (norm_final_attn): LayerNorm(...) - ) - """ - super().__init__(depth, embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate) - self.layers = nn.ModuleList() - for i in range(depth): - self.layers.append( - TwoWayAttentionBlock( - embedding_dim=embedding_dim, - num_heads=num_heads, - mlp_dim=mlp_dim, - activation=activation, - attention_downsample_rate=attention_downsample_rate, - skip_first_layer_pe=(i == 0), - ) - ) - - -class RoPEAttention(Attention): - """Implements rotary position encoding for attention mechanisms in transformer architectures.""" - - def __init__( - self, - *args, - rope_theta=10000.0, - # whether to repeat q rope to match k length - # this is needed for cross-attention to memories - rope_k_repeat=False, - feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution - **kwargs, - ): - """Initializes RoPEAttention with rotary position encoding for attention mechanisms.""" - super().__init__(*args, **kwargs) - - self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta) - freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) - self.freqs_cis = freqs_cis - self.rope_k_repeat = rope_k_repeat - - def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor: - """Applies rotary position encoding and computes attention between query, key, and value tensors.""" - q = self.q_proj(q) - k = self.k_proj(k) - v = self.v_proj(v) - - # Separate into heads - q = self._separate_heads(q, self.num_heads) - k = self._separate_heads(k, self.num_heads) - v = self._separate_heads(v, self.num_heads) - - # Apply rotary position encoding - w = h = math.sqrt(q.shape[-2]) - self.freqs_cis = self.freqs_cis.to(q.device) - if self.freqs_cis.shape[0] != q.shape[-2]: - self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) - if q.shape[-2] != k.shape[-2]: - assert self.rope_k_repeat - - num_k_rope = k.size(-2) - num_k_exclude_rope - q, k[:, :, :num_k_rope] = apply_rotary_enc( - q, - k[:, :, :num_k_rope], - freqs_cis=self.freqs_cis, - repeat_freqs_k=self.rope_k_repeat, - ) - - # Attention - _, _, _, c_per_head = q.shape - attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens - attn = attn / math.sqrt(c_per_head) - attn = torch.softmax(attn, dim=-1) - - # Get output - out = attn @ v - - out = self._recombine_heads(out) - out = self.out_proj(out) - - return out - - -def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: - """Applies pooling and optional normalization to a tensor, handling permutations for spatial operations.""" - if pool is None: - return x - # (B, H, W, C) -> (B, C, H, W) - x = x.permute(0, 3, 1, 2) - x = pool(x) - # (B, C, H', W') -> (B, H', W', C) - x = x.permute(0, 2, 3, 1) - if norm: - x = norm(x) - - return x - - -class MultiScaleAttention(nn.Module): - """Implements multi-scale self-attention with optional query pooling for efficient feature extraction.""" - - def __init__( - self, - dim: int, - dim_out: int, - num_heads: int, - q_pool: nn.Module = None, - ): - """Initializes a multi-scale attention module with configurable query pooling and linear projections.""" - super().__init__() - - self.dim = dim - self.dim_out = dim_out - - self.num_heads = num_heads - head_dim = dim_out // num_heads - self.scale = head_dim**-0.5 - - self.q_pool = q_pool - self.qkv = nn.Linear(dim, dim_out * 3) - self.proj = nn.Linear(dim_out, dim_out) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Applies multi-scale attention to input tensor, optionally downsampling query features.""" - B, H, W, _ = x.shape - # qkv with shape (B, H * W, 3, nHead, C) - qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) - # q, k, v with shape (B, H * W, nheads, C) - q, k, v = torch.unbind(qkv, 2) - - # Q pooling (for downsample at stage changes) - if self.q_pool: - q = do_pool(q.reshape(B, H, W, -1), self.q_pool) - H, W = q.shape[1:3] # downsampled shape - q = q.reshape(B, H * W, self.num_heads, -1) - - # Torch's SDPA expects [B, nheads, H*W, C] so we transpose - x = F.scaled_dot_product_attention( - q.transpose(1, 2), - k.transpose(1, 2), - v.transpose(1, 2), - ) - # Transpose back - x = x.transpose(1, 2) - x = x.reshape(B, H, W, -1) - - x = self.proj(x) - - return x - - -class MultiScaleBlock(nn.Module): - """Multiscale attention block with window partitioning and query pooling for efficient vision transformers.""" - - def __init__( - self, - dim: int, - dim_out: int, - num_heads: int, - mlp_ratio: float = 4.0, - drop_path: float = 0.0, - norm_layer: Union[nn.Module, str] = "LayerNorm", - q_stride: Tuple[int, int] = None, - act_layer: nn.Module = nn.GELU, - window_size: int = 0, - ): - """Initializes a multi-scale attention block with optional window partitioning and downsampling.""" - super().__init__() - - if isinstance(norm_layer, str): - norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) - - self.dim = dim - self.dim_out = dim_out - self.norm1 = norm_layer(dim) - - self.window_size = window_size - - self.pool, self.q_stride = None, q_stride - if self.q_stride: - self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False) - - self.attn = MultiScaleAttention( - dim, - dim_out, - num_heads=num_heads, - q_pool=self.pool, - ) - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - - self.norm2 = norm_layer(dim_out) - self.mlp = MLP( - dim_out, - int(dim_out * mlp_ratio), - dim_out, - num_layers=2, - act=act_layer, - ) - - if dim != dim_out: - self.proj = nn.Linear(dim, dim_out) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Applies multi-scale attention and MLP processing to input tensor, with optional windowing.""" - shortcut = x # B, H, W, C - x = self.norm1(x) - - # Skip connection - if self.dim != self.dim_out: - shortcut = do_pool(self.proj(x), self.pool) - - # Window partition - window_size = self.window_size - if window_size > 0: - H, W = x.shape[1], x.shape[2] - x, pad_hw = window_partition(x, window_size) - - # Window Attention + Q Pooling (if stage change) - x = self.attn(x) - if self.q_stride: - # Shapes have changed due to Q pooling - window_size = self.window_size // self.q_stride[0] - H, W = shortcut.shape[1:3] - - pad_h = (window_size - H % window_size) % window_size - pad_w = (window_size - W % window_size) % window_size - pad_hw = (H + pad_h, W + pad_w) - - # Reverse window partition - if self.window_size > 0: - x = window_unpartition(x, window_size, pad_hw, (H, W)) - - x = shortcut + self.drop_path(x) - # MLP - x = x + self.drop_path(self.mlp(self.norm2(x))) - return x - - -class PositionEmbeddingSine(nn.Module): - """Generates sinusoidal positional embeddings for 2D inputs like images.""" - - def __init__( - self, - num_pos_feats, - temperature: int = 10000, - normalize: bool = True, - scale: Optional[float] = None, - ): - """Initializes sinusoidal position embeddings for 2D image inputs.""" - super().__init__() - assert num_pos_feats % 2 == 0, "Expecting even model width" - self.num_pos_feats = num_pos_feats // 2 - self.temperature = temperature - self.normalize = normalize - if scale is not None and normalize is False: - raise ValueError("normalize should be True if scale is passed") - if scale is None: - scale = 2 * math.pi - self.scale = scale - - self.cache = {} - - def _encode_xy(self, x, y): - """Encodes 2D positions using sine and cosine functions for positional embeddings.""" - assert len(x) == len(y) and x.ndim == y.ndim == 1 - x_embed = x * self.scale - y_embed = y * self.scale - - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) - dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) - - pos_x = x_embed[:, None] / dim_t - pos_y = y_embed[:, None] / dim_t - pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1) - pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1) - return pos_x, pos_y - - @torch.no_grad() - def encode_boxes(self, x, y, w, h): - """Encodes box coordinates and dimensions into positional embeddings for object detection tasks.""" - pos_x, pos_y = self._encode_xy(x, y) - pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) - return pos - - encode = encode_boxes # Backwards compatibility - - @torch.no_grad() - def encode_points(self, x, y, labels): - """Encodes 2D point coordinates with sinusoidal positional embeddings and appends labels.""" - (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape - assert bx == by and nx == ny and bx == bl and nx == nl - pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) - pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) - pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) - return pos - - @torch.no_grad() - def forward(self, x: torch.Tensor): - """Generate sinusoidal position embeddings for 2D inputs.""" - cache_key = (x.shape[-2], x.shape[-1]) - if cache_key in self.cache: - return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) - y_embed = ( - torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) - .view(1, -1, 1) - .repeat(x.shape[0], 1, x.shape[-1]) - ) - x_embed = ( - torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) - .view(1, 1, -1) - .repeat(x.shape[0], x.shape[-2], 1) - ) - - if self.normalize: - eps = 1e-6 - y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale - x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) - dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) - - pos_x = x_embed[:, :, :, None] / dim_t - pos_y = y_embed[:, :, :, None] / dim_t - pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) - pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) - pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) - self.cache[cache_key] = pos[0] - return pos diff --git a/ultralytics/models/sam2/predict.py b/ultralytics/models/sam2/predict.py deleted file mode 100644 index f40c01582..000000000 --- a/ultralytics/models/sam2/predict.py +++ /dev/null @@ -1,177 +0,0 @@ -# Ultralytics YOLO 🚀, AGPL-3.0 license - -import torch - -from ..sam.predict import Predictor -from .build import build_sam2 - - -class SAM2Predictor(Predictor): - """ - A predictor class for the Segment Anything Model 2 (SAM2), extending the base Predictor class. - - This class provides an interface for model inference tailored to image segmentation tasks, leveraging SAM2's - advanced architecture and promptable segmentation capabilities. It facilitates flexible and real-time mask - generation, working with various types of prompts such as bounding boxes, points, and low-resolution masks. - - Attributes: - cfg (Dict): Configuration dictionary specifying model and task-related parameters. - overrides (Dict): Dictionary containing values that override the default configuration. - _callbacks (Dict): Dictionary of user-defined callback functions to augment behavior. - args (namespace): Namespace to hold command-line arguments or other operational variables. - im (torch.Tensor): Preprocessed input image tensor. - features (torch.Tensor): Extracted image features used for inference. - prompts (Dict): Collection of various prompt types, such as bounding boxes and points. - segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones. - model (torch.nn.Module): The loaded SAM2 model. - device (torch.device): The device (CPU or GPU) on which the model is loaded. - _bb_feat_sizes (List[Tuple[int, int]]): List of feature sizes for different backbone levels. - - Methods: - get_model: Builds and returns the SAM2 model. - prompt_inference: Performs image segmentation inference based on various prompts. - set_image: Preprocesses and sets a single image for inference. - get_im_features: Extracts image features from the SAM2 image encoder. - - Examples: - >>> predictor = SAM2Predictor(model='sam2_l.pt') - >>> predictor.set_image('path/to/image.jpg') - >>> masks, scores = predictor.prompt_inference(im=predictor.im, points=[[500, 375]], labels=[1]) - >>> print(f"Generated {len(masks)} mask(s) with scores: {scores}") - """ - - _bb_feat_sizes = [ - (256, 256), - (128, 128), - (64, 64), - ] - - def get_model(self): - """Retrieves and initializes the Segment Anything Model (SAM) for image segmentation tasks.""" - return build_sam2(self.args.model) - - def prompt_inference( - self, - im, - bboxes=None, - points=None, - labels=None, - masks=None, - multimask_output=False, - img_idx=-1, - ): - """ - Performs image segmentation inference based on various prompts using SAM2 architecture. - - Args: - im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W). - bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). - points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels. - labels (np.ndarray | List | None): Labels for point prompts with shape (N,). 1 = foreground, 0 = background. - masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W). - multimask_output (bool): Flag to return multiple masks for ambiguous prompts. - img_idx (int): Index of the image in the batch to process. - - Returns: - (tuple): Tuple containing: - - np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks. - - np.ndarray: Quality scores for each mask, with length C. - - np.ndarray: Low-resolution logits with shape (C, 256, 256) for subsequent inference. - - Examples: - >>> predictor = SAM2Predictor(cfg) - >>> image = torch.rand(1, 3, 640, 640) - >>> bboxes = [[100, 100, 200, 200]] - >>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes) - """ - features = self.get_im_features(im) if self.features is None else self.features - - src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:] - r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) - # Transform input prompts - if points is not None: - points = torch.as_tensor(points, dtype=torch.float32, device=self.device) - points = points[None] if points.ndim == 1 else points - # Assuming labels are all positive if users don't pass labels. - if labels is None: - labels = torch.ones(points.shape[0]) - labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device) - points *= r - # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1) - points, labels = points[:, None], labels[:, None] - if bboxes is not None: - bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device) - bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes - bboxes = bboxes.view(-1, 2, 2) * r - bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1) - # NOTE: merge "boxes" and "points" into a single "points" input - # (where boxes are added at the beginning) to model.sam_prompt_encoder - if points is not None: - points = torch.cat([bboxes, points], dim=1) - labels = torch.cat([bbox_labels, labels], dim=1) - else: - points, labels = bboxes, bbox_labels - if masks is not None: - masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) - - points = (points, labels) if points is not None else None - - sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( - points=points, - boxes=None, - masks=masks, - ) - # Predict masks - batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction - high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]] - pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder( - image_embeddings=features["image_embed"][img_idx].unsqueeze(0), - image_pe=self.model.sam_prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=multimask_output, - repeat_image=batched_mode, - high_res_features=high_res_features, - ) - # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) - # `d` could be 1 or 3 depends on `multimask_output`. - return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) - - def set_image(self, image): - """ - Preprocesses and sets a single image for inference. - - This function sets up the model if not already initialized, configures the data source to the specified image, - and preprocesses the image for feature extraction. Only one image can be set at a time. - - Args: - image (str | np.ndarray): Image file path as a string, or a numpy array image read by cv2. - - Raises: - AssertionError: If more than one image is set. - - Examples: - >>> predictor = SAM2Predictor() - >>> predictor.set_image("path/to/image.jpg") - >>> predictor.set_image(np.array([...])) # Using a numpy array - """ - if self.model is None: - self.setup_model(model=None) - self.setup_source(image) - assert len(self.dataset) == 1, "`set_image` only supports setting one image!" - for batch in self.dataset: - im = self.preprocess(batch[1]) - self.features = self.get_im_features(im) - break - - def get_im_features(self, im): - """Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks.""" - backbone_out = self.model.forward_image(im) - _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) - if self.model.directly_add_no_mem_embed: - vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed - feats = [ - feat.permute(1, 2, 0).view(1, -1, *feat_size) - for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) - ][::-1] - return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}